Skip to content

Commit

Permalink
Workaround for ROCm 5.6+ failing to compile with AVX2 SIMD support (k…
Browse files Browse the repository at this point in the history
…okkos#6449)

* Workaround for ROCm 5.6+ failing to compile with AVX2 SIMD support

* Introduce KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE

* Undefine KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE unconditionally
  • Loading branch information
masterleinad committed Sep 21, 2023
1 parent 2e74367 commit b9fa28c
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions simd/src/Kokkos_SIMD_AVX2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@

#include <immintrin.h>

// FIXME_HIP ROCm 5.6 can't compile with the intrinsic used here.
#if defined(__HIPCC__) && \
((HIP_VERSION_MAJOR > 5) || \
((HIP_VERSION_MAJOR == 5) && (HIP_VERSION_MINOR >= 6)))
#define KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
#endif

namespace Kokkos {

namespace Experimental {
Expand Down Expand Up @@ -938,7 +945,12 @@ class simd<std::int32_t, simd_abi::avx2_fixed_size<4>> {
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
element_aligned_tag) {
// FIXME_HIP ROCm 5.6 can't compile with the intrinsic used here.
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
m_value = _mm_loadu_si128(reinterpret_cast<__m128i const*>(ptr));
#else
m_value = _mm_maskload_epi32(ptr, static_cast<__m128i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(
value_type* ptr, element_aligned_tag) const {
Expand Down Expand Up @@ -1079,8 +1091,12 @@ class simd<std::int64_t, simd_abi::avx2_fixed_size<4>> {
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
element_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
m_value = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(ptr));
#else
m_value = _mm256_maskload_epi64(reinterpret_cast<long long const*>(ptr),
static_cast<__m256i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to(
value_type* ptr, element_aligned_tag) const {
Expand Down Expand Up @@ -1232,8 +1248,12 @@ class simd<std::uint64_t, simd_abi::avx2_fixed_size<4>> {
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr,
element_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
m_value = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(ptr));
#else
m_value = _mm256_maskload_epi64(reinterpret_cast<long long const*>(ptr),
static_cast<__m256i>(mask_type(true)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator __m256i()
const {
Expand Down Expand Up @@ -1531,7 +1551,12 @@ class where_expression<simd_mask<std::int32_t, simd_abi::avx2_fixed_size<4>>,
: const_where_expression(mask_arg, value_arg) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void copy_from(std::int32_t const* mem, element_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
__m128i tmp = _mm_loadu_si128(reinterpret_cast<__m128i const*>(mem));
m_value = value_type(_mm_and_si128(tmp, static_cast<__m128i>(m_mask)));
#else
m_value = value_type(_mm_maskload_epi32(mem, static_cast<__m128i>(m_mask)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
Expand Down Expand Up @@ -1613,8 +1638,13 @@ class where_expression<simd_mask<std::int64_t, simd_abi::avx2_fixed_size<4>>,
: const_where_expression(mask_arg, value_arg) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(std::int64_t const* mem,
element_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
__m256i tmp = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(mem));
m_value = value_type(_mm256_and_si256(tmp, static_cast<__m256i>(m_mask)));
#else
m_value = value_type(_mm256_maskload_epi64(
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
Expand Down Expand Up @@ -1697,8 +1727,13 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>>,
: const_where_expression(mask_arg, value_arg) {}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(std::uint64_t const* mem,
element_aligned_tag) {
#ifdef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE
__m256i tmp = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(mem));
m_value = value_type(_mm256_and_si256(tmp, static_cast<__m256i>(m_mask)));
#else
m_value = value_type(_mm256_maskload_epi64(
reinterpret_cast<long long const*>(mem), static_cast<__m256i>(m_mask)));
#endif
}
KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION
void gather_from(
Expand Down Expand Up @@ -1728,4 +1763,6 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx2_fixed_size<4>>,
} // namespace Experimental
} // namespace Kokkos

#undef KOKKOS_IMPL_WORKAROUND_ROCM_AVX2_ISSUE

#endif

0 comments on commit b9fa28c

Please sign in to comment.