@@ -11,16 +11,6 @@ namespace at {
1111namespace native {
1212namespace {
1313
14- // use float as accumulation type for BFloat16
15- template <typename scalar_t > struct AccType { using type = scalar_t ; };
16- template <> struct AccType <BFloat16> { using type = float ; };
17-
18- template <typename scalar_t > struct AccType <Vectorized<scalar_t >> { using type = Vectorized<scalar_t >; };
19- template <> struct AccType <Vectorized<BFloat16>> { using type = Vec2f; };
20-
21- template <typename scalar_t >
22- using acc_type = typename AccType<scalar_t >::type;
23-
2414template <typename scalar_t >
2515struct LoadPolicy {
2616 static scalar_t load (const char * C10_RESTRICT data, int64_t stride, int64_t index) {
@@ -217,9 +207,8 @@ std::array<scalar_t, nrows> multi_row_sum(
217207 const int64_t level_mask = level_step - 1 ;
218208
219209 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
220- using accscalar_t = acc_type<scalar_t >;
221- accscalar_t acc[num_levels][nrows];
222- std::fill_n (&acc[0 ][0 ], num_levels * nrows, accscalar_t (0 ));
210+ scalar_t acc[num_levels][nrows];
211+ std::fill_n (&acc[0 ][0 ], num_levels * nrows, scalar_t (0 ));
223212
224213 int64_t i = 0 ;
225214 for (; i + level_step <= size;) {
@@ -239,7 +228,7 @@ std::array<scalar_t, nrows> multi_row_sum(
239228 #endif
240229 for (int64_t k = 0 ; k < nrows; ++k) {
241230 acc[j][k] += acc[j-1 ][k];
242- acc[j-1 ][k] = accscalar_t (0 );
231+ acc[j-1 ][k] = scalar_t (0 );
243232 }
244233
245234 const auto mask = (level_mask << (j * level_power));
@@ -271,7 +260,7 @@ std::array<scalar_t, nrows> multi_row_sum(
271260 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
272261 std::array<scalar_t , nrows> ret;
273262 for (int64_t k = 0 ; k < nrows; ++k) {
274- ret[k] = scalar_t ( acc[0 ][k]) ;
263+ ret[k] = acc[0 ][k];
275264 }
276265 return ret;
277266}
0 commit comments