diff --git a/aten/src/ATen/native/cpu/SumKernel.cpp b/aten/src/ATen/native/cpu/SumKernel.cpp index e47a417983d8..32daf2cc84b5 100644 --- a/aten/src/ATen/native/cpu/SumKernel.cpp +++ b/aten/src/ATen/native/cpu/SumKernel.cpp @@ -11,8 +11,8 @@ namespace at { namespace native { namespace { -// Load vector from a smaller type to a larger type, reducing neighboring -// elements until it fits into the vector size. +// Load vector from a smaller type (more elements) to a larger type (fewer elements), +// reducing neighboring elements until it fits into the vector size. template Vectorized load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) { using vec_t = Vectorized; @@ -511,6 +511,87 @@ void scalar_outer_sum( } } +// Custom floating point sum for better accuracy +template +void cascade_sum(TensorIterator &iter) { + iter.output().fill_(scalar_t(0)); + iter.parallel_reduce( + [&](char** data, const int64_t* strides, int64_t size0, int64_t size1) { + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + int64_t in_strides[] = { strides[1], strides[3] }; + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + int64_t out_strides[] = { strides[0], strides[2] }; + + // Move reduction to be the 1st dim + if (out_strides[0] != 0 && out_strides[1] == 0) { + std::swap(in_strides[0], in_strides[1]); + std::swap(out_strides[0], out_strides[1]); + std::swap(size0, size1); + } + + // Special case? - not a true reduction + if (out_strides[0] != 0 && out_strides[1] != 0) { + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + int64_t outer_strides[] = { strides[2], strides[3] }; + UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + char* ptrs[3] = { data[0], data[0], data[1] }; + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + int64_t inner_strides[3] = { strides[0], strides[0], strides[1] }; + c10::guts::if_constexpr( + [&](auto) { + basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) { + auto a_notnan = at::_isnan(a) ? scalar_t(0) : a; + auto b_notnan = at::_isnan(b) ? scalar_t(0) : b; + return a_notnan + b_notnan; + }); + }, + [&](auto) { + basic_loop(ptrs, inner_strides, 0, size0, + [](scalar_t a, scalar_t b) { return a + b; }); + }); + }); + return; + } + + const int64_t out_stride = out_strides[1]; + TORCH_INTERNAL_ASSERT(out_strides[0] == 0); + + using vec_t = Vectorized; + using acc_t = at::acc_type; + using vacc_t = Vectorized; + using ScalarLoadPolicy = std::conditional_t< + ignore_nan, + NanSumCastLoadPolicy, + CastLoadPolicy>; + using StorePolicy = CastStoreAccumulate; + + if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) { + // Contiguous inner reduction + using VecLoadPolicy = std::conditional_t< + ignore_nan, + InnerNanSumCastLoadPolicy, + InnerSumCastLoadPolicy>; + vectorized_inner_sum( + data, in_strides[1], out_stride, size0, size1); + } else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) { + // Contiguous outer reduction + using VecLoadPolicy = std::conditional_t< + ignore_nan, + OuterNanSumCastLoadPolicy, + OuterSumCastLoadPolicy>; + vectorized_outer_sum( + data, in_strides[0], out_stride, size0, size1); + } else if (in_strides[0] < in_strides[1]) { + scalar_inner_sum( + data, in_strides, out_stride, size0, size1); + } else { + scalar_outer_sum( + data, in_strides, out_stride, size0, size1); + } + }); +} + void sum_kernel_impl(TensorIterator &iter) { if (isIntegralType(iter.dtype(), /*includeBool=*/ true)) { AT_DISPATCH_INTEGRAL_TYPES_AND(ScalarType::Bool, iter.dtype(), "sum_cpu", @@ -522,133 +603,17 @@ void sum_kernel_impl(TensorIterator &iter) { return; } - // Custom floating point sum for better accuracy AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "sum_cpu", - [&] { - iter.output().fill_(scalar_t(0)); - iter.parallel_reduce( - [&](char** data, const int64_t* strides, int64_t size0, int64_t size1) { - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - int64_t in_strides[] = { strides[1], strides[3] }; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - int64_t out_strides[] = { strides[0], strides[2] }; - - // Move reduction to be the 1st dim - if (out_strides[0] != 0 && out_strides[1] == 0) { - std::swap(in_strides[0], in_strides[1]); - std::swap(out_strides[0], out_strides[1]); - std::swap(size0, size1); - } - - // Special case? - not a true reduction - if (out_strides[0] != 0 && out_strides[1] != 0) { - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - int64_t outer_strides[] = { strides[2], strides[3] }; - UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - char* ptrs[3] = { data[0], data[0], data[1] }; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - int64_t inner_strides[3] = { strides[0], strides[0], strides[1] }; - basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) { return a + b; }); - }); - return; - } - - const int64_t out_stride = out_strides[1]; - TORCH_INTERNAL_ASSERT(out_strides[0] == 0); - - using vec_t = Vectorized; - using acc_t = at::acc_type; - using vacc_t = Vectorized; - using ScalarLoadPolicy = CastLoadPolicy; - using StorePolicy = CastStoreAccumulate; - - if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) { - // Contiguous inner reduction - using VecLoadPolicy = InnerSumCastLoadPolicy; - vectorized_inner_sum( - data, in_strides[1], out_stride, size0, size1); - } else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) { - // Contiguous outer reduction - using VecLoadPolicy = OuterSumCastLoadPolicy; - vectorized_outer_sum( - data, in_strides[0], out_stride, size0, size1); - } else if (in_strides[0] < in_strides[1]) { - scalar_inner_sum( - data, in_strides, out_stride, size0, size1); - } else { - scalar_outer_sum( - data, in_strides, out_stride, size0, size1); - } - }); - }); + ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "sum_cpu", [&] { + cascade_sum(iter); + }); } void nansum_kernel_impl(TensorIterator &iter) { AT_DISPATCH_FLOATING_TYPES_AND( - ScalarType::Half, iter.dtype(), "nansum_cpu", - [&] { - iter.output().fill_(scalar_t(0)); - iter.parallel_reduce( - [&](char** data, const int64_t* strides, int64_t size0, int64_t size1) { - std::array in_strides{{ strides[1], strides[3] }}; - std::array out_strides{{ strides[0], strides[2] }}; - - // Move reduction to be the 1st dim - if (out_strides[0] != 0 && out_strides[1] == 0) { - std::swap(in_strides[0], in_strides[1]); - std::swap(out_strides[0], out_strides[1]); - std::swap(size0, size1); - } - - // Special case? - not a true reduction - if (out_strides[0] != 0 && out_strides[1] != 0) { - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - int64_t outer_strides[] = { strides[2], strides[3] }; - UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - char* ptrs[] = { data[0], data[0], data[1] }; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - int64_t inner_strides[] = { strides[0], strides[0], strides[1] }; - basic_loop(ptrs, inner_strides, 0, size0, - [](scalar_t a, scalar_t b) { - auto notnan_a = at::_isnan(a) ? scalar_t(0) : a; - auto notnan_b = at::_isnan(b) ? scalar_t(0) : b; - return notnan_a + notnan_b; - }); - }); - return; - } - - const int64_t out_stride = out_strides[1]; - TORCH_INTERNAL_ASSERT(out_strides[0] == 0); - - using vec_t = Vectorized; - using acc_t = at::acc_type; - using vacc_t = Vectorized; - using ScalarLoadPolicy = NanSumCastLoadPolicy; - using StorePolicy = CastStoreAccumulate; - - if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) { - // Contiguous inner reduction - using VecLoadPolicy = InnerNanSumCastLoadPolicy; - vectorized_inner_sum( - data, in_strides[1], out_stride, size0, size1); - } else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) { - // Contiguous outer reduction - using VecLoadPolicy = OuterNanSumCastLoadPolicy; - vectorized_outer_sum( - data, in_strides[0], out_stride, size0, size1); - } else if (in_strides[0] < in_strides[1]) { - scalar_inner_sum( - data, in_strides.data(), out_stride, size0, size1); - } else { - scalar_outer_sum( - data, in_strides.data(), out_stride, size0, size1); - } - }); - }); + ScalarType::Half, iter.dtype(), "nansum_cpu", [&] { + cascade_sum(iter); + }); } } // namespace (anonymous)