diff --git a/include/xsimd/arch/xsimd_avx2.hpp b/include/xsimd/arch/xsimd_avx2.hpp index 2c44df461..435f4ea50 100644 --- a/include/xsimd/arch/xsimd_avx2.hpp +++ b/include/xsimd/arch/xsimd_avx2.hpp @@ -179,6 +179,15 @@ namespace xsimd { constexpr auto bits = std::numeric_limits::digits + std::numeric_limits::is_signed; static_assert(shift < bits, "Shift must be less than the number of bits in T"); + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + // 8-bit left shift via 16-bit shift + mask + __m256i shifted = _mm256_slli_epi16(self, shift); + // TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow + constexpr uint8_t mask8 = static_cast(sizeof(T) == 1 ? (~0u << shift) : 0); + const __m256i mask = _mm256_set1_epi8(mask8); + return _mm256_and_si256(shifted, mask); + } XSIMD_IF_CONSTEXPR(sizeof(T) == 2) { return _mm256_slli_epi16(self, shift); @@ -191,10 +200,6 @@ namespace xsimd { return _mm256_slli_epi64(self, shift); } - else - { - return bitwise_lshift(self, avx {}); - } } template ::value>::type> @@ -312,10 +317,12 @@ namespace xsimd { XSIMD_IF_CONSTEXPR(sizeof(T) == 1) { - const __m256i byte_mask = _mm256_set1_epi16(0x00FF); - __m256i u16 = _mm256_and_si256(self, byte_mask); - __m256i r16 = _mm256_srli_epi16(u16, shift); - return _mm256_and_si256(r16, byte_mask); + // 8-bit left shift via 16-bit shift + mask + const __m256i shifted = _mm256_srli_epi16(self, shift); + // TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow + constexpr uint8_t mask8 = static_cast(sizeof(T) == 1 ? ((1u << shift) - 1u) : 0); + const __m256i mask = _mm256_set1_epi8(mask8); + return _mm256_and_si256(shifted, mask); } XSIMD_IF_CONSTEXPR(sizeof(T) == 2) { @@ -329,10 +336,6 @@ namespace xsimd { return _mm256_srli_epi64(self, shift); } - else - { - return bitwise_rshift(self, avx {}); - } } } diff --git a/include/xsimd/arch/xsimd_sse2.hpp b/include/xsimd/arch/xsimd_sse2.hpp index deb1af542..f77017f82 100644 --- a/include/xsimd/arch/xsimd_sse2.hpp +++ b/include/xsimd/arch/xsimd_sse2.hpp @@ -305,7 +305,9 @@ namespace xsimd { // 8-bit left shift via 16-bit shift + mask __m128i shifted = _mm_slli_epi16(self, static_cast(shift)); - __m128i mask = _mm_set1_epi8(static_cast(0xFF << shift)); + // TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow + constexpr uint8_t mask8 = static_cast(sizeof(T) == 1 ? (~0u << shift) : 0); + const __m128i mask = _mm_set1_epi8(mask8); return _mm_and_si128(shifted, mask); } else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) @@ -488,10 +490,12 @@ namespace xsimd { XSIMD_IF_CONSTEXPR(sizeof(T) == 1) { - // Emulate byte-wise logical right shift using 16-bit shifts + per-byte mask. - __m128i s16 = _mm_srli_epi16(self, static_cast(shift)); - __m128i mask = _mm_set1_epi8(static_cast(0xFFu >> shift)); - return _mm_and_si128(s16, mask); + // 8-bit left shift via 16-bit shift + mask + __m128i shifted = _mm_srli_epi16(self, static_cast(shift)); + // TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow + constexpr uint8_t mask8 = static_cast(sizeof(T) == 1 ? ((1u << shift) - 1u) : 0); + const __m128i mask = _mm_set1_epi8(mask8); + return _mm_and_si128(shifted, mask); } else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) { diff --git a/test/test_xsimd_api.cpp b/test/test_xsimd_api.cpp index a61c0e6ad..2e62c292c 100644 --- a/test/test_xsimd_api.cpp +++ b/test/test_xsimd_api.cpp @@ -358,7 +358,7 @@ struct xsimd_api_integral_types_functions value_type val1(shift); value_type r = val0 << val1; value_type ir = val0 << shift; - value_type cr = xsimd::bitwise_lshift(val0); + T cr = xsimd::bitwise_lshift(T(val0)); CHECK_EQ(extract(xsimd::bitwise_lshift(T(val0), T(val1))), r); CHECK_EQ(extract(ir), r); CHECK_EQ(extract(cr), r); @@ -371,7 +371,7 @@ struct xsimd_api_integral_types_functions value_type val1(shift); value_type r = val0 >> val1; value_type ir = val0 >> shift; - value_type cr = xsimd::bitwise_rshift(val0); + T cr = xsimd::bitwise_rshift(T(val0)); CHECK_EQ(extract(xsimd::bitwise_rshift(T(val0), T(val1))), r); CHECK_EQ(extract(ir), r); CHECK_EQ(extract(cr), r); @@ -391,7 +391,7 @@ struct xsimd_api_integral_types_functions value_type val0(12); value_type val1(count); value_type r = (val0 << val1) | (val0 >> (N - val1)); - value_type cr = xsimd::rotl(val0); + T cr = xsimd::rotl(T(val0)); CHECK_EQ(extract(xsimd::rotl(T(val0), T(val1))), r); CHECK_EQ(extract(cr), r); } @@ -403,7 +403,7 @@ struct xsimd_api_integral_types_functions value_type val0(12); value_type val1(count); value_type r = (val0 >> val1) | (val0 << (N - val1)); - value_type cr = xsimd::rotr<3>(val0); + T cr = xsimd::rotr<3>(T(val0)); CHECK_EQ(extract(xsimd::rotr(T(val0), T(val1))), r); CHECK_EQ(extract(cr), r); }