Skip to content

Commit

Permalink
add a unique function for guarding index and remove the unused function
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoE committed May 25, 2023
1 parent a399f92 commit 40655a1
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 58 deletions.
29 changes: 19 additions & 10 deletions aten/src/ATen/native/UpSample.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,18 @@ static inline scalar_t cubic_interp1d(
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template<typename scalar_t, typename opmath_t>
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
}

template<typename scalar_t, typename opmath_t>
static inline void compute_source_index_and_lambda(
int64_t& input_index0,
Expand All @@ -449,23 +461,20 @@ static inline void compute_source_index_and_lambda(
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size - 1`, causing overflow. So we guard it with `std::min` below.
input_index0 = std::min(static_cast<int64_t>(real_input_index), input_size - 1);
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda1 = std::min(
std::max(real_input_index - input_index0, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}

// For compilation, and it will not be used by data types other than BFloat16.
// It will not be used by data types other than BFloat16.
template <typename scalar_in, typename scalar_out>
void inline apply_grad_input(scalar_out* buffer_ptr, scalar_in* gin, int64_t size) {
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
TORCH_CHECK((std::is_same<scalar_out, BFloat16>::value),
"Upsample backward only support BFloat16 in the lower percision data types on CPU.")
TORCH_CHECK((std::is_same<scalar_in, float>::value),
"Upsample backward should use float as acc buffer for BFloat16 grad input on CPU.")
return;
}

Expand Down
21 changes: 7 additions & 14 deletions aten/src/ATen/native/UpSampleBicubic2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ static void upsample_bicubic2d_backward_out_frame(
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, [&](int64_t start, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand All @@ -142,21 +142,14 @@ static void upsample_bicubic2d_backward_out_frame(
for (const auto output_x : c10::irange(output_width)) {

const opmath_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
// when `real_x` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_width - 1`. So we guard it with `std::min` below.
int64_t input_x = std::min(static_cast<int64_t>(floorf(real_x)), input_width - 1);
opmath_t t_x = std::min(
std::max(real_x - input_x, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
int64_t input_x;
opmath_t t_x;
guard_index_and_lambda(real_x, input_width, input_x, t_x);

const opmath_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y = std::min(static_cast<int64_t>(floorf(real_y)), input_height - 1);
opmath_t t_y = std::min(
std::max(real_y - input_y, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
int64_t input_y;
opmath_t t_y;
guard_index_and_lambda(real_y, input_height, input_y, t_y);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t x_coeffs[4];
Expand Down
10 changes: 2 additions & 8 deletions aten/src/ATen/native/cpu/UpSampleKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,14 +1249,8 @@ struct HelperInterpCubic : public HelperInterpBase {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
scale, i, align_corners, /*cubic=*/true);
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size - 1`. So we guard it with `std::min` below.
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
auto lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
opmath_t lambda;
guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
get_cubic_upsample_coefficients<opmath_t>(coeffs, lambda);

for (const auto j : c10::irange(interp_size)) {
Expand Down
56 changes: 30 additions & 26 deletions aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ namespace {

using scale_t = std::vector<c10::optional<double>>;

template <typename scalar_in, typename scalar_out>
void inline nearest_channels_last_acc(scalar_in* gin, scalar_out* gout, int64_t size) {
using Vec = vec::Vectorized<scalar_in>;
template <typename acc_t, typename scalar_t>
void inline nearest_channels_last_acc(acc_t* gin, scalar_t* gout, int64_t size) {
TORCH_CHECK((std::is_same<acc_t, scalar_t>::value),
"acc data type of Upsample backward should be same as scalar_t for float or double on CPU.")
using Vec = vec::Vectorized<acc_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d);
Expand Down Expand Up @@ -46,9 +48,11 @@ void inline nearest_channels_last_acc(float* gin, BFloat16* gout, int64_t size)
}
}

template <typename scalar_in, typename scalar_out>
void inline linear_channels_last_acc(scalar_in* gin, scalar_out* gout, scalar_in w, int64_t size) {
using Vec = vec::Vectorized<scalar_in>;
template <typename acc_t, typename scalar_t>
void inline linear_channels_last_acc(acc_t* gin, scalar_t* gout, acc_t w, int64_t size) {
TORCH_CHECK((std::is_same<acc_t, scalar_t>::value),
"acc data type of Upsample backward should be same as scalar_t for float or double on CPU.")
using Vec = vec::Vectorized<acc_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec gin_vec = Vec::loadu(gin + d) + Vec(w) * Vec::loadu(gout + d);
Expand Down Expand Up @@ -111,7 +115,7 @@ void cpu_upsample_nearest_backward(
auto loop1d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand All @@ -126,7 +130,7 @@ void cpu_upsample_nearest_backward(
int64_t output_offset = c * output_slice_size + ow;
acc_data_ptr[input_offset + iw] += grad_output_data[output_offset];
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand All @@ -136,7 +140,7 @@ void cpu_upsample_nearest_backward(
auto loop2d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand All @@ -154,7 +158,7 @@ void cpu_upsample_nearest_backward(
acc_data_ptr[input_offset + ih * input_width + iw] += grad_output_data[output_offset];
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand All @@ -164,7 +168,7 @@ void cpu_upsample_nearest_backward(
auto loop3d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand All @@ -187,7 +191,7 @@ void cpu_upsample_nearest_backward(
}
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand Down Expand Up @@ -246,7 +250,7 @@ void cpu_upsample_nearest_backward_channels_last(
auto loop2d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand All @@ -266,7 +270,7 @@ void cpu_upsample_nearest_backward_channels_last(
nearest_channels_last_acc(buffer_ptr, grad_output_ptr, channels);
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + n * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand All @@ -277,7 +281,7 @@ void cpu_upsample_nearest_backward_channels_last(
auto loop3d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand All @@ -302,7 +306,7 @@ void cpu_upsample_nearest_backward_channels_last(
}
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + n * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand Down Expand Up @@ -429,7 +433,7 @@ void cpu_upsample_linear_backward(
auto loop1d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand All @@ -452,7 +456,7 @@ void cpu_upsample_linear_backward(
acc_data_ptr[input_offset + iw0] += w0lambda * grad_output_value; /* i0 */
acc_data_ptr[input_offset + iw1] += w1lambda * grad_output_value; /* i1*/
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand All @@ -462,7 +466,7 @@ void cpu_upsample_linear_backward(
auto loop2d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand Down Expand Up @@ -493,7 +497,7 @@ void cpu_upsample_linear_backward(
acc_data_ptr[input_offset + ih1 * input_width + iw1] += h1lambda * w1lambda * grad_output_value; /* i11 */
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand All @@ -503,7 +507,7 @@ void cpu_upsample_linear_backward(
auto loop3d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand Down Expand Up @@ -545,7 +549,7 @@ void cpu_upsample_linear_backward(
}
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand Down Expand Up @@ -605,7 +609,7 @@ void cpu_upsample_linear_backward_channels_last(
auto loop2d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand Down Expand Up @@ -641,7 +645,7 @@ void cpu_upsample_linear_backward_channels_last(
linear_channels_last_acc(input_indexr(n, ih1, iw1, input_offset), grad_output_ptr, h1lambda * w1lambda, channels); /* i11 */
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + n * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand All @@ -652,7 +656,7 @@ void cpu_upsample_linear_backward_channels_last(
auto loop3d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
Expand Down Expand Up @@ -698,7 +702,7 @@ void cpu_upsample_linear_backward_channels_last(
}
}
}
if (!std::is_same<scalar_t, opmath_t>::value) {
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
auto gin = grad_input_data + n * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
Expand Down

0 comments on commit 40655a1

Please sign in to comment.