Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions include/xsimd/arch/xsimd_avx2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ namespace xsimd
{
constexpr auto bits = std::numeric_limits<T>::digits + std::numeric_limits<T>::is_signed;
static_assert(shift < bits, "Shift must be less than the number of bits in T");
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
{
// 8-bit left shift via 16-bit shift + mask
__m256i shifted = _mm256_slli_epi16(self, shift);
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? (~0u << shift) : 0);
const __m256i mask = _mm256_set1_epi8(mask8);
return _mm256_and_si256(shifted, mask);
}
XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
{
return _mm256_slli_epi16(self, shift);
Expand All @@ -191,10 +200,6 @@ namespace xsimd
{
return _mm256_slli_epi64(self, shift);
}
else
{
return bitwise_lshift<shift>(self, avx {});
}
}

template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value>::type>
Expand Down Expand Up @@ -312,10 +317,12 @@ namespace xsimd
{
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
{
const __m256i byte_mask = _mm256_set1_epi16(0x00FF);
__m256i u16 = _mm256_and_si256(self, byte_mask);
__m256i r16 = _mm256_srli_epi16(u16, shift);
return _mm256_and_si256(r16, byte_mask);
// 8-bit left shift via 16-bit shift + mask
const __m256i shifted = _mm256_srli_epi16(self, shift);
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? ((1u << shift) - 1u) : 0);
const __m256i mask = _mm256_set1_epi8(mask8);
return _mm256_and_si256(shifted, mask);
}
XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
{
Expand All @@ -329,10 +336,6 @@ namespace xsimd
{
return _mm256_srli_epi64(self, shift);
}
else
{
return bitwise_rshift<shift>(self, avx {});
}
}
}

Expand Down
14 changes: 9 additions & 5 deletions include/xsimd/arch/xsimd_sse2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ namespace xsimd
{
// 8-bit left shift via 16-bit shift + mask
__m128i shifted = _mm_slli_epi16(self, static_cast<int>(shift));
__m128i mask = _mm_set1_epi8(static_cast<char>(0xFF << shift));
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? (~0u << shift) : 0);
const __m128i mask = _mm_set1_epi8(mask8);
return _mm_and_si128(shifted, mask);
}
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
Expand Down Expand Up @@ -488,10 +490,12 @@ namespace xsimd
{
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
{
// Emulate byte-wise logical right shift using 16-bit shifts + per-byte mask.
__m128i s16 = _mm_srli_epi16(self, static_cast<int>(shift));
__m128i mask = _mm_set1_epi8(static_cast<char>(0xFFu >> shift));
return _mm_and_si128(s16, mask);
// 8-bit left shift via 16-bit shift + mask
__m128i shifted = _mm_srli_epi16(self, static_cast<int>(shift));
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? ((1u << shift) - 1u) : 0);
const __m128i mask = _mm_set1_epi8(mask8);
return _mm_and_si128(shifted, mask);
}
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
{
Expand Down
8 changes: 4 additions & 4 deletions test/test_xsimd_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ struct xsimd_api_integral_types_functions
value_type val1(shift);
value_type r = val0 << val1;
value_type ir = val0 << shift;
value_type cr = xsimd::bitwise_lshift<shift>(val0);
T cr = xsimd::bitwise_lshift<shift>(T(val0));
CHECK_EQ(extract(xsimd::bitwise_lshift(T(val0), T(val1))), r);
CHECK_EQ(extract(ir), r);
CHECK_EQ(extract(cr), r);
Expand All @@ -371,7 +371,7 @@ struct xsimd_api_integral_types_functions
value_type val1(shift);
value_type r = val0 >> val1;
value_type ir = val0 >> shift;
value_type cr = xsimd::bitwise_rshift<shift>(val0);
T cr = xsimd::bitwise_rshift<shift>(T(val0));
CHECK_EQ(extract(xsimd::bitwise_rshift(T(val0), T(val1))), r);
CHECK_EQ(extract(ir), r);
CHECK_EQ(extract(cr), r);
Expand All @@ -391,7 +391,7 @@ struct xsimd_api_integral_types_functions
value_type val0(12);
value_type val1(count);
value_type r = (val0 << val1) | (val0 >> (N - val1));
value_type cr = xsimd::rotl<count>(val0);
T cr = xsimd::rotl<count>(T(val0));
CHECK_EQ(extract(xsimd::rotl(T(val0), T(val1))), r);
CHECK_EQ(extract(cr), r);
}
Expand All @@ -403,7 +403,7 @@ struct xsimd_api_integral_types_functions
value_type val0(12);
value_type val1(count);
value_type r = (val0 >> val1) | (val0 << (N - val1));
value_type cr = xsimd::rotr<3>(val0);
T cr = xsimd::rotr<3>(T(val0));
CHECK_EQ(extract(xsimd::rotr(T(val0), T(val1))), r);
CHECK_EQ(extract(cr), r);
}
Expand Down