From 88e88a180cbb546cb77af64326d3b7924eb3610c Mon Sep 17 00:00:00 2001 From: AntoinePrv Date: Tue, 11 Nov 2025 15:40:17 -0800 Subject: [PATCH] Avx2 constant swizzle 8/16 bits and optimizations --- .../arch/common/xsimd_common_swizzle.hpp | 93 ++++---- include/xsimd/arch/xsimd_avx2.hpp | 201 +++++++++++++++--- include/xsimd/arch/xsimd_common_fwd.hpp | 6 +- test/test_batch_manip.cpp | 18 +- 4 files changed, 233 insertions(+), 85 deletions(-) diff --git a/include/xsimd/arch/common/xsimd_common_swizzle.hpp b/include/xsimd/arch/common/xsimd_common_swizzle.hpp index 7f40dd542..535f1bf74 100644 --- a/include/xsimd/arch/common/xsimd_common_swizzle.hpp +++ b/include/xsimd/arch/common/xsimd_common_swizzle.hpp @@ -16,6 +16,8 @@ #include #include +#include "../../config/xsimd_inline.hpp" + namespace xsimd { template @@ -39,7 +41,7 @@ namespace xsimd }; // ──────────────────────────────────────────────────────────────────────── - // 1) identity_impl + // identity_impl template XSIMD_INLINE constexpr bool identity_impl() noexcept { return true; } template @@ -50,18 +52,7 @@ namespace xsimd } // ──────────────────────────────────────────────────────────────────────── - // 2) bitmask_impl - template - XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept { return 0u; } - template - XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept - { - return (1u << (static_cast(V0) & (N - 1))) - | bitmask_impl(); - } - - // ──────────────────────────────────────────────────────────────────────── - // 3) dup_lo_impl + // dup_lo_impl template ::type = 0> XSIMD_INLINE constexpr bool dup_lo_impl() noexcept { return true; } @@ -76,7 +67,7 @@ namespace xsimd } // ──────────────────────────────────────────────────────────────────────── - // 4) dup_hi_impl + // dup_hi_impl template ::type = 0> XSIMD_INLINE constexpr bool dup_hi_impl() noexcept { return true; } @@ -91,6 +82,52 @@ namespace xsimd && dup_hi_impl(); } + // ──────────────────────────────────────────────────────────────────────── + // only_from_lo + template + struct only_from_lo_impl; + + template + struct only_from_lo_impl + { + static constexpr bool value = (Last < (Size / 2)); + }; + + template + struct only_from_lo_impl + { + static constexpr bool value = (First < (Size / 2)) && only_from_lo_impl::value; + }; + + template + constexpr bool is_only_from_lo() + { + return only_from_lo_impl::value; + }; + + // ──────────────────────────────────────────────────────────────────────── + // only_from_hi + template + struct only_from_hi_impl; + + template + struct only_from_hi_impl + { + static constexpr bool value = (Last >= (Size / 2)); + }; + + template + struct only_from_hi_impl + { + static constexpr bool value = (First >= (Size / 2)) && only_from_hi_impl::value; + }; + + template + constexpr bool is_only_from_hi() + { + return only_from_hi_impl::value; + }; + // ──────────────────────────────────────────────────────────────────────── // 1) helper to get the I-th value from the Vs pack template @@ -123,33 +160,15 @@ namespace xsimd { static constexpr bool value = false; }; - template - XSIMD_INLINE constexpr bool no_duplicates_impl() noexcept - { - // build the bitmask of (Vs & (N-1)) across all lanes - return detail::bitmask_impl<0, N, T, Vs...>() == ((1u << N) - 1u); - } - template - XSIMD_INLINE constexpr bool no_duplicates_v() noexcept - { - // forward to your existing no_duplicates_impl - return no_duplicates_impl<0, sizeof...(Vs), uint32_t, Vs...>(); - } template XSIMD_INLINE constexpr bool is_cross_lane() noexcept { static_assert(sizeof...(Vs) >= 1, "Need at least one lane"); return cross_impl<0, sizeof...(Vs), sizeof...(Vs) / 2, Vs...>::value; } + template XSIMD_INLINE constexpr bool is_identity() noexcept { return detail::identity_impl<0, T, Vs...>(); } - template - XSIMD_INLINE constexpr bool is_all_different() noexcept - { - return detail::bitmask_impl<0, sizeof...(Vs), T, Vs...>() == ((1u << sizeof...(Vs)) - 1); - } - template XSIMD_INLINE constexpr bool is_dup_lo() noexcept { return detail::dup_lo_impl<0, sizeof...(Vs), T, Vs...>(); } template @@ -157,15 +176,15 @@ namespace xsimd template XSIMD_INLINE constexpr bool is_identity(batch_constant) noexcept { return is_identity(); } template - XSIMD_INLINE constexpr bool is_all_different(batch_constant) noexcept { return is_all_different(); } - template XSIMD_INLINE constexpr bool is_dup_lo(batch_constant) noexcept { return is_dup_lo(); } template XSIMD_INLINE constexpr bool is_dup_hi(batch_constant) noexcept { return is_dup_hi(); } template - XSIMD_INLINE constexpr bool is_cross_lane(batch_constant) noexcept { return detail::is_cross_lane(); } + XSIMD_INLINE constexpr bool is_only_from_lo(batch_constant) noexcept { return detail::is_only_from_lo(); } template - XSIMD_INLINE constexpr bool no_duplicates(batch_constant) noexcept { return no_duplicates_impl<0, sizeof...(Vs), T, Vs...>(); } + XSIMD_INLINE constexpr bool is_only_from_hi(batch_constant) noexcept { return detail::is_only_from_hi(); } + template + XSIMD_INLINE constexpr bool is_cross_lane(batch_constant) noexcept { return detail::is_cross_lane(); } } // namespace detail } // namespace kernel diff --git a/include/xsimd/arch/xsimd_avx2.hpp b/include/xsimd/arch/xsimd_avx2.hpp index 6e5e4342e..2c44df461 100644 --- a/include/xsimd/arch/xsimd_avx2.hpp +++ b/include/xsimd/arch/xsimd_avx2.hpp @@ -1141,67 +1141,198 @@ namespace xsimd return bitwise_cast(swizzle(bitwise_cast(self), mask, req)); } + namespace detail + { + template + constexpr T swizzle_val_none() + { + // Most significant bit of the byte must be 1 + return 0x80; + } + + template + constexpr bool swizzle_val_is_cross_lane(T val, T idx, T size) + { + return (idx < (size / 2)) != (val < (size / 2)); + } + + template + constexpr bool swizzle_val_is_defined(T val, T size) + { + return (0 <= val) && (val < size); + } + + template + constexpr T swizzle_self_val(T val, T idx, T size) + { + return (swizzle_val_is_defined(val, size) && !swizzle_val_is_cross_lane(val, idx, size)) + ? val % (size / 2) + : swizzle_val_none(); + } + + template + constexpr auto swizzle_make_self_batch_impl(::xsimd::detail::index_sequence) + -> batch_constant(sizeof...(Vals)))...> + { + return {}; + } + + template + constexpr auto swizzle_make_self_batch() + -> decltype(swizzle_make_self_batch_impl(::xsimd::detail::make_index_sequence())) + { + return {}; + } + + template + constexpr T swizzle_cross_val(T val, T idx, T size) + { + return (swizzle_val_is_defined(val, size) && swizzle_val_is_cross_lane(val, idx, size)) + ? val % (size / 2) + : swizzle_val_none(); + } + + template + constexpr auto swizzle_make_cross_batch_impl(::xsimd::detail::index_sequence) + -> batch_constant(sizeof...(Vals)))...> + { + return {}; + } + + template + constexpr auto swizzle_make_cross_batch() + -> decltype(swizzle_make_cross_batch_impl(::xsimd::detail::make_index_sequence())) + { + return {}; + } + } + // swizzle (constant mask) + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + static_assert(sizeof...(Vals) == 32, "Must contain as many uint8_t as can fit in avx register"); + + XSIMD_IF_CONSTEXPR(detail::is_identity(mask)) + { + return self; + } + + constexpr auto lane_mask = mask % make_batch_constant(); + + XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask)) + { + return _mm256_shuffle_epi8(self, lane_mask.as_batch()); + } + XSIMD_IF_CONSTEXPR(detail::is_only_from_lo(mask)) + { + __m256i broadcast = _mm256_permute2x128_si256(self, self, 0x00); // [low | low] + return _mm256_shuffle_epi8(broadcast, lane_mask.as_batch()); + } + XSIMD_IF_CONSTEXPR(detail::is_only_from_hi(mask)) + { + __m256i broadcast = _mm256_permute2x128_si256(self, self, 0x11); // [high | high] + return _mm256_shuffle_epi8(broadcast, lane_mask.as_batch()); + } + + // swap lanes + __m256i swapped = _mm256_permute2x128_si256(self, self, 0x01); // [high | low] + + // We can outsmart the dynamic version by creating a compile-time mask that leaves zeros + // where it does not need to select data, resulting in a simple OR merge of the two batches. + constexpr auto self_mask = detail::swizzle_make_self_batch(); + constexpr auto cross_mask = detail::swizzle_make_cross_batch(); + + // permute bytes within each lane (AVX2 only) + __m256i r0 = _mm256_shuffle_epi8(self, self_mask.as_batch()); + __m256i r1 = _mm256_shuffle_epi8(swapped, cross_mask.as_batch()); + + return _mm256_or_si256(r0, r1); + } + template = 0> - XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch req) noexcept + XSIMD_INLINE batch swizzle(batch const& self, batch_constant const& mask, requires_arch req) noexcept { static_assert(sizeof...(Vals) == 32, "Must contain as many uint8_t as can fit in avx register"); - return swizzle(self, mask.as_batch(), req); + return bitwise_cast(swizzle(bitwise_cast(self), mask, req)); + } + + template < + class A, + uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, + uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7, + uint16_t V8, uint16_t V9, uint16_t V10, uint16_t V11, + uint16_t V12, uint16_t V13, uint16_t V14, uint16_t V15> + XSIMD_INLINE batch swizzle( + batch const& self, + batch_constant, + requires_arch req) noexcept + { + const auto self_bytes = bitwise_cast(self); + // If a mask entry is k, we want 2k in low byte and 2k+1 in high byte + auto constexpr mask_2k_2kp1 = batch_constant< + uint8_t, A, + 2 * V0, 2 * V0 + 1, 2 * V1, 2 * V1 + 1, 2 * V2, 2 * V2 + 1, 2 * V3, 2 * V3 + 1, + 2 * V4, 2 * V4 + 1, 2 * V5, 2 * V5 + 1, 2 * V6, 2 * V6 + 1, 2 * V7, 2 * V7 + 1, + 2 * V8, 2 * V8 + 1, 2 * V9, 2 * V9 + 1, 2 * V10, 2 * V10 + 1, 2 * V11, 2 * V11 + 1, + 2 * V12, 2 * V12 + 1, 2 * V13, 2 * V13 + 1, 2 * V14, 2 * V14 + 1, 2 * V15, 2 * V15 + 1> {}; + return bitwise_cast(swizzle(self_bytes, mask_2k_2kp1, req)); } template = 0> - XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch req) noexcept + XSIMD_INLINE batch swizzle(batch const& self, batch_constant const& mask, requires_arch req) noexcept { static_assert(sizeof...(Vals) == 16, "Must contain as many uint16_t as can fit in avx register"); - return swizzle(self, mask.as_batch(), req); + return bitwise_cast(swizzle(bitwise_cast(self), mask, req)); } template - XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept { - XSIMD_IF_CONSTEXPR(detail::is_all_different(mask) && !detail::is_identity(mask)) + XSIMD_IF_CONSTEXPR(detail::is_identity(mask)) { - // The intrinsic does NOT allow to copy the same element of the source vector to more than one element of the destination vector. - // one-shot 8-lane permute - return _mm256_permutevar8x32_ps(self, mask.as_batch()); + return self; } - return swizzle(self, mask, avx {}); - } - - template - XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept - { - XSIMD_IF_CONSTEXPR(detail::is_identity(mask)) { return self; } XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask)) { - constexpr auto imm = ((V0 & 1) << 0) | ((V1 & 1) << 1) | ((V2 & 1) << 2) | ((V3 & 1) << 3); - return _mm256_permute_pd(self, imm); + constexpr auto lane_mask = mask % make_batch_constant(); + // Cheaper intrinsics when not crossing lanes + // Contrary to the uint64_t version, the limits of 8 bits for the immediate constant + // cannot make different permutations across lanes + batch permuted = _mm256_permutevar_ps(bitwise_cast(self), lane_mask.as_batch()); + return bitwise_cast(permuted); } - constexpr auto imm = detail::mod_shuffle(V0, V1, V2, V3); - // fallback to full 4-element permute - return _mm256_permute4x64_pd(self, imm); + return _mm256_permutevar8x32_epi32(self, mask.as_batch()); } - template - XSIMD_INLINE batch swizzle(batch const& self, batch_constant, requires_arch) noexcept + template = 0> + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch req) noexcept { - constexpr auto mask = detail::mod_shuffle(V0, V1, V2, V3); - return _mm256_permute4x64_epi64(self, mask); + return bitwise_cast(swizzle(bitwise_cast(self), mask, req)); } + template - XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept { - return bitwise_cast(swizzle(bitwise_cast(self), mask, avx2 {})); - } - template - XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept - { - return _mm256_permutevar8x32_epi32(self, mask.as_batch()); + XSIMD_IF_CONSTEXPR(detail::is_identity(mask)) + { + return self; + } + XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask)) + { + constexpr uint8_t lane_mask = (V0 % 2) | ((V1 % 2) << 1) | ((V2 % 2) << 2) | ((V3 % 2) << 3); + // Cheaper intrinsics when not crossing lanes + batch permuted = _mm256_permute_pd(bitwise_cast(self), lane_mask); + return bitwise_cast(permuted); + } + constexpr auto mask_int = detail::mod_shuffle(V0, V1, V2, V3); + return _mm256_permute4x64_epi64(self, mask_int); } - template - XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + + template = 0> + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch req) noexcept { - return bitwise_cast(swizzle(bitwise_cast(self), mask, avx2 {})); + return bitwise_cast(swizzle(bitwise_cast(self), mask, req)); } // zip_hi diff --git a/include/xsimd/arch/xsimd_common_fwd.hpp b/include/xsimd/arch/xsimd_common_fwd.hpp index cff63e739..2b401155b 100644 --- a/include/xsimd/arch/xsimd_common_fwd.hpp +++ b/include/xsimd/arch/xsimd_common_fwd.hpp @@ -60,15 +60,15 @@ namespace xsimd template XSIMD_INLINE constexpr bool is_identity(batch_constant) noexcept; template - XSIMD_INLINE constexpr bool is_all_different(batch_constant) noexcept; - template XSIMD_INLINE constexpr bool is_dup_lo(batch_constant) noexcept; template XSIMD_INLINE constexpr bool is_dup_hi(batch_constant) noexcept; template XSIMD_INLINE constexpr bool is_cross_lane(batch_constant) noexcept; template - XSIMD_INLINE constexpr bool no_duplicates(batch_constant) noexcept; + XSIMD_INLINE constexpr bool is_only_from_lo(batch_constant) noexcept; + template + XSIMD_INLINE constexpr bool is_only_from_hi(batch_constant) noexcept; } } diff --git a/test/test_batch_manip.cpp b/test/test_batch_manip.cpp index dcc88c6cb..7da46e736 100644 --- a/test/test_batch_manip.cpp +++ b/test/test_batch_manip.cpp @@ -26,8 +26,7 @@ namespace xsimd // compile-time tests (identity, all-different, dup-lo, dup-hi) // 8-lane identity static_assert(is_identity(), "identity failed"); - // 8-lane reverse is all-different but not identity - static_assert(is_all_different(), "all-diff failed"); + // 8-lane reverse is not identity static_assert(!is_identity(), "identity on reverse"); // 8-lane dup-lo (repeat 0..3 twice) static_assert(is_dup_lo(), "dup_lo failed"); @@ -35,11 +34,16 @@ namespace xsimd // 8-lane dup-hi (repeat 4..7 twice) static_assert(is_dup_hi(), "dup_hi failed"); static_assert(!is_dup_lo(), "dup_lo on dup_hi"); + // 8-lane is-only-from-hi (repeat 4..7 twice) + static_assert(is_only_from_hi(), "only_from_hi on hi"); + static_assert(!is_only_from_hi(), "only_from_hi failed"); + // 8-lane is-only-from-lo (repeat 4..7 twice) + static_assert(is_only_from_lo(), "only_from_lo on lo"); + static_assert(!is_only_from_lo(), "only_from_lo failed"); // ──────────────────────────────────────────────────────────────────────── // 4-lane identity static_assert(is_identity(), "4-lane identity failed"); - // 4-lane reverse all-different but not identity - static_assert(is_all_different(), "4-lane all-diff failed"); + // 4-lane reverse is not identity static_assert(!is_identity(), "4-lane identity on reverse"); // 4-lane dup-lo (repeat 0..1 twice) static_assert(is_dup_lo(), "4-lane dup_lo failed"); @@ -53,12 +57,6 @@ namespace xsimd static_assert(is_cross_lane<0, 3, 3, 3>(), "one low + rest high → crossing"); static_assert(!is_cross_lane<1, 0, 2, 3>(), "mixed low/high → no crossing"); static_assert(!is_cross_lane<0, 1, 2, 3>(), "mixed low/high → no crossing"); - - static_assert(no_duplicates_v<0, 1, 2, 3>(), "N=4: [0,1,2,3] → distinct"); - static_assert(!no_duplicates_v<0, 1, 2, 2>(), "N=4: [0,1,2,2] → dup"); - - static_assert(no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 7>(), "N=8: [0..7] → distinct"); - static_assert(!no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 0>(), "N=8: last repeats 0"); } } }