Skip to content

Commit

Permalink
__half2_raw initiation error workaround for transformer inference on …
Browse files Browse the repository at this point in the history
…ROCm (microsoft#60)
  • Loading branch information
rraminen committed May 12, 2023
1 parent 82991df commit 995f3db
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions csrc/includes/reduction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,33 @@ DS_D_INLINE __half init<ROpType::Max>()
template <>
DS_D_INLINE __half2 init<ROpType::Add>()
{
#if defined(__HIP_PLATFORM_HCC__)
constexpr __half2_raw zero = {_Float16_2{0x0000,0x0000}};
#else
constexpr __half2_raw zero = {0x0000, 0x0000};
#endif
return __half2(zero);
}

template <>
DS_D_INLINE __half2 init<ROpType::Min>()
{
#if defined(__HIP_PLATFORM_HCC__)
constexpr __half2_raw inf = {_Float16_2{0x7C00,0x7C00}};
#else
constexpr __half2_raw inf = {0x7C00, 0x7C00};
#endif
return __half2(inf);
}

template <>
DS_D_INLINE __half2 init<ROpType::Max>()
{
#if defined(__HIP_PLATFORM_HCC__)
constexpr __half2_raw neg_inf = {_Float16_2{0xFC00,0xFC00}};
#else
constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
#endif
return __half2(neg_inf);
}

Expand Down

0 comments on commit 995f3db

Please sign in to comment.