Skip to content

Commit

Permalink
add a unique function for guarding index and remove the unused function
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoE committed Jan 30, 2023
1 parent e871b90 commit fad835a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 44 deletions.
53 changes: 31 additions & 22 deletions aten/src/ATen/native/UpSample.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,18 @@ static inline scalar_t cubic_interp1d(
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template<typename scalar_t, typename opmath_t>
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
input_index = std::min(static_cast<int64_t>(real_input_index), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
}

template<typename scalar_t, typename opmath_t>
static inline void compute_source_index_and_lambda(
int64_t& input_index0,
Expand All @@ -449,44 +461,41 @@ static inline void compute_source_index_and_lambda(
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size - 1`, causing overflow. So we guard it with `std::min` below.
input_index0 = std::min(static_cast<int64_t>(real_input_index), input_size - 1);
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda1 = std::min(
std::max(real_input_index - input_index0, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}

// For compilation, and it will not be used by data types other than BFloat16.
// It will not be used by data types other than BFloat16.
template <typename scalar_in, typename scalar_out>
void inline apply_grad_input(scalar_out* buffer_ptr, scalar_in* gin, int64_t size) {
return;
}
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
constexpr bool is_bf16 = std::is_same<scalar_out, BFloat16>::value;
TORCH_CHECK(is_bf16,
"Upsample backward only support BFloat16 in the lower percision data types on CPU.")
constexpr bool is_fp32 = std::is_same<scalar_in, float>::value;
TORCH_CHECK(is_fp32,
"Upsample backward should use float as acc buffer for BFloat16 grad input on CPU.")

template <>
void inline apply_grad_input(float* buffer_ptr, BFloat16* gin, int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
auto f_buffer_ptr = reinterpret_cast<float*>(buffer_ptr);
auto b_gin = reinterpret_cast<BFloat16*>(gin);
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec gin_bvec = bVec::loadu(gin + d);
bVec gin_bvec = bVec::loadu(b_gin + d);
fVec gin_fvec0, gin_fvec1;
std::tie(gin_fvec0, gin_fvec1) = convert_bfloat16_float(gin_bvec);
gin_fvec0 += fVec::loadu(buffer_ptr + d);
gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
fVec(0).store(buffer_ptr + d);
fVec(0).store(buffer_ptr + d + fVec::size());
convert_float_bfloat16(gin_fvec0, gin_fvec1).store(gin + d);
gin_fvec0 += fVec::loadu(f_buffer_ptr + d);
gin_fvec1 += fVec::loadu(f_buffer_ptr + d + fVec::size());
fVec(0).store(f_buffer_ptr + d);
fVec(0).store(f_buffer_ptr + d + fVec::size());
convert_float_bfloat16(gin_fvec0, gin_fvec1).store(b_gin + d);
}
for (; d < size; d++) {
gin[d] += buffer_ptr[d];
buffer_ptr[d] = 0;
b_gin[d] += f_buffer_ptr[d];
f_buffer_ptr[d] = 0;
}
}

Expand Down
19 changes: 6 additions & 13 deletions aten/src/ATen/native/UpSampleBicubic2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,14 @@ static void upsample_bicubic2d_backward_out_frame(
for (const auto output_x : c10::irange(output_width)) {

const opmath_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
// when `real_x` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_width - 1`. So we guard it with `std::min` below.
int64_t input_x = std::min(static_cast<int64_t>(floorf(real_x)), input_width - 1);
opmath_t t_x = std::min(
std::max(real_x - input_x, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
int64_t input_x;
opmath_t t_x;
guard_index_and_lambda(real_x, input_width, input_x, t_x);

const opmath_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y = std::min(static_cast<int64_t>(floorf(real_y)), input_height - 1);
opmath_t t_y = std::min(
std::max(real_y - input_y, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
int64_t input_y;
opmath_t t_y;
guard_index_and_lambda(real_y, input_height, input_y, t_y);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t x_coeffs[4];
Expand Down
11 changes: 3 additions & 8 deletions aten/src/ATen/native/cpu/UpSampleKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,14 +980,9 @@ struct HelperInterpCubic : public HelperInterpBase {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
scale, i, align_corners, /*cubic=*/true);
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size - 1`. So we guard it with `std::min` below.
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
auto lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);

opmath_t lambda;
guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
get_cubic_upsample_coefficients<opmath_t>(coeffs, lambda);

for (const auto j : c10::irange(interp_size)) {
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6338,8 +6338,8 @@ def helper(size, scale_factor, mode, device, memory_format=torch.contiguous_form

ginput = torch.randn(out.shape, device=device, dtype=dtype).to(memory_format=memory_format)
ginputf = ginput.to(torch.float32).to(memory_format=torch.contiguous_format)
out.backward(ginput)
outf.backward(ginputf)
out.backward(ginput)
self.assertEqual(input.grad.to(torch.float32), inputf.grad, atol=0.01, rtol=0.01)

for device in ['cpu']:
Expand Down

0 comments on commit fad835a

Please sign in to comment.