Skip to content

Commit

Permalink
Update base for Update on "Factor vector intrinsics out of SumKernel.…
Browse files Browse the repository at this point in the history
…cpp"


This will make it simpler to support AVX512 which is upcoming in #56992, see #56992 (comment) for reference.

[ghstack-poisoned]
  • Loading branch information
peterbell10 committed Jul 15, 2021
1 parent 04946d7 commit d531bed
Showing 1 changed file with 89 additions and 124 deletions.
213 changes: 89 additions & 124 deletions aten/src/ATen/native/cpu/SumKernel.cpp
Expand Up @@ -11,8 +11,8 @@ namespace at {
namespace native {
namespace {

// Load vector from a smaller type to a larger type, reducing neighboring
// elements until it fits into the vector size.
// Load vector from a smaller type (more elements) to a larger type (fewer elements),
// reducing neighboring elements until it fits into the vector size.
template <typename acc_t, typename scalar_t, typename F>
Vectorized<acc_t> load_reduce_vec(const scalar_t* data, F reduce, acc_t ident) {
using vec_t = Vectorized<scalar_t>;
Expand Down Expand Up @@ -511,6 +511,87 @@ void scalar_outer_sum(
}
}

// Custom floating point sum for better accuracy
template <bool ignore_nan, typename scalar_t>
void cascade_sum(TensorIterator &iter) {
iter.output().fill_(scalar_t(0));
iter.parallel_reduce(
[&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t in_strides[] = { strides[1], strides[3] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t out_strides[] = { strides[0], strides[2] };

// Move reduction to be the 1st dim
if (out_strides[0] != 0 && out_strides[1] == 0) {
std::swap(in_strides[0], in_strides[1]);
std::swap(out_strides[0], out_strides[1]);
std::swap(size0, size1);
}

// Special case? - not a true reduction
if (out_strides[0] != 0 && out_strides[1] != 0) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t outer_strides[] = { strides[2], strides[3] };
UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
char* ptrs[3] = { data[0], data[0], data[1] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
c10::guts::if_constexpr<ignore_nan>(
[&](auto) {
basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) {
auto a_notnan = at::_isnan(a) ? scalar_t(0) : a;
auto b_notnan = at::_isnan(b) ? scalar_t(0) : b;
return a_notnan + b_notnan;
});
},
[&](auto) {
basic_loop(ptrs, inner_strides, 0, size0,
[](scalar_t a, scalar_t b) { return a + b; });
});
});
return;
}

const int64_t out_stride = out_strides[1];
TORCH_INTERNAL_ASSERT(out_strides[0] == 0);

using vec_t = Vectorized<scalar_t>;
using acc_t = at::acc_type<scalar_t, true>;
using vacc_t = Vectorized<acc_t>;
using ScalarLoadPolicy = std::conditional_t<
ignore_nan,
NanSumCastLoadPolicy<scalar_t, acc_t>,
CastLoadPolicy<scalar_t, acc_t>>;
using StorePolicy = CastStoreAccumulate<scalar_t, acc_t>;

if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) {
// Contiguous inner reduction
using VecLoadPolicy = std::conditional_t<
ignore_nan,
InnerNanSumCastLoadPolicy<vec_t, vacc_t>,
InnerSumCastLoadPolicy<vec_t, vacc_t>>;
vectorized_inner_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[1], out_stride, size0, size1);
} else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) {
// Contiguous outer reduction
using VecLoadPolicy = std::conditional_t<
ignore_nan,
OuterNanSumCastLoadPolicy<vec_t, vacc_t>,
OuterSumCastLoadPolicy<vec_t, vacc_t>>;
vectorized_outer_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[0], out_stride, size0, size1);
} else if (in_strides[0] < in_strides[1]) {
scalar_inner_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides, out_stride, size0, size1);
} else {
scalar_outer_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides, out_stride, size0, size1);
}
});
}

void sum_kernel_impl(TensorIterator &iter) {
if (isIntegralType(iter.dtype(), /*includeBool=*/ true)) {
AT_DISPATCH_INTEGRAL_TYPES_AND(ScalarType::Bool, iter.dtype(), "sum_cpu",
Expand All @@ -522,133 +603,17 @@ void sum_kernel_impl(TensorIterator &iter) {
return;
}

// Custom floating point sum for better accuracy
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "sum_cpu",
[&] {
iter.output().fill_(scalar_t(0));
iter.parallel_reduce(
[&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t in_strides[] = { strides[1], strides[3] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t out_strides[] = { strides[0], strides[2] };

// Move reduction to be the 1st dim
if (out_strides[0] != 0 && out_strides[1] == 0) {
std::swap(in_strides[0], in_strides[1]);
std::swap(out_strides[0], out_strides[1]);
std::swap(size0, size1);
}

// Special case? - not a true reduction
if (out_strides[0] != 0 && out_strides[1] != 0) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t outer_strides[] = { strides[2], strides[3] };
UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
char* ptrs[3] = { data[0], data[0], data[1] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) { return a + b; });
});
return;
}

const int64_t out_stride = out_strides[1];
TORCH_INTERNAL_ASSERT(out_strides[0] == 0);

using vec_t = Vectorized<scalar_t>;
using acc_t = at::acc_type<scalar_t, true>;
using vacc_t = Vectorized<acc_t>;
using ScalarLoadPolicy = CastLoadPolicy<scalar_t, acc_t>;
using StorePolicy = CastStoreAccumulate<scalar_t, acc_t>;

if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) {
// Contiguous inner reduction
using VecLoadPolicy = InnerSumCastLoadPolicy<vec_t, vacc_t>;
vectorized_inner_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[1], out_stride, size0, size1);
} else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) {
// Contiguous outer reduction
using VecLoadPolicy = OuterSumCastLoadPolicy<vec_t, vacc_t>;
vectorized_outer_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[0], out_stride, size0, size1);
} else if (in_strides[0] < in_strides[1]) {
scalar_inner_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides, out_stride, size0, size1);
} else {
scalar_outer_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides, out_stride, size0, size1);
}
});
});
ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "sum_cpu", [&] {
cascade_sum</*ignore_nan=*/false, scalar_t>(iter);
});
}

void nansum_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::Half, iter.dtype(), "nansum_cpu",
[&] {
iter.output().fill_(scalar_t(0));
iter.parallel_reduce(
[&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
std::array<int64_t, 2> in_strides{{ strides[1], strides[3] }};
std::array<int64_t, 2> out_strides{{ strides[0], strides[2] }};

// Move reduction to be the 1st dim
if (out_strides[0] != 0 && out_strides[1] == 0) {
std::swap(in_strides[0], in_strides[1]);
std::swap(out_strides[0], out_strides[1]);
std::swap(size0, size1);
}

// Special case? - not a true reduction
if (out_strides[0] != 0 && out_strides[1] != 0) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t outer_strides[] = { strides[2], strides[3] };
UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
char* ptrs[] = { data[0], data[0], data[1] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t inner_strides[] = { strides[0], strides[0], strides[1] };
basic_loop(ptrs, inner_strides, 0, size0,
[](scalar_t a, scalar_t b) {
auto notnan_a = at::_isnan(a) ? scalar_t(0) : a;
auto notnan_b = at::_isnan(b) ? scalar_t(0) : b;
return notnan_a + notnan_b;
});
});
return;
}

const int64_t out_stride = out_strides[1];
TORCH_INTERNAL_ASSERT(out_strides[0] == 0);

using vec_t = Vectorized<scalar_t>;
using acc_t = at::acc_type<scalar_t, true>;
using vacc_t = Vectorized<acc_t>;
using ScalarLoadPolicy = NanSumCastLoadPolicy<scalar_t, acc_t>;
using StorePolicy = CastStoreAccumulate<scalar_t, acc_t>;

if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) {
// Contiguous inner reduction
using VecLoadPolicy = InnerNanSumCastLoadPolicy<vec_t, vacc_t>;
vectorized_inner_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[1], out_stride, size0, size1);
} else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) {
// Contiguous outer reduction
using VecLoadPolicy = OuterNanSumCastLoadPolicy<vec_t, vacc_t>;
vectorized_outer_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[0], out_stride, size0, size1);
} else if (in_strides[0] < in_strides[1]) {
scalar_inner_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides.data(), out_stride, size0, size1);
} else {
scalar_outer_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides.data(), out_stride, size0, size1);
}
});
});
ScalarType::Half, iter.dtype(), "nansum_cpu", [&] {
cascade_sum</*ignore_nan=*/true, scalar_t>(iter);
});
}

} // namespace (anonymous)
Expand Down

0 comments on commit d531bed

Please sign in to comment.