Skip to content

Commit

Permalink
Merge pull request #925 from xtensor-stack/feature/shuffle-float
Browse files Browse the repository at this point in the history
Provide xsimd::shuffle of floating point batches
  • Loading branch information
JohanMabille committed May 23, 2023
2 parents 8ffcae8 + daf0ce9 commit 8e2189b
Show file tree
Hide file tree
Showing 7 changed files with 467 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ jobs:
- { compiler: 'gcc', version: '7', flags: 'force_no_instr_set' }
- { compiler: 'gcc', version: '8', flags: 'enable_xtl_complex' }
- { compiler: 'gcc', version: '9', flags: 'avx' }
- { compiler: 'gcc', version: '10', flags: 'avx512' }
#- { compiler: 'gcc', version: '10', flags: 'avx512' } buggy
- { compiler: 'gcc', version: '11', flags: 'avx512' }
- { compiler: 'gcc', version: '11', flags: 'i386' }
- { compiler: 'gcc', version: '11', flags: 'avx512cd' }
- { compiler: 'clang', version: '8', flags: 'force_no_instr_set' }
Expand Down
99 changes: 99 additions & 0 deletions include/xsimd/arch/generic/xsimd_generic_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ namespace xsimd
template <class batch_type, typename batch_type::value_type... Values>
struct batch_constant;

template <class batch_type, bool... Values>
struct batch_bool_constant;

namespace kernel
{

Expand Down Expand Up @@ -286,6 +289,102 @@ namespace xsimd
kernel::scatter<A>(tmp, dst, index, A {});
}

// shuffle
namespace detail
{
constexpr bool is_swizzle_fst(size_t)
{
return true;
}
template <typename ITy, typename... ITys>
constexpr bool is_swizzle_fst(size_t bsize, ITy index, ITys... indices)
{
return index < bsize && is_swizzle_fst(bsize, indices...);
}
constexpr bool is_swizzle_snd(size_t)
{
return true;
}
template <typename ITy, typename... ITys>
constexpr bool is_swizzle_snd(size_t bsize, ITy index, ITys... indices)
{
return index >= bsize && is_swizzle_snd(bsize, indices...);
}

constexpr bool is_zip_lo(size_t)
{
return true;
}

template <typename ITy0, typename ITy1, typename... ITys>
constexpr bool is_zip_lo(size_t bsize, ITy0 index0, ITy1 index1, ITys... indices)
{
return index0 == (bsize - (sizeof...(indices) + 2)) && index1 == (2 * bsize - (sizeof...(indices) + 2)) && is_zip_lo(bsize, indices...);
}

constexpr bool is_zip_hi(size_t)
{
return true;
}

template <typename ITy0, typename ITy1, typename... ITys>
constexpr bool is_zip_hi(size_t bsize, ITy0 index0, ITy1 index1, ITys... indices)
{
return index0 == (bsize / 2 + bsize - (sizeof...(indices) + 2)) && index1 == (bsize / 2 + 2 * bsize - (sizeof...(indices) + 2)) && is_zip_hi(bsize, indices...);
}

constexpr bool is_select(size_t)
{
return true;
}

template <typename ITy, typename... ITys>
constexpr bool is_select(size_t bsize, ITy index, ITys... indices)
{
return (index < bsize ? index : index - bsize) == (bsize - sizeof...(ITys)) && is_select(bsize, indices...);
}

}

template <class A, typename T, typename ITy, ITy... Indices>
inline batch<T, A> shuffle(batch<T, A> const& x, batch<T, A> const& y, batch_constant<batch<ITy, A>, Indices...>, requires_arch<generic>) noexcept
{
constexpr size_t bsize = sizeof...(Indices);

// Detect common patterns
XSIMD_IF_CONSTEXPR(detail::is_swizzle_fst(bsize, Indices...))
{
return swizzle(x, batch_constant<batch<ITy, A>, ((Indices >= bsize) ? 0 /* never happens */ : Indices)...>());
}

XSIMD_IF_CONSTEXPR(detail::is_swizzle_snd(bsize, Indices...))
{
return swizzle(y, batch_constant<batch<ITy, A>, ((Indices >= bsize) ? (Indices - bsize) : 0 /* never happens */)...>());
}

XSIMD_IF_CONSTEXPR(detail::is_zip_lo(bsize, Indices...))
{
return zip_lo(x, y);
}

XSIMD_IF_CONSTEXPR(detail::is_zip_hi(bsize, Indices...))
{
return zip_hi(x, y);
}

XSIMD_IF_CONSTEXPR(detail::is_select(bsize, Indices...))
{
return select(batch_bool_constant<batch<T, A>, (Indices < bsize)...>(), x, y);
}

// Use a generic_pattern. It is suboptimal but clang optimizes this
// pretty well.
batch<T, A> x_lane = swizzle(x, batch_constant<batch<ITy, A>, ((Indices >= bsize) ? (Indices - bsize) : Indices)...>());
batch<T, A> y_lane = swizzle(y, batch_constant<batch<ITy, A>, ((Indices >= bsize) ? (Indices - bsize) : Indices)...>());
batch_bool_constant<batch<T, A>, (Indices < bsize)...> select_x_lane;
return select(select_x_lane, x_lane, y_lane);
}

// store
template <class T, class A>
inline void store(batch_bool<T, A> const& self, bool* mem, requires_arch<generic>) noexcept
Expand Down
31 changes: 31 additions & 0 deletions include/xsimd/arch/xsimd_avx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,37 @@ namespace xsimd
return _mm256_castsi256_pd(set(batch<int64_t, A>(), A {}, static_cast<int64_t>(values ? -1LL : 0LL)...).data);
}

// shuffle
template <class A, class ITy, ITy I0, ITy I1, ITy I2, ITy I3, ITy I4, ITy I5, ITy I6, ITy I7>
inline batch<float, A> shuffle(batch<float, A> const& x, batch<float, A> const& y, batch_constant<batch<ITy, A>, I0, I1, I2, I3, I4, I5, I6, I7> mask, requires_arch<avx>) noexcept
{
constexpr uint32_t smask = detail::mod_shuffle(I0, I1, I2, I3);
// shuffle within lane
if (I4 == (I0 + 4) && I5 == (I1 + 4) && I6 == (I2 + 4) && I7 == (I3 + 4) && I0 < 4 && I1 < 4 && I2 >= 8 && I2 < 12 && I3 >= 8 && I3 < 12)
return _mm256_shuffle_ps(x, y, smask);

// shuffle within opposite lane
if (I4 == (I0 + 4) && I5 == (I1 + 4) && I6 == (I2 + 4) && I7 == (I3 + 4) && I2 < 4 && I3 < 4 && I0 >= 8 && I0 < 12 && I1 >= 8 && I1 < 12)
return _mm256_shuffle_ps(y, x, smask);

return shuffle(x, y, mask, generic {});
}

template <class A, class ITy, ITy I0, ITy I1, ITy I2, ITy I3>
inline batch<double, A> shuffle(batch<double, A> const& x, batch<double, A> const& y, batch_constant<batch<ITy, A>, I0, I1, I2, I3> mask, requires_arch<avx>) noexcept
{
constexpr uint32_t smask = (I0 & 0x1) | ((I1 & 0x1) << 1) | ((I2 & 0x1) << 2) | ((I3 & 0x1) << 3);
// shuffle within lane
if (I0 < 2 && I1 >= 4 && I1 < 6 && I2 >= 2 && I2 < 4 && I3 >= 6)
return _mm256_shuffle_pd(x, y, smask);

// shuffle within opposite lane
if (I1 < 2 && I0 >= 4 && I0 < 6 && I3 >= 2 && I3 < 4 && I2 >= 6)
return _mm256_shuffle_pd(y, x, smask);

return shuffle(x, y, mask, generic {});
}

// slide_left
template <size_t N, class A, class T>
inline batch<T, A> slide_left(batch<T, A> const& x, requires_arch<avx>) noexcept
Expand Down
34 changes: 34 additions & 0 deletions include/xsimd/arch/xsimd_avx512f.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,40 @@ namespace xsimd
return r;
}

// shuffle
template <class A, class ITy, ITy I0, ITy I1, ITy I2, ITy I3, ITy I4, ITy I5, ITy I6, ITy I7, ITy I8, ITy I9, ITy I10, ITy I11, ITy I12, ITy I13, ITy I14, ITy I15>
inline batch<float, A> shuffle(batch<float, A> const& x, batch<float, A> const& y,
batch_constant<batch<ITy, A>, I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15> mask,
requires_arch<avx512f>) noexcept
{
constexpr uint32_t smask = (I0 & 0x3) | ((I1 & 0x3) << 2) | ((I2 & 0x3) << 4) | ((I3 & 0x3) << 6);

// shuffle within lane
if ((I4 == I0 + 4) && (I5 == I1 + 4) && (I6 == I2 + 4) && (I7 == I3 + 4) && (I8 == I0 + 8) && (I9 == I1 + 8) && (I10 == I2 + 8) && (I11 == I3 + 8) && (I12 == I0 + 12) && (I13 == I1 + 12) && (I14 == I2 + 12) && (I15 == I3 + 12) && I0 < 4 && I1 < 4 && I2 >= 16 && I2 < 20 && I3 >= 16 && I3 < 20)
return _mm512_shuffle_ps(x, y, smask);

// shuffle within opposite lane
if ((I4 == I0 + 4) && (I5 == I1 + 4) && (I6 == I2 + 4) && (I7 == I3 + 4) && (I8 == I0 + 8) && (I9 == I1 + 8) && (I10 == I2 + 8) && (I11 == I3 + 8) && (I12 == I0 + 12) && (I13 == I1 + 12) && (I14 == I2 + 12) && (I15 == I3 + 12) && I2 < 4 && I3 < 4 && I0 >= 16 && I0 < 20 && I1 >= 16 && I1 < 20)
return _mm512_shuffle_ps(y, x, smask);

return shuffle(x, y, mask, generic {});
}

template <class A, class ITy, ITy I0, ITy I1, ITy I2, ITy I3, ITy I4, ITy I5, ITy I6, ITy I7>
inline batch<double, A> shuffle(batch<double, A> const& x, batch<double, A> const& y, batch_constant<batch<ITy, A>, I0, I1, I2, I3, I4, I5, I6, I7> mask, requires_arch<avx512f>) noexcept
{
constexpr uint32_t smask = (I0 & 0x1) | ((I1 & 0x1) << 1) | ((I2 & 0x1) << 2) | ((I3 & 0x1) << 3) | ((I4 & 0x1) << 4) | ((I5 & 0x1) << 5) | ((I6 & 0x1) << 6) | ((I7 & 0x1) << 7);
// shuffle within lane
if (I0 < 2 && I1 >= 8 && I1 < 10 && I2 >= 2 && I2 < 4 && I3 >= 10 && I3 < 12 && I4 >= 4 && I4 < 6 && I5 >= 12 && I5 < 14 && I6 >= 6 && I6 < 8 && I7 >= 14)
return _mm512_shuffle_pd(x, y, smask);

// shuffle within opposite lane
if (I1 < 2 && I0 >= 8 && I0 < 10 && I3 >= 2 && I3 < 4 && I2 >= 10 && I2 < 12 && I5 >= 4 && I5 < 6 && I4 >= 12 && I4 < 14 && I7 >= 6 && I7 < 8 && I6 >= 14)
return _mm512_shuffle_pd(y, x, smask);

return shuffle(x, y, mask, generic {});
}

// slide_left
template <size_t N, class A, class T>
inline batch<T, A> slide_left(batch<T, A> const&, requires_arch<avx512f>) noexcept
Expand Down
41 changes: 41 additions & 0 deletions include/xsimd/arch/xsimd_sse2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,23 @@ namespace xsimd
{
return (y << 1) | x;
}

constexpr uint32_t mod_shuffle(uint32_t w, uint32_t x, uint32_t y, uint32_t z)
{
return shuffle(w % 4, x % 4, y % 4, z % 4);
}

constexpr uint32_t mod_shuffle(uint32_t w, uint32_t x)
{
return shuffle(w % 2, x % 2);
}
}

// fwd
template <class A, class T, size_t I>
inline batch<T, A> insert(batch<T, A> const& self, T val, index<I>, requires_arch<generic>) noexcept;
template <class A, typename T, typename ITy, ITy... Indices>
inline batch<T, A> shuffle(batch<T, A> const& x, batch<T, A> const& y, batch_constant<batch<ITy, A>, Indices...>, requires_arch<generic>) noexcept;

// abs
template <class A>
Expand Down Expand Up @@ -1312,6 +1324,35 @@ namespace xsimd
return _mm_or_pd(_mm_and_pd(cond, true_br), _mm_andnot_pd(cond, false_br));
}

// shuffle
template <class A, class ITy, ITy I0, ITy I1, ITy I2, ITy I3>
inline batch<float, A> shuffle(batch<float, A> const& x, batch<float, A> const& y, batch_constant<batch<ITy, A>, I0, I1, I2, I3> mask, requires_arch<sse2>) noexcept
{
constexpr uint32_t smask = detail::mod_shuffle(I0, I1, I2, I3);
// shuffle within lane
if (I0 < 4 && I1 < 4 && I2 >= 4 && I3 >= 4)
return _mm_shuffle_ps(x, y, smask);

// shuffle within opposite lane
if (I0 >= 4 && I1 >= 4 && I2 < 4 && I3 < 4)
return _mm_shuffle_ps(y, x, smask);
return shuffle(x, y, mask, generic {});
}

template <class A, class ITy, ITy I0, ITy I1>
inline batch<double, A> shuffle(batch<double, A> const& x, batch<double, A> const& y, batch_constant<batch<ITy, A>, I0, I1> mask, requires_arch<sse2>) noexcept
{
constexpr uint32_t smask = detail::mod_shuffle(I0, I1);
// shuffle within lane
if (I0 < 2 && I1 >= 2)
return _mm_shuffle_pd(x, y, smask);

// shuffle within opposite lane
if (I0 >= 2 && I1 < 2)
return _mm_shuffle_pd(y, x, smask);
return shuffle(x, y, mask, generic {});
}

// sqrt
template <class A>
inline batch<float, A> sqrt(batch<float, A> const& val, requires_arch<sse2>) noexcept
Expand Down
25 changes: 25 additions & 0 deletions include/xsimd/types/xsimd_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,31 @@ namespace xsimd
return kernel::select<A>(cond, true_br, false_br, A {});
}

/**
* @ingroup batch_data_transfer
*
* Combine elements from \c x and \c y according to selector \c mask
* @param x batch
* @param y batch
* @param mask constant batch mask of integer elements of the same size as
* element of \c x and \c y. Each element of the mask index the vector that
* would be formed by the concatenation of \c x and \c y. For instance
* \code{.cpp}
* batch_constant<batch<uint32_t, sse2>, 0, 4, 3, 7>
* \endcode
* Picks \c x[0], \c y[0], \c x[3], \c y[3]
*
* @return combined batch
*/
template <class T, class A, class Vt, Vt... Values>
inline typename std::enable_if<std::is_arithmetic<T>::value, batch<T, A>>::type
shuffle(batch<T, A> const& x, batch<T, A> const& y, batch_constant<batch<Vt, A>, Values...> mask) noexcept
{
static_assert(sizeof(T) == sizeof(Vt), "consistent mask");
detail::static_check_supported_config<T, A>();
return kernel::shuffle<A>(x, y, mask, A {});
}

/**
* @ingroup batch_miscellaneous
*
Expand Down

0 comments on commit 8e2189b

Please sign in to comment.