Skip to content

Commit

Permalink
Improve precision and performance for BFloat16 upsampling (#91169)
Browse files Browse the repository at this point in the history
### Description
- Fix precision issue for BFloat16 upsampling: #89212
- Improve performance for BFloat16 upsampling.
### Testing
data type: BFloat16

- Single core

contiguous:
mode | scale_factor | shape  | before backward / ms |  after backward / ms
-- | -- | -- | -- | --
nearest | 2 | [10, 3, 200, 200] | 14.47 | 8.34
linear | 2 | [3, 200, 200] | 3.69 | 2.74
bilinear | 2 | [3, 5, 200, 200] | 87.99 | 49.05
trilinear | 2 | [3, 3, 3, 100, 100]  | 171.02 | 72.53
bicubic | 2 | [3, 3, 200, 200 ] | 176.29 | 78

channels last:
mode | scale_factor | shape | before backward / ms |  after backward / ms
-- | -- | -- | -- | --
nearest | 2 | [10, 3, 200, 200] | 17.70 | 10.30
linear | 2 | [3, 200, 200] | \ | \
bilinear | 2 | [3, 5, 200, 200] | 50.90 | 18.83
trilinear | 2 | [3, 3, 3, 100, 100] | 121.56 | 42.60
bicubic | 2 | [3, 3, 200, 200 ] | 179.40 | 80

- 20 cores

contiguous:
mode | scale_factor | shape | before backward / ms |  after backward / ms
-- | -- | -- | -- | --
nearest | 2 | [10, 3, 200, 200] | 1.17 | 1.01
linear | 2 | [3, 200, 200] | 0.41 | 0.26
bilinear | 2 | [3, 5, 200, 200] | 7.19 | 4.07
trilinear | 2 | [3, 3, 3, 100, 100]  | 21.32 | 9.33
bicubic | 2 | [3, 3, 200, 200 ] | 178.67 | 10

channels last:
mode | scale_factor | shape | before backward / ms |  after backward / ms
-- | -- | -- | -- | --
nearest | 2 | [10, 3, 200, 200] |  2.25 | 1.55
linear | 2 | [3, 200, 200] | \ | \
bilinear | 2 | [3, 5, 200, 200] |  20.17 | 7.20
trilinear | 2 | [3, 3, 3, 100, 100] | 43.33 | 15.66
bicubic | 2 | [3, 3, 200, 200 ] | 176.76 | 10

Pull Request resolved: #91169
Approved by: https://github.com/jgong5, https://github.com/mingfeima, https://github.com/Skylion007
  • Loading branch information
CaoE authored and pytorchmergebot committed May 29, 2023
1 parent 040d2cc commit af1d437
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 186 deletions.
58 changes: 47 additions & 11 deletions aten/src/ATen/native/UpSample.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/OpMathType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/DispatchStub.h>

/**
Expand Down Expand Up @@ -427,13 +428,25 @@ static inline scalar_t cubic_interp1d(
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}

template<typename scalar_t>
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size`, causing overflow. So we guard it with `std::min` below.
template<typename scalar_t, typename opmath_t>
static inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
lambda = std::min(
std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
}

template<typename scalar_t, typename opmath_t>
static inline void compute_source_index_and_lambda(
int64_t& input_index0,
int64_t& input_index1,
scalar_t& lambda0,
scalar_t& lambda1,
scalar_t ratio,
opmath_t ratio,
int64_t output_index,
int64_t input_size,
int64_t output_size,
Expand All @@ -445,23 +458,46 @@ static inline void compute_source_index_and_lambda(
lambda0 = static_cast<scalar_t>(1);
lambda1 = static_cast<scalar_t>(0);
} else {
using opmath_t = at::opmath_type<scalar_t>;
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
ratio, output_index, align_corners, /*cubic=*/false);
// when `real_input_index` becomes larger than the range the floating point
// type can accurately represent, the type casting to `int64_t` might exceed
// `input_size - 1`, causing overflow. So we guard it with `std::min` below.
input_index0 = std::min(static_cast<int64_t>(real_input_index), input_size - 1);
guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
input_index1 = input_index0 + offset;
lambda1 = std::min(
std::max(real_input_index - input_index0, static_cast<opmath_t>(0)),
static_cast<opmath_t>(1)
);
lambda0 = static_cast<scalar_t>(1.) - lambda1;
}
}

// It will not be used by data types other than BFloat16.
template <typename scalar_in, typename scalar_out>
void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
TORCH_CHECK((std::is_same<scalar_out, BFloat16>::value),
"Upsample backward only support BFloat16 in the lower percision data types on CPU.")
TORCH_CHECK((std::is_same<scalar_in, float>::value),
"Upsample backward should use float as acc buffer for BFloat16 grad input on CPU.")
return;
}

template <>
void inline apply_grad_input(float* buffer_ptr, BFloat16* gin, int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec gin_bvec = bVec::loadu(gin + d);
fVec gin_fvec0, gin_fvec1;
std::tie(gin_fvec0, gin_fvec1) = convert_bfloat16_float(gin_bvec);
gin_fvec0 += fVec::loadu(buffer_ptr + d);
gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
fVec(0).store(buffer_ptr + d);
fVec(0).store(buffer_ptr + d + fVec::size());
convert_float_bfloat16(gin_fvec0, gin_fvec1).store(gin + d);
}
for (; d < size; d++) {
gin[d] += buffer_ptr[d];
buffer_ptr[d] = 0;
}
}

} // namespace native
} // namespace at
115 changes: 58 additions & 57 deletions aten/src/ATen/native/UpSampleBicubic2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/TensorMeta.h>
#include <ATen/native/UpSample.h>
#include <c10/util/irange.h>
#include <ATen/Parallel.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -118,69 +119,65 @@ static void upsample_bicubic2d_backward_out_frame(
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
channels = channels * nbatch;
auto input_slice_size = input_height * input_width;
auto output_slice_size = output_height * output_width;

// Special case: input/output same size, just copy
if (input_height == output_height && input_width == output_width) {
for (const auto output_y : c10::irange(output_height)) {
for (const auto output_x : c10::irange(output_width)) {
scalar_t* in = &idata[output_y * input_width + output_x];
scalar_t* out = &odata[output_y * output_width + output_x];
for (const auto c C10_UNUSED : c10::irange(channels)) {
in[0] = out[0];
in += input_width * input_height;
out += output_width * output_height;
}
}
}
return;
}

const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
input_height, output_height, align_corners, scales_h);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales_w);

for (const auto output_y : c10::irange(output_height)) {
for (const auto output_x : c10::irange(output_width)) {
scalar_t* in = idata;
scalar_t* out = odata;

const scalar_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int64_t input_x = floorf(real_x);
scalar_t t_x = real_x - input_x;

const scalar_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y = floorf(real_y);
scalar_t t_y = real_y - input_y;

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t x_coeffs[4];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t y_coeffs[4];

get_cubic_upsample_coefficients<scalar_t>(x_coeffs, t_x);
get_cubic_upsample_coefficients<scalar_t>(y_coeffs, t_y);

for (const auto c C10_UNUSED : c10::irange(channels)) {
scalar_t out_value = out[output_y * output_width + output_x];

for (const auto i : c10::irange(4)) {
for (const auto j : c10::irange(4)) {
upsample_increment_value_bounded<scalar_t>(
in,
input_width,
input_height,
input_x - 1 + i,
input_y - 1 + j,
out_value * y_coeffs[j] * x_coeffs[i]);
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, [&](int64_t start, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if constexpr (!std::is_same<scalar_t, opmath_t>::value) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
}
for (const auto i : c10::irange(start, end)) {
scalar_t* in = idata + i * input_slice_size;
scalar_t* out = odata + i * output_slice_size;
for (const auto output_y : c10::irange(output_height)) {
for (const auto output_x : c10::irange(output_width)) {

const opmath_t real_x = area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true);
int64_t input_x;
opmath_t t_x;
guard_index_and_lambda(real_x, input_width, input_x, t_x);

const opmath_t real_y = area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true);
int64_t input_y;
opmath_t t_y;
guard_index_and_lambda(real_y, input_height, input_y, t_y);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t x_coeffs[4];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t y_coeffs[4];

get_cubic_upsample_coefficients<opmath_t>(x_coeffs, t_x);
get_cubic_upsample_coefficients<opmath_t>(y_coeffs, t_y);

opmath_t out_value = out[output_y * output_width + output_x];
for (const auto ii : c10::irange(4)) {
for (const auto jj : c10::irange(4)) {
upsample_increment_value_bounded<opmath_t>(
acc_data_ptr == nullptr ? reinterpret_cast<opmath_t*>(in) : acc_data_ptr,
input_width,
input_height,
input_x - 1 + ii,
input_y - 1 + jj,
out_value * y_coeffs[jj] * x_coeffs[ii]);
}
}
}

in += input_width * input_height;
out += output_width * output_height;
}
if (acc_data_ptr != nullptr) {
apply_grad_input(acc_data_ptr, in, input_slice_size);
}
}
}
});
}

static void upsample_bicubic2d_backward_kernel(
Expand All @@ -201,7 +198,11 @@ static void upsample_bicubic2d_backward_kernel(
int64_t input_width = input_size[3];

auto grad_output = grad_output_.contiguous();

// Special case: input/output same size, just copy
if (input_height == output_height && input_width == output_width) {
grad_input.copy_(grad_output);
return;
}
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
grad_output.scalar_type(), "upsample_bicubic2d_backward", [&] {
scalar_t* idata = grad_input.data_ptr<scalar_t>();
Expand Down
20 changes: 11 additions & 9 deletions aten/src/ATen/native/cpu/UpSampleKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,8 @@ struct HelperInterpNearest : public HelperInterpBase {

AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, scalar_type, "compute_indices_weights_nearest", [&] {
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);

auto input_index_ptr = output[0].data_ptr<int64_t>();
int64_t input_index;
Expand All @@ -1020,7 +1021,6 @@ struct HelperInterpNearest : public HelperInterpBase {
// index_f32 = (output_index) * scale
// input_index = floor(index_f32)
// Same as OpenCV INTER_NEAREST
using opmath_t = at::opmath_type<scalar_t>;
for (const auto i : c10::irange(output_size)) {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
Expand Down Expand Up @@ -1110,7 +1110,8 @@ struct HelperInterpLinear : public HelperInterpBase {
scalar_type, output, output_size, ndims, reshape_dim, HelperInterpLinear::interp_size);
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, scalar_type, "compute_indices_weights_linear", [&] {
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);

auto input_index0_ptr = output[0].data_ptr<int64_t>();
auto lambda0_ptr = output[1].data_ptr<scalar_t>();
Expand All @@ -1119,7 +1120,7 @@ struct HelperInterpLinear : public HelperInterpBase {

for (const auto i : c10::irange(output_size)) {

compute_source_index_and_lambda<scalar_t>(
compute_source_index_and_lambda<scalar_t, opmath_t>(
input_index0_ptr[i], input_index1_ptr[i],
lambda0_ptr[i], lambda1_ptr[i],
scale, i, input_size, output_size, align_corners
Expand Down Expand Up @@ -1234,22 +1235,23 @@ struct HelperInterpCubic : public HelperInterpBase {

AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, scalar_type, "compute_indices_weights_cubic", [&] {
scalar_t scale = area_pixel_compute_scale<scalar_t>(input_size, output_size, align_corners, opt_scale);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);

int64_t input_index;
int64_t zero = static_cast<int64_t>(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t coeffs[4];
opmath_t coeffs[4];

int64_t * idx_ptr;
scalar_t * wt_ptr;
using opmath_t = at::opmath_type<scalar_t>;
for (const auto i : c10::irange(output_size)) {
const auto real_input_index =
area_pixel_compute_source_index<opmath_t>(
scale, i, align_corners, /*cubic=*/true);
input_index = static_cast<int64_t>(floorf(real_input_index));
get_cubic_upsample_coefficients<scalar_t>(coeffs, real_input_index - input_index);
opmath_t lambda;
guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
get_cubic_upsample_coefficients<opmath_t>(coeffs, lambda);

for (const auto j : c10::irange(interp_size)) {
idx_ptr = output[2 * j + 0].data_ptr<int64_t>();
Expand Down

0 comments on commit af1d437

Please sign in to comment.