Skip to content

Commit cb7d813

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Revert D28836794: SumKernel (BFloat16): use float as accumulation type
Test Plan: revert-hammer Differential Revision: D28836794 (4f5c688) Original commit changeset: 46ed3a862c2b fbshipit-source-id: 3b586eeb752b7cdee909fa97a4c78876a6014770
1 parent 11dca2e commit cb7d813

File tree

2 files changed

+4
-37
lines changed

2 files changed

+4
-37
lines changed

aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -742,26 +742,4 @@ inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, c
742742

743743
#endif
744744

745-
struct Vec2f {
746-
Vectorized<float> val0, val1;
747-
Vec2f() {}
748-
Vec2f(float v) : val0(v), val1(v) {}
749-
Vec2f(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
750-
operator Vectorized<BFloat16>() const {
751-
return convert_float_bfloat16(val0, val1);
752-
}
753-
};
754-
inline Vec2f& operator+= (Vec2f& a, const Vec2f& b) {
755-
a.val0 += b.val0;
756-
a.val1 += b.val1;
757-
return a;
758-
}
759-
inline Vec2f& operator+= (Vec2f& a, const Vectorized<BFloat16>& b) {
760-
Vectorized<float> b0, b1;
761-
std::tie(b0, b1) = convert_bfloat16_float(b);
762-
a.val0 += b0;
763-
a.val1 += b1;
764-
return a;
765-
}
766-
767745
}}}

aten/src/ATen/native/cpu/SumKernel.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,6 @@ namespace at {
1111
namespace native {
1212
namespace {
1313

14-
// use float as accumulation type for BFloat16
15-
template <typename scalar_t> struct AccType { using type = scalar_t; };
16-
template <> struct AccType<BFloat16> { using type = float; };
17-
18-
template <typename scalar_t> struct AccType<Vectorized<scalar_t>> { using type = Vectorized<scalar_t>; };
19-
template <> struct AccType<Vectorized<BFloat16>> { using type = Vec2f; };
20-
21-
template <typename scalar_t>
22-
using acc_type = typename AccType<scalar_t>::type;
23-
2414
template <typename scalar_t>
2515
struct LoadPolicy {
2616
static scalar_t load(const char * C10_RESTRICT data, int64_t stride, int64_t index) {
@@ -217,9 +207,8 @@ std::array<scalar_t, nrows> multi_row_sum(
217207
const int64_t level_mask = level_step - 1;
218208

219209
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
220-
using accscalar_t = acc_type<scalar_t>;
221-
accscalar_t acc[num_levels][nrows];
222-
std::fill_n(&acc[0][0], num_levels * nrows, accscalar_t(0));
210+
scalar_t acc[num_levels][nrows];
211+
std::fill_n(&acc[0][0], num_levels * nrows, scalar_t(0));
223212

224213
int64_t i = 0;
225214
for (; i + level_step <= size;) {
@@ -239,7 +228,7 @@ std::array<scalar_t, nrows> multi_row_sum(
239228
#endif
240229
for (int64_t k = 0; k < nrows; ++k) {
241230
acc[j][k] += acc[j-1][k];
242-
acc[j-1][k] = accscalar_t(0);
231+
acc[j-1][k] = scalar_t(0);
243232
}
244233

245234
const auto mask = (level_mask << (j * level_power));
@@ -271,7 +260,7 @@ std::array<scalar_t, nrows> multi_row_sum(
271260
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
272261
std::array<scalar_t, nrows> ret;
273262
for (int64_t k = 0; k < nrows; ++k) {
274-
ret[k] = scalar_t(acc[0][k]);
263+
ret[k] = acc[0][k];
275264
}
276265
return ret;
277266
}

0 commit comments

Comments
 (0)