From d61f281347c058518a3cf07ee2466c970332604a Mon Sep 17 00:00:00 2001 From: Michael Broughton Date: Sat, 23 Jan 2021 22:43:58 -0800 Subject: [PATCH 1/3] Add BulkSetAmpl function to statespaces. --- lib/statespace_avx.h | 110 ++++++++++++++++++++++++++--------------- lib/statespace_basic.h | 77 ++++++++++++++++++----------- lib/statespace_sse.h | 84 +++++++++++++++++++++---------- 3 files changed, 176 insertions(+), 95 deletions(-) diff --git a/lib/statespace_avx.h b/lib/statespace_avx.h index 5ff17a831..5771b5d28 100644 --- a/lib/statespace_avx.h +++ b/lib/statespace_avx.h @@ -48,7 +48,7 @@ inline __m256i GetZeroMaskAVX(uint64_t i, uint64_t mask, uint64_t bits) { inline double HorizontalSumAVX(__m256 s) { __m128 l = _mm256_castps256_ps128(s); __m128 h = _mm256_extractf128_ps(s, 1); - __m128 s1 = _mm_add_ps(h, l); + __m128 s1 = _mm_add_ps(h, l); __m128 s1s = _mm_movehdup_ps(s1); __m128 s2 = _mm_add_ps(s1, s1s); @@ -194,25 +194,25 @@ class StateSpaceAVX : public StateSpace, For, float> { fp_type v = double{1} / std::sqrt(uint64_t{1} << state.num_qubits()); switch (state.num_qubits()) { - case 1: - valu = _mm256_set_ps(0, 0, 0, 0, 0, 0, v, v); - break; - case 2: - valu = _mm256_set_ps(0, 0, 0, 0, v, v, v, v); - break; - default: - valu = _mm256_set1_ps(v); - break; + case 1: + valu = _mm256_set_ps(0, 0, 0, 0, 0, 0, v, v); + break; + case 2: + valu = _mm256_set_ps(0, 0, 0, 0, v, v, v, v); + break; + default: + valu = _mm256_set1_ps(v); + break; } - auto f = [](unsigned n, unsigned m, uint64_t i, - __m256& val0, __m256 valu, fp_type* p) { + auto f = [](unsigned n, unsigned m, uint64_t i, __m256& val0, __m256 valu, + fp_type* p) { _mm256_store_ps(p + 16 * i, valu); _mm256_store_ps(p + 16 * i + 8, val0); }; - Base::for_.Run( - MinSize(state.num_qubits()) / 16, f, val0, valu, state.get()); + Base::for_.Run(MinSize(state.num_qubits()) / 16, f, val0, valu, + state.get()); } // |0> state. @@ -226,8 +226,8 @@ class StateSpaceAVX : public StateSpace, For, float> { return std::complex(state.get()[k], state.get()[k + 8]); } - static void SetAmpl( - State& state, uint64_t i, const std::complex& ampl) { + static void SetAmpl(State& state, uint64_t i, + const std::complex& ampl) { uint64_t k = (16 * (i / 8)) + (i % 8); state.get()[k] = std::real(ampl); state.get()[k + 8] = std::imag(ampl); @@ -239,14 +239,44 @@ class StateSpaceAVX : public StateSpace, For, float> { state.get()[k + 8] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + __m256 re_reg = _mm256_set1_ps(re); + __m256 im_reg = _mm256_set1_ps(im); + + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) { + __m256 ml = (__m256)detail::GetZeroMaskAVX(8 * i, maskv, bitsv); + + __m256 re = _mm256_load_ps(p + 16 * i); + __m256 im = _mm256_load_ps(p + 16 * i + 8); + + re = _mm256_blendv_ps(re, re_n, ml); + im = _mm256_blendv_ps(im, im_n, ml); + + _mm256_store_ps(p + 16 * i, re); + _mm256_store_ps(p + 16 * i + 8, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg, + im_reg, state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { return false; } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, fp_type* p2) { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + fp_type* p2) { __m256 re1 = _mm256_load_ps(p1 + 16 * i); __m256 im1 = _mm256_load_ps(p1 + 16 * i + 8); __m256 re2 = _mm256_load_ps(p2 + 16 * i); @@ -279,14 +309,14 @@ class StateSpaceAVX : public StateSpace, For, float> { Base::for_.Run(MinSize(state.num_qubits()) / 16, f, r, state.get()); } - std::complex InnerProduct( - const State& state1, const State& state2) const { + std::complex InnerProduct(const State& state1, + const State& state2) const { if (state1.num_qubits() != state2.num_qubits()) { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, const fp_type* p2) -> std::complex { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + const fp_type* p2) -> std::complex { __m256 re1 = _mm256_load_ps(p1 + 16 * i); __m256 im1 = _mm256_load_ps(p1 + 16 * i + 8); __m256 re2 = _mm256_load_ps(p2 + 16 * i); @@ -302,8 +332,8 @@ class StateSpaceAVX : public StateSpace, For, float> { }; using Op = std::plus>; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, - Op(), state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, Op(), + state1.get(), state2.get()); } double RealInnerProduct(const State& state1, const State& state2) const { @@ -311,8 +341,8 @@ class StateSpaceAVX : public StateSpace, For, float> { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, const fp_type* p2) -> double { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + const fp_type* p2) -> double { __m256 re1 = _mm256_load_ps(p1 + 16 * i); __m256 im1 = _mm256_load_ps(p1 + 16 * i + 8); __m256 re2 = _mm256_load_ps(p2 + 16 * i); @@ -324,13 +354,13 @@ class StateSpaceAVX : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, - Op(), state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, Op(), + state1.get(), state2.get()); } template - std::vector Sample( - const State& state, uint64_t num_samples, unsigned seed) const { + std::vector Sample(const State& state, uint64_t num_samples, + unsigned seed) const { std::vector bitstrings; if (num_samples > 0) { @@ -371,8 +401,8 @@ class StateSpaceAVX : public StateSpace, For, float> { using MeasurementResult = typename Base::MeasurementResult; void Collapse(const MeasurementResult& mr, State& state) const { - auto f1 = [](unsigned n, unsigned m, uint64_t i, - uint64_t mask, uint64_t bits, const fp_type* p) -> double { + auto f1 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, + uint64_t bits, const fp_type* p) -> double { __m256i ml = detail::GetZeroMaskAVX(8 * i, mask, bits); __m256 re = _mm256_maskload_ps(p + 16 * i, ml); @@ -388,8 +418,8 @@ class StateSpaceAVX : public StateSpace, For, float> { __m256 renorm = _mm256_set1_ps(1.0 / std::sqrt(norm)); - auto f2 = [](unsigned n, unsigned m, uint64_t i, - uint64_t mask, uint64_t bits, __m256 renorm, fp_type* p) { + auto f2 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, + uint64_t bits, __m256 renorm, fp_type* p) { __m256i ml = detail::GetZeroMaskAVX(8 * i, mask, bits); __m256 re = _mm256_maskload_ps(p + 16 * i, ml); @@ -402,8 +432,8 @@ class StateSpaceAVX : public StateSpace, For, float> { _mm256_store_ps(p + 16 * i + 8, im); }; - Base::for_.Run(MinSize(state.num_qubits()) / 16, f2, - mr.mask, mr.bits, renorm, state.get()); + Base::for_.Run(MinSize(state.num_qubits()) / 16, f2, mr.mask, mr.bits, + renorm, state.get()); } std::vector PartialNorms(const State& state) const { @@ -417,12 +447,12 @@ class StateSpaceAVX : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduceP( - MinSize(state.num_qubits()) / 16, f, Op(), state.get()); + return Base::for_.RunReduceP(MinSize(state.num_qubits()) / 16, f, Op(), + state.get()); } - uint64_t FindMeasuredBits( - unsigned m, double r, uint64_t mask, const State& state) const { + uint64_t FindMeasuredBits(unsigned m, double r, uint64_t mask, + const State& state) const { double csum = 0; uint64_t k0 = Base::for_.GetIndex0(MinSize(state.num_qubits()) / 16, m); diff --git a/lib/statespace_basic.h b/lib/statespace_basic.h index c4c41bdd8..347e04b47 100644 --- a/lib/statespace_basic.h +++ b/lib/statespace_basic.h @@ -63,8 +63,7 @@ class StateSpaceBasic : public StateSpace, For, FP> { void SetStateUniform(State& state) const { fp_type val = fp_type{1} / std::sqrt(uint64_t{1} << state.num_qubits()); - auto f = [](unsigned n, unsigned m, uint64_t i, - fp_type val, fp_type* p) { + auto f = [](unsigned n, unsigned m, uint64_t i, fp_type val, fp_type* p) { p[2 * i] = val; p[2 * i + 1] = 0; }; @@ -83,8 +82,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { return std::complex(state.get()[p], state.get()[p + 1]); } - static void SetAmpl( - State& state, uint64_t i, const std::complex& ampl) { + static void SetAmpl(State& state, uint64_t i, + const std::complex& ampl) { uint64_t p = 2 * i; state.get()[p] = std::real(ampl); state.get()[p + 1] = std::imag(ampl); @@ -96,14 +95,36 @@ class StateSpaceBasic : public StateSpace, For, FP> { state.get()[p + 1] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) { + auto s = p + 2 * i; + bool in_mask = (i & maskv) == bitsv; + + s[0] = in_mask ? re_n : s[0]; + s[1] = in_mask ? im_n : s[1]; + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im, + state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { return false; } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, fp_type* p2) { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + fp_type* p2) { p2[2 * i] += p1[2 * i]; p2[2 * i + 1] += p1[2 * i + 1]; }; @@ -123,14 +144,14 @@ class StateSpaceBasic : public StateSpace, For, FP> { Base::for_.Run(MinSize(state.num_qubits()) / 2, f, a, state.get()); } - std::complex InnerProduct( - const State& state1, const State& state2) const { + std::complex InnerProduct(const State& state1, + const State& state2) const { if (state1.num_qubits() != state2.num_qubits()) { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, const fp_type* p2) -> std::complex { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + const fp_type* p2) -> std::complex { auto s1 = p1 + 2 * i; auto s2 = p2 + 2 * i; @@ -141,8 +162,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { }; using Op = std::plus>; - return Base::for_.RunReduce( - MinSize(state1.num_qubits()) / 2, f, Op(), state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 2, f, Op(), + state1.get(), state2.get()); } double RealInnerProduct(const State& state1, const State& state2) const { @@ -150,8 +171,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, const fp_type* p2) -> double { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + const fp_type* p2) -> double { auto s1 = p1 + 2 * i; auto s2 = p2 + 2 * i; @@ -159,13 +180,13 @@ class StateSpaceBasic : public StateSpace, For, FP> { }; using Op = std::plus; - return Base::for_.RunReduce( - MinSize(state1.num_qubits()) / 2, f, Op(), state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 2, f, Op(), + state1.get(), state2.get()); } template - std::vector Sample( - const State& state, uint64_t num_samples, unsigned seed) const { + std::vector Sample(const State& state, uint64_t num_samples, + unsigned seed) const { std::vector bitstrings; if (num_samples > 0) { @@ -203,8 +224,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { using MeasurementResult = typename Base::MeasurementResult; void Collapse(const MeasurementResult& mr, State& state) const { - auto f1 = [](unsigned n, unsigned m, uint64_t i, - uint64_t mask, uint64_t bits, const fp_type* p) -> double { + auto f1 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, + uint64_t bits, const fp_type* p) -> double { auto s = p + 2 * i; return (i & mask) == bits ? s[0] * s[0] + s[1] * s[1] : 0; }; @@ -215,8 +236,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { double renorm = 1.0 / std::sqrt(norm); - auto f2 = [](unsigned n, unsigned m, uint64_t i, - uint64_t mask, uint64_t bits, fp_type renorm, fp_type* p) { + auto f2 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, + uint64_t bits, fp_type renorm, fp_type* p) { auto s = p + 2 * i; bool not_zero = (i & mask) == bits; @@ -224,8 +245,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { s[1] = not_zero ? s[1] * renorm : 0; }; - Base::for_.Run(MinSize(state.num_qubits()) / 2, f2, - mr.mask, mr.bits, renorm, state.get()); + Base::for_.Run(MinSize(state.num_qubits()) / 2, f2, mr.mask, mr.bits, + renorm, state.get()); } std::vector PartialNorms(const State& state) const { @@ -236,12 +257,12 @@ class StateSpaceBasic : public StateSpace, For, FP> { }; using Op = std::plus; - return Base::for_.RunReduceP( - MinSize(state.num_qubits()) / 2, f, Op(), state.get()); + return Base::for_.RunReduceP(MinSize(state.num_qubits()) / 2, f, Op(), + state.get()); } - uint64_t FindMeasuredBits( - unsigned m, double r, uint64_t mask, const State& state) const { + uint64_t FindMeasuredBits(unsigned m, double r, uint64_t mask, + const State& state) const { double csum = 0; uint64_t k0 = Base::for_.GetIndex0(MinSize(state.num_qubits()) / 2, m); diff --git a/lib/statespace_sse.h b/lib/statespace_sse.h index 9f95217ef..baf74dd66 100644 --- a/lib/statespace_sse.h +++ b/lib/statespace_sse.h @@ -167,8 +167,8 @@ class StateSpaceSSE : public StateSpace, For, float> { valu = _mm_set1_ps(v); } - auto f = [](unsigned n, unsigned m, uint64_t i, - __m128 val0, __m128 valu, fp_type* p) { + auto f = [](unsigned n, unsigned m, uint64_t i, __m128 val0, __m128 valu, + fp_type* p) { _mm_store_ps(p + 8 * i, valu); _mm_store_ps(p + 8 * i + 4, val0); }; @@ -187,8 +187,8 @@ class StateSpaceSSE : public StateSpace, For, float> { return std::complex(state.get()[p], state.get()[p + 4]); } - static void SetAmpl( - State& state, uint64_t i, const std::complex& ampl) { + static void SetAmpl(State& state, uint64_t i, + const std::complex& ampl) { uint64_t p = (8 * (i / 4)) + (i % 4); state.get()[p] = std::real(ampl); state.get()[p + 4] = std::imag(ampl); @@ -200,14 +200,44 @@ class StateSpaceSSE : public StateSpace, For, float> { state.get()[p + 4] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + __m128 re_reg = _mm_set1_ps(re); + __m128 im_reg = _mm_set1_ps(im); + + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) { + __m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv)); + + __m128 re = _mm_load_ps(p + 8 * i); + __m128 im = _mm_load_ps(p + 8 * i + 4); + + re = _mm_blendv_ps(re, re_n, ml); + im = _mm_blendv_ps(im, im_n, ml); + + _mm_store_ps(p + 8 * i, re); + _mm_store_ps(p + 8 * i + 4, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg, + im_reg, state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { return false; } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, fp_type* p2) { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + fp_type* p2) { __m128 re1 = _mm_load_ps(p1 + 8 * i); __m128 im1 = _mm_load_ps(p1 + 8 * i + 4); __m128 re2 = _mm_load_ps(p2 + 8 * i); @@ -240,14 +270,14 @@ class StateSpaceSSE : public StateSpace, For, float> { Base::for_.Run(MinSize(state.num_qubits()) / 8, f, r, state.get()); } - std::complex InnerProduct( - const State& state1, const State& state2) const { + std::complex InnerProduct(const State& state1, + const State& state2) const { if (state1.num_qubits() != state2.num_qubits()) { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, const fp_type* p2) -> std::complex { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + const fp_type* p2) -> std::complex { __m128 re1 = _mm_load_ps(p1 + 8 * i); __m128 im1 = _mm_load_ps(p1 + 8 * i + 4); __m128 re2 = _mm_load_ps(p2 + 8 * i); @@ -263,8 +293,8 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus>; - return Base::for_.RunReduce( - MinSize(state1.num_qubits()) / 8, f, Op(), state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 8, f, Op(), + state1.get(), state2.get()); } double RealInnerProduct(const State& state1, const State& state2) const { @@ -272,8 +302,8 @@ class StateSpaceSSE : public StateSpace, For, float> { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, - const fp_type* p1, const fp_type* p2) -> double { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, + const fp_type* p2) -> double { __m128 re1 = _mm_load_ps(p1 + 8 * i); __m128 im1 = _mm_load_ps(p1 + 8 * i + 4); __m128 re2 = _mm_load_ps(p2 + 8 * i); @@ -285,13 +315,13 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduce( - MinSize(state1.num_qubits()) / 8, f, Op(), state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 8, f, Op(), + state1.get(), state2.get()); } template - std::vector Sample( - const State& state, uint64_t num_samples, unsigned seed) const { + std::vector Sample(const State& state, uint64_t num_samples, + unsigned seed) const { std::vector bitstrings; if (num_samples > 0) { @@ -348,9 +378,9 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus; - double norm = Base::for_.RunReduce(MinSize(state.num_qubits()) / 8, f1, - Op(), mr.mask, mr.bits, zero, - state.get()); + double norm = + Base::for_.RunReduce(MinSize(state.num_qubits()) / 8, f1, Op(), mr.mask, + mr.bits, zero, state.get()); __m128 renorm = _mm_set1_ps(1.0 / std::sqrt(norm)); @@ -368,8 +398,8 @@ class StateSpaceSSE : public StateSpace, For, float> { _mm_store_ps(p + 8 * i + 4, im); }; - Base::for_.Run(MinSize(state.num_qubits()) / 8, f2, - mr.mask, mr.bits, renorm, zero, state.get()); + Base::for_.Run(MinSize(state.num_qubits()) / 8, f2, mr.mask, mr.bits, + renorm, zero, state.get()); } std::vector PartialNorms(const State& state) const { @@ -383,12 +413,12 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduceP( - MinSize(state.num_qubits()) / 8, f, Op(), state.get()); + return Base::for_.RunReduceP(MinSize(state.num_qubits()) / 8, f, Op(), + state.get()); } - uint64_t FindMeasuredBits( - unsigned m, double r, uint64_t mask, const State& state) const { + uint64_t FindMeasuredBits(unsigned m, double r, uint64_t mask, + const State& state) const { double csum = 0; uint64_t k0 = Base::for_.GetIndex0(MinSize(state.num_qubits()) / 8, m); From cea59749ef1820355d47c0100aad384f070d906d Mon Sep 17 00:00:00 2001 From: Michael Broughton Date: Sat, 23 Jan 2021 22:56:28 -0800 Subject: [PATCH 2/3] added tests. --- lib/statespace_avx.h | 80 +++++++++++++++++----------------- lib/statespace_basic.h | 55 +++++++++++------------ lib/statespace_sse.h | 54 +++++++++++------------ tests/statespace_avx_test.cc | 4 ++ tests/statespace_basic_test.cc | 4 ++ tests/statespace_sse_test.cc | 4 ++ tests/statespace_testfixture.h | 61 ++++++++++++++++++++++++++ 7 files changed, 168 insertions(+), 94 deletions(-) diff --git a/lib/statespace_avx.h b/lib/statespace_avx.h index 5771b5d28..0c02b2e85 100644 --- a/lib/statespace_avx.h +++ b/lib/statespace_avx.h @@ -48,7 +48,7 @@ inline __m256i GetZeroMaskAVX(uint64_t i, uint64_t mask, uint64_t bits) { inline double HorizontalSumAVX(__m256 s) { __m128 l = _mm256_castps256_ps128(s); __m128 h = _mm256_extractf128_ps(s, 1); - __m128 s1 = _mm_add_ps(h, l); + __m128 s1 = _mm_add_ps(h, l); __m128 s1s = _mm_movehdup_ps(s1); __m128 s2 = _mm_add_ps(s1, s1s); @@ -194,25 +194,25 @@ class StateSpaceAVX : public StateSpace, For, float> { fp_type v = double{1} / std::sqrt(uint64_t{1} << state.num_qubits()); switch (state.num_qubits()) { - case 1: - valu = _mm256_set_ps(0, 0, 0, 0, 0, 0, v, v); - break; - case 2: - valu = _mm256_set_ps(0, 0, 0, 0, v, v, v, v); - break; - default: - valu = _mm256_set1_ps(v); - break; + case 1: + valu = _mm256_set_ps(0, 0, 0, 0, 0, 0, v, v); + break; + case 2: + valu = _mm256_set_ps(0, 0, 0, 0, v, v, v, v); + break; + default: + valu = _mm256_set1_ps(v); + break; } - auto f = [](unsigned n, unsigned m, uint64_t i, __m256& val0, __m256 valu, - fp_type* p) { + auto f = [](unsigned n, unsigned m, uint64_t i, + __m256& val0, __m256 valu, fp_type* p) { _mm256_store_ps(p + 16 * i, valu); _mm256_store_ps(p + 16 * i + 8, val0); }; - Base::for_.Run(MinSize(state.num_qubits()) / 16, f, val0, valu, - state.get()); + Base::for_.Run( + MinSize(state.num_qubits()) / 16, f, val0, valu, state.get()); } // |0> state. @@ -226,8 +226,8 @@ class StateSpaceAVX : public StateSpace, For, float> { return std::complex(state.get()[k], state.get()[k + 8]); } - static void SetAmpl(State& state, uint64_t i, - const std::complex& ampl) { + static void SetAmpl( + State& state, uint64_t i, const std::complex& ampl) { uint64_t k = (16 * (i / 8)) + (i % 8); state.get()[k] = std::real(ampl); state.get()[k + 8] = std::imag(ampl); @@ -275,8 +275,8 @@ class StateSpaceAVX : public StateSpace, For, float> { return false; } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - fp_type* p2) { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, fp_type* p2) { __m256 re1 = _mm256_load_ps(p1 + 16 * i); __m256 im1 = _mm256_load_ps(p1 + 16 * i + 8); __m256 re2 = _mm256_load_ps(p2 + 16 * i); @@ -309,14 +309,14 @@ class StateSpaceAVX : public StateSpace, For, float> { Base::for_.Run(MinSize(state.num_qubits()) / 16, f, r, state.get()); } - std::complex InnerProduct(const State& state1, - const State& state2) const { + std::complex InnerProduct( + const State& state1, const State& state2) const { if (state1.num_qubits() != state2.num_qubits()) { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - const fp_type* p2) -> std::complex { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> std::complex { __m256 re1 = _mm256_load_ps(p1 + 16 * i); __m256 im1 = _mm256_load_ps(p1 + 16 * i + 8); __m256 re2 = _mm256_load_ps(p2 + 16 * i); @@ -332,8 +332,8 @@ class StateSpaceAVX : public StateSpace, For, float> { }; using Op = std::plus>; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, Op(), - state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, + Op(), state1.get(), state2.get()); } double RealInnerProduct(const State& state1, const State& state2) const { @@ -341,8 +341,8 @@ class StateSpaceAVX : public StateSpace, For, float> { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - const fp_type* p2) -> double { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> double { __m256 re1 = _mm256_load_ps(p1 + 16 * i); __m256 im1 = _mm256_load_ps(p1 + 16 * i + 8); __m256 re2 = _mm256_load_ps(p2 + 16 * i); @@ -354,13 +354,13 @@ class StateSpaceAVX : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, Op(), - state1.get(), state2.get()); + return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 16, f, + Op(), state1.get(), state2.get()); } template - std::vector Sample(const State& state, uint64_t num_samples, - unsigned seed) const { + std::vector Sample( + const State& state, uint64_t num_samples, unsigned seed) const { std::vector bitstrings; if (num_samples > 0) { @@ -401,8 +401,8 @@ class StateSpaceAVX : public StateSpace, For, float> { using MeasurementResult = typename Base::MeasurementResult; void Collapse(const MeasurementResult& mr, State& state) const { - auto f1 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, - uint64_t bits, const fp_type* p) -> double { + auto f1 = [](unsigned n, unsigned m, uint64_t i, + uint64_t mask, uint64_t bits, const fp_type* p) -> double { __m256i ml = detail::GetZeroMaskAVX(8 * i, mask, bits); __m256 re = _mm256_maskload_ps(p + 16 * i, ml); @@ -418,8 +418,8 @@ class StateSpaceAVX : public StateSpace, For, float> { __m256 renorm = _mm256_set1_ps(1.0 / std::sqrt(norm)); - auto f2 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, - uint64_t bits, __m256 renorm, fp_type* p) { + auto f2 = [](unsigned n, unsigned m, uint64_t i, + uint64_t mask, uint64_t bits, __m256 renorm, fp_type* p) { __m256i ml = detail::GetZeroMaskAVX(8 * i, mask, bits); __m256 re = _mm256_maskload_ps(p + 16 * i, ml); @@ -432,8 +432,8 @@ class StateSpaceAVX : public StateSpace, For, float> { _mm256_store_ps(p + 16 * i + 8, im); }; - Base::for_.Run(MinSize(state.num_qubits()) / 16, f2, mr.mask, mr.bits, - renorm, state.get()); + Base::for_.Run(MinSize(state.num_qubits()) / 16, f2, + mr.mask, mr.bits, renorm, state.get()); } std::vector PartialNorms(const State& state) const { @@ -447,12 +447,12 @@ class StateSpaceAVX : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduceP(MinSize(state.num_qubits()) / 16, f, Op(), - state.get()); + return Base::for_.RunReduceP( + MinSize(state.num_qubits()) / 16, f, Op(), state.get()); } - uint64_t FindMeasuredBits(unsigned m, double r, uint64_t mask, - const State& state) const { + uint64_t FindMeasuredBits( + unsigned m, double r, uint64_t mask, const State& state) const { double csum = 0; uint64_t k0 = Base::for_.GetIndex0(MinSize(state.num_qubits()) / 16, m); diff --git a/lib/statespace_basic.h b/lib/statespace_basic.h index 347e04b47..2cdab2c8f 100644 --- a/lib/statespace_basic.h +++ b/lib/statespace_basic.h @@ -63,7 +63,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { void SetStateUniform(State& state) const { fp_type val = fp_type{1} / std::sqrt(uint64_t{1} << state.num_qubits()); - auto f = [](unsigned n, unsigned m, uint64_t i, fp_type val, fp_type* p) { + auto f = [](unsigned n, unsigned m, uint64_t i, + fp_type val, fp_type* p) { p[2 * i] = val; p[2 * i + 1] = 0; }; @@ -82,8 +83,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { return std::complex(state.get()[p], state.get()[p + 1]); } - static void SetAmpl(State& state, uint64_t i, - const std::complex& ampl) { + static void SetAmpl( + State& state, uint64_t i, const std::complex& ampl) { uint64_t p = 2 * i; state.get()[p] = std::real(ampl); state.get()[p + 1] = std::imag(ampl); @@ -123,8 +124,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { return false; } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - fp_type* p2) { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, fp_type* p2) { p2[2 * i] += p1[2 * i]; p2[2 * i + 1] += p1[2 * i + 1]; }; @@ -144,14 +145,14 @@ class StateSpaceBasic : public StateSpace, For, FP> { Base::for_.Run(MinSize(state.num_qubits()) / 2, f, a, state.get()); } - std::complex InnerProduct(const State& state1, - const State& state2) const { + std::complex InnerProduct( + const State& state1, const State& state2) const { if (state1.num_qubits() != state2.num_qubits()) { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - const fp_type* p2) -> std::complex { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> std::complex { auto s1 = p1 + 2 * i; auto s2 = p2 + 2 * i; @@ -162,8 +163,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { }; using Op = std::plus>; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 2, f, Op(), - state1.get(), state2.get()); + return Base::for_.RunReduce( + MinSize(state1.num_qubits()) / 2, f, Op(), state1.get(), state2.get()); } double RealInnerProduct(const State& state1, const State& state2) const { @@ -171,8 +172,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - const fp_type* p2) -> double { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> double { auto s1 = p1 + 2 * i; auto s2 = p2 + 2 * i; @@ -180,13 +181,13 @@ class StateSpaceBasic : public StateSpace, For, FP> { }; using Op = std::plus; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 2, f, Op(), - state1.get(), state2.get()); + return Base::for_.RunReduce( + MinSize(state1.num_qubits()) / 2, f, Op(), state1.get(), state2.get()); } template - std::vector Sample(const State& state, uint64_t num_samples, - unsigned seed) const { + std::vector Sample( + const State& state, uint64_t num_samples, unsigned seed) const { std::vector bitstrings; if (num_samples > 0) { @@ -224,8 +225,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { using MeasurementResult = typename Base::MeasurementResult; void Collapse(const MeasurementResult& mr, State& state) const { - auto f1 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, - uint64_t bits, const fp_type* p) -> double { + auto f1 = [](unsigned n, unsigned m, uint64_t i, + uint64_t mask, uint64_t bits, const fp_type* p) -> double { auto s = p + 2 * i; return (i & mask) == bits ? s[0] * s[0] + s[1] * s[1] : 0; }; @@ -236,8 +237,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { double renorm = 1.0 / std::sqrt(norm); - auto f2 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, - uint64_t bits, fp_type renorm, fp_type* p) { + auto f2 = [](unsigned n, unsigned m, uint64_t i, + uint64_t mask, uint64_t bits, fp_type renorm, fp_type* p) { auto s = p + 2 * i; bool not_zero = (i & mask) == bits; @@ -245,8 +246,8 @@ class StateSpaceBasic : public StateSpace, For, FP> { s[1] = not_zero ? s[1] * renorm : 0; }; - Base::for_.Run(MinSize(state.num_qubits()) / 2, f2, mr.mask, mr.bits, - renorm, state.get()); + Base::for_.Run(MinSize(state.num_qubits()) / 2, f2, + mr.mask, mr.bits, renorm, state.get()); } std::vector PartialNorms(const State& state) const { @@ -257,12 +258,12 @@ class StateSpaceBasic : public StateSpace, For, FP> { }; using Op = std::plus; - return Base::for_.RunReduceP(MinSize(state.num_qubits()) / 2, f, Op(), - state.get()); + return Base::for_.RunReduceP( + MinSize(state.num_qubits()) / 2, f, Op(), state.get()); } - uint64_t FindMeasuredBits(unsigned m, double r, uint64_t mask, - const State& state) const { + uint64_t FindMeasuredBits( + unsigned m, double r, uint64_t mask, const State& state) const { double csum = 0; uint64_t k0 = Base::for_.GetIndex0(MinSize(state.num_qubits()) / 2, m); diff --git a/lib/statespace_sse.h b/lib/statespace_sse.h index baf74dd66..5c1ecd415 100644 --- a/lib/statespace_sse.h +++ b/lib/statespace_sse.h @@ -167,8 +167,8 @@ class StateSpaceSSE : public StateSpace, For, float> { valu = _mm_set1_ps(v); } - auto f = [](unsigned n, unsigned m, uint64_t i, __m128 val0, __m128 valu, - fp_type* p) { + auto f = [](unsigned n, unsigned m, uint64_t i, + __m128 val0, __m128 valu, fp_type* p) { _mm_store_ps(p + 8 * i, valu); _mm_store_ps(p + 8 * i + 4, val0); }; @@ -187,8 +187,8 @@ class StateSpaceSSE : public StateSpace, For, float> { return std::complex(state.get()[p], state.get()[p + 4]); } - static void SetAmpl(State& state, uint64_t i, - const std::complex& ampl) { + static void SetAmpl( + State& state, uint64_t i, const std::complex& ampl) { uint64_t p = (8 * (i / 4)) + (i % 4); state.get()[p] = std::real(ampl); state.get()[p + 4] = std::imag(ampl); @@ -236,8 +236,8 @@ class StateSpaceSSE : public StateSpace, For, float> { return false; } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - fp_type* p2) { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, fp_type* p2) { __m128 re1 = _mm_load_ps(p1 + 8 * i); __m128 im1 = _mm_load_ps(p1 + 8 * i + 4); __m128 re2 = _mm_load_ps(p2 + 8 * i); @@ -270,14 +270,14 @@ class StateSpaceSSE : public StateSpace, For, float> { Base::for_.Run(MinSize(state.num_qubits()) / 8, f, r, state.get()); } - std::complex InnerProduct(const State& state1, - const State& state2) const { + std::complex InnerProduct( + const State& state1, const State& state2) const { if (state1.num_qubits() != state2.num_qubits()) { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - const fp_type* p2) -> std::complex { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> std::complex { __m128 re1 = _mm_load_ps(p1 + 8 * i); __m128 im1 = _mm_load_ps(p1 + 8 * i + 4); __m128 re2 = _mm_load_ps(p2 + 8 * i); @@ -293,8 +293,8 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus>; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 8, f, Op(), - state1.get(), state2.get()); + return Base::for_.RunReduce( + MinSize(state1.num_qubits()) / 8, f, Op(), state1.get(), state2.get()); } double RealInnerProduct(const State& state1, const State& state2) const { @@ -302,8 +302,8 @@ class StateSpaceSSE : public StateSpace, For, float> { return std::nan(""); } - auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* p1, - const fp_type* p2) -> double { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> double { __m128 re1 = _mm_load_ps(p1 + 8 * i); __m128 im1 = _mm_load_ps(p1 + 8 * i + 4); __m128 re2 = _mm_load_ps(p2 + 8 * i); @@ -315,13 +315,13 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduce(MinSize(state1.num_qubits()) / 8, f, Op(), - state1.get(), state2.get()); + return Base::for_.RunReduce( + MinSize(state1.num_qubits()) / 8, f, Op(), state1.get(), state2.get()); } template - std::vector Sample(const State& state, uint64_t num_samples, - unsigned seed) const { + std::vector Sample( + const State& state, uint64_t num_samples, unsigned seed) const { std::vector bitstrings; if (num_samples > 0) { @@ -378,9 +378,9 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus; - double norm = - Base::for_.RunReduce(MinSize(state.num_qubits()) / 8, f1, Op(), mr.mask, - mr.bits, zero, state.get()); + double norm = Base::for_.RunReduce(MinSize(state.num_qubits()) / 8, f1, + Op(), mr.mask, mr.bits, zero, + state.get()); __m128 renorm = _mm_set1_ps(1.0 / std::sqrt(norm)); @@ -398,8 +398,8 @@ class StateSpaceSSE : public StateSpace, For, float> { _mm_store_ps(p + 8 * i + 4, im); }; - Base::for_.Run(MinSize(state.num_qubits()) / 8, f2, mr.mask, mr.bits, - renorm, zero, state.get()); + Base::for_.Run(MinSize(state.num_qubits()) / 8, f2, + mr.mask, mr.bits, renorm, zero, state.get()); } std::vector PartialNorms(const State& state) const { @@ -413,12 +413,12 @@ class StateSpaceSSE : public StateSpace, For, float> { }; using Op = std::plus; - return Base::for_.RunReduceP(MinSize(state.num_qubits()) / 8, f, Op(), - state.get()); + return Base::for_.RunReduceP( + MinSize(state.num_qubits()) / 8, f, Op(), state.get()); } - uint64_t FindMeasuredBits(unsigned m, double r, uint64_t mask, - const State& state) const { + uint64_t FindMeasuredBits( + unsigned m, double r, uint64_t mask, const State& state) const { double csum = 0; uint64_t k0 = Base::for_.GetIndex0(MinSize(state.num_qubits()) / 8, m); diff --git a/tests/statespace_avx_test.cc b/tests/statespace_avx_test.cc index 907760910..ca72a13b1 100644 --- a/tests/statespace_avx_test.cc +++ b/tests/statespace_avx_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceAVXTest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_basic_test.cc b/tests/statespace_basic_test.cc index 349108d68..6a789f744 100644 --- a/tests/statespace_basic_test.cc +++ b/tests/statespace_basic_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceBasicTest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_sse_test.cc b/tests/statespace_sse_test.cc index 24a8bc68c..45f61e598 100644 --- a/tests/statespace_sse_test.cc +++ b/tests/statespace_sse_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceSSETest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_testfixture.h b/tests/statespace_testfixture.h index c9f7b42f1..28ca0cfd3 100644 --- a/tests/statespace_testfixture.h +++ b/tests/statespace_testfixture.h @@ -809,6 +809,67 @@ void TestInvalidStateSize() { EXPECT_FALSE(!std::isnan(state_space.RealInnerProduct(state1, state2))); } +template +void TestBulkSetAmplitude() { + using State = typename StateSpace::State; + unsigned num_qubits = 3; + + StateSpace state_space(1); + + State state = state_space.Create(num_qubits); + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 1, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 2, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); +} + } // namespace qsim #endif // STATESPACE_TESTFIXTURE_H_ From 0985bcf8b54403724d7e2a6f36aa407ad2c3ea32 Mon Sep 17 00:00:00 2001 From: Michael Broughton Date: Sun, 24 Jan 2021 19:04:39 -0800 Subject: [PATCH 3/3] use builtin mm256i cast for win compat. --- lib/statespace_avx.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/statespace_avx.h b/lib/statespace_avx.h index 0c02b2e85..f21dd3f82 100644 --- a/lib/statespace_avx.h +++ b/lib/statespace_avx.h @@ -253,7 +253,8 @@ class StateSpaceAVX : public StateSpace, For, float> { auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) { - __m256 ml = (__m256)detail::GetZeroMaskAVX(8 * i, maskv, bitsv); + __m256 ml = + _mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv)); __m256 re = _mm256_load_ps(p + 16 * i); __m256 im = _mm256_load_ps(p + 16 * i + 8);