diff --git a/lib/statespace_avx.h b/lib/statespace_avx.h index 5ff17a831..f21dd3f82 100644 --- a/lib/statespace_avx.h +++ b/lib/statespace_avx.h @@ -239,6 +239,37 @@ class StateSpaceAVX : public StateSpace, For, float> { state.get()[k + 8] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + __m256 re_reg = _mm256_set1_ps(re); + __m256 im_reg = _mm256_set1_ps(im); + + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) { + __m256 ml = + _mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv)); + + __m256 re = _mm256_load_ps(p + 16 * i); + __m256 im = _mm256_load_ps(p + 16 * i + 8); + + re = _mm256_blendv_ps(re, re_n, ml); + im = _mm256_blendv_ps(im, im_n, ml); + + _mm256_store_ps(p + 16 * i, re); + _mm256_store_ps(p + 16 * i + 8, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg, + im_reg, state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { diff --git a/lib/statespace_basic.h b/lib/statespace_basic.h index c4c41bdd8..2cdab2c8f 100644 --- a/lib/statespace_basic.h +++ b/lib/statespace_basic.h @@ -96,6 +96,28 @@ class StateSpaceBasic : public StateSpace, For, FP> { state.get()[p + 1] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) { + auto s = p + 2 * i; + bool in_mask = (i & maskv) == bitsv; + + s[0] = in_mask ? re_n : s[0]; + s[1] = in_mask ? im_n : s[1]; + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im, + state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { diff --git a/lib/statespace_sse.h b/lib/statespace_sse.h index 9f95217ef..5c1ecd415 100644 --- a/lib/statespace_sse.h +++ b/lib/statespace_sse.h @@ -200,6 +200,36 @@ class StateSpaceSSE : public StateSpace, For, float> { state.get()[p + 4] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + __m128 re_reg = _mm_set1_ps(re); + __m128 im_reg = _mm_set1_ps(im); + + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) { + __m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv)); + + __m128 re = _mm_load_ps(p + 8 * i); + __m128 im = _mm_load_ps(p + 8 * i + 4); + + re = _mm_blendv_ps(re, re_n, ml); + im = _mm_blendv_ps(im, im_n, ml); + + _mm_store_ps(p + 8 * i, re); + _mm_store_ps(p + 8 * i + 4, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg, + im_reg, state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { diff --git a/tests/statespace_avx_test.cc b/tests/statespace_avx_test.cc index 907760910..ca72a13b1 100644 --- a/tests/statespace_avx_test.cc +++ b/tests/statespace_avx_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceAVXTest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_basic_test.cc b/tests/statespace_basic_test.cc index 349108d68..6a789f744 100644 --- a/tests/statespace_basic_test.cc +++ b/tests/statespace_basic_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceBasicTest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_sse_test.cc b/tests/statespace_sse_test.cc index 24a8bc68c..45f61e598 100644 --- a/tests/statespace_sse_test.cc +++ b/tests/statespace_sse_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceSSETest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_testfixture.h b/tests/statespace_testfixture.h index c9f7b42f1..28ca0cfd3 100644 --- a/tests/statespace_testfixture.h +++ b/tests/statespace_testfixture.h @@ -809,6 +809,67 @@ void TestInvalidStateSize() { EXPECT_FALSE(!std::isnan(state_space.RealInnerProduct(state1, state2))); } +template +void TestBulkSetAmplitude() { + using State = typename StateSpace::State; + unsigned num_qubits = 3; + + StateSpace state_space(1); + + State state = state_space.Create(num_qubits); + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 1, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 2, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); +} + } // namespace qsim #endif // STATESPACE_TESTFIXTURE_H_