Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve precision and performance for BFloat16 upsampling #91169

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 47 additions & 11 deletions aten/src/ATen/native/UpSample.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/DispatchStub.h>

/**
Expand Down Expand Up @@ -427,13 +428,25 @@ static inline scalar_t cubic_interp1d(
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

template<typename scalar_t>
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template<typename scalar_t, typename opmath_t>
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
}

template<typename scalar_t, typename opmath_t>
static inline void compute_source_index_and_lambda(
int64_t& input_index0,
int64_t& input_index1,
scalar_t& lambda0,
scalar_t& lambda1,
scalar_t ratio,
opmath_t ratio,
int64_t output_index,
int64_t input_size,
int64_t output_size,
Expand All @@ -445,23 +458,46 @@ static inline void compute_source_index_and_lambda(
lambda0 = static_cast<scalar_t>(1);
lambda1 = static_cast<scalar_t>(0);
} else {
using opmath_t = at::opmath_type<scalar_t>;
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size - 1`, causing overflow. So we guard it with `std::min` below.
input_index0 = std::min(static_cast<int64_t>(real_input_index), input_size - 1);
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda1 = std::min(
std::max(real_input_index - input_index0, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}

// It will not be used by data types other than BFloat16.
template <typename scalar_in, typename scalar_out>
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
TORCH_CHECK((std::is_same<scalar_out, BFloat16>::value),
"Upsample backward only support BFloat16 in the lower percision data types on CPU.")
TORCH_CHECK((std::is_same<scalar_in, float>::value),
"Upsample backward should use float as acc buffer for BFloat16 grad input on CPU.")
return;
}

template <>
void inline apply_grad_input(float* buffer_ptr, BFloat16* gin, int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec gin_bvec = bVec::loadu(gin + d);
fVec gin_fvec0, gin_fvec1;
std::tie(gin_fvec0, gin_fvec1) = convert_bfloat16_float(gin_bvec);
gin_fvec0 += fVec::loadu(buffer_ptr + d);
gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
fVec(0).store(buffer_ptr + d);
fVec(0).store(buffer_ptr + d + fVec::size());
convert_float_bfloat16(gin_fvec0, gin_fvec1).store(gin + d);
}
for (; d < size; d++) {
gin[d] += buffer_ptr[d];
buffer_ptr[d] = 0;
}
}

} // namespace native
} // namespace at
115 changes: 58 additions & 57 deletions aten/src/ATen/native/UpSampleBicubic2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/TensorMeta.h>
#include <ATen/native/UpSample.h>
#include <c10/util/irange.h>
#include <ATen/Parallel.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -118,69 +119,65 @@ static void upsample_bicubic2d_backward_out_frame(
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
channels = channels * nbatch;
auto input_slice_size = input_height * input_width;
auto output_slice_size = output_height * output_width;

// Special case: input/output same size, just copy
if (input_height == output_height && input_width == output_width) {
for (const auto output_y : c10::irange(output_height)) {
for (const auto output_x : c10::irange(output_width)) {
scalar_t* in = &idata[output_y * input_width + output_x];
scalar_t* out = &odata[output_y * output_width + output_x];
for (const auto c C10_UNUSED : c10::irange(channels)) {
in[0] = out[0];
in += input_width * input_height;
out += output_width * output_height;
}
}
}
return;
}

const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
input_height, output_height, align_corners, scales_h);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales_w);

for (const auto output_y : c10::irange(output_height)) {
for (const auto output_x : c10::irange(output_width)) {
scalar_t* in = idata;
scalar_t* out = odata;

const scalar_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int64_t input_x = floorf(real_x);
scalar_t t_x = real_x - input_x;

const scalar_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y = floorf(real_y);
scalar_t t_y = real_y - input_y;

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t x_coeffs[4];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t y_coeffs[4];

get_cubic_upsample_coefficients<scalar_t>(x_coeffs, t_x);
get_cubic_upsample_coefficients<scalar_t>(y_coeffs, t_y);

for (const auto c C10_UNUSED : c10::irange(channels)) {
scalar_t out_value = out[output_y * output_width + output_x];

for (const auto i : c10::irange(4)) {
for (const auto j : c10::irange(4)) {
upsample_increment_value_bounded<scalar_t>(
in,
input_width,
input_height,
input_x - 1 + i,
input_y - 1 + j,
out_value * y_coeffs[j] * x_coeffs[i]);
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, [&](int64_t start, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
}
for (const auto i : c10::irange(start, end)) {
scalar_t* in = idata + i * input_slice_size;
scalar_t* out = odata + i * output_slice_size;
for (const auto output_y : c10::irange(output_height)) {
for (const auto output_x : c10::irange(output_width)) {

const opmath_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int64_t input_x;
opmath_t t_x;
guard_index_and_lambda(real_x, input_width, input_x, t_x);

const opmath_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y;
opmath_t t_y;
guard_index_and_lambda(real_y, input_height, input_y, t_y);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t x_coeffs[4];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t y_coeffs[4];

get_cubic_upsample_coefficients<opmath_t>(x_coeffs, t_x);
get_cubic_upsample_coefficients<opmath_t>(y_coeffs, t_y);

opmath_t out_value = out[output_y * output_width + output_x];
for (const auto ii : c10::irange(4)) {
for (const auto jj : c10::irange(4)) {
upsample_increment_value_bounded<opmath_t>(
acc_data_ptr == nullptr ? reinterpret_cast<opmath_t*>(in) : acc_data_ptr,
input_width,
input_height,
input_x - 1 + ii,
input_y - 1 + jj,
out_value * y_coeffs[jj] * x_coeffs[ii]);
}
}
}

in += input_width * input_height;
out += output_width * output_height;
}
if (acc_data_ptr != nullptr) {
apply_grad_input(acc_data_ptr, in, input_slice_size);
}
}
}
});
}

static void upsample_bicubic2d_backward_kernel(
Expand All @@ -201,7 +198,11 @@ static void upsample_bicubic2d_backward_kernel(
int64_t input_width = input_size[3];

auto grad_output = grad_output_.contiguous();

// Special case: input/output same size, just copy
if (input_height == output_height && input_width == output_width) {
grad_input.copy_(grad_output);
return;
}
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
grad_output.scalar_type(), "upsample_bicubic2d_backward", [&] {
scalar_t* idata = grad_input.data_ptr<scalar_t>();
Expand Down
20 changes: 11 additions & 9 deletions aten/src/ATen/native/cpu/UpSampleKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,8 @@ struct HelperInterpNearest : public HelperInterpBase {

AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, scalar_type, "compute_indices_weights_nearest", [&] {
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);

auto input_index_ptr = output[0].data_ptr<int64_t>();
int64_t input_index;
Expand All @@ -1020,7 +1021,6 @@ struct HelperInterpNearest : public HelperInterpBase {
// index_f32 = (output_index) * scale
// input_index = floor(index_f32)
// Same as OpenCV INTER_NEAREST
using opmath_t = at::opmath_type<scalar_t>;
for (const auto i : c10::irange(output_size)) {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
Expand Down Expand Up @@ -1110,7 +1110,8 @@ struct HelperInterpLinear : public HelperInterpBase {
scalar_type, output, output_size, ndims, reshape_dim, HelperInterpLinear::interp_size);
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, scalar_type, "compute_indices_weights_linear", [&] {
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);

auto input_index0_ptr = output[0].data_ptr<int64_t>();
auto lambda0_ptr = output[1].data_ptr<scalar_t>();
Expand All @@ -1119,7 +1120,7 @@ struct HelperInterpLinear : public HelperInterpBase {

for (const auto i : c10::irange(output_size)) {

compute_source_index_and_lambda<scalar_t>(
compute_source_index_and_lambda<scalar_t, opmath_t>(
input_index0_ptr[i], input_index1_ptr[i],
lambda0_ptr[i], lambda1_ptr[i],
scale, i, input_size, output_size, align_corners
Expand Down Expand Up @@ -1234,22 +1235,23 @@ struct HelperInterpCubic : public HelperInterpBase {

AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, scalar_type, "compute_indices_weights_cubic", [&] {
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);

int64_t input_index;
int64_t zero = static_cast<int64_t>(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t coeffs[4];
opmath_t coeffs[4];

int64_t * idx_ptr;
scalar_t * wt_ptr;
using opmath_t = at::opmath_type<scalar_t>;
for (const auto i : c10::irange(output_size)) {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
scale, i, align_corners, /*cubic=*/true);
input_index = static_cast<int64_t>(floorf(real_input_index));
get_cubic_upsample_coefficients<scalar_t>(coeffs, real_input_index - input_index);
opmath_t lambda;
guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
get_cubic_upsample_coefficients<opmath_t>(coeffs, lambda);

for (const auto j : c10::irange(interp_size)) {
idx_ptr = output[2 * j + 0].data_ptr<int64_t>();
Expand Down