Skip to content

Commit

Permalink
[fbsync] [2/2] Added backward pass on CUDA for interpolation with ant…
Browse files Browse the repository at this point in the history
…i-alias option (#4211)

Summary:
* WIP on backward op interpolation with AA

* Removed cuda tests and reformat cpp code

* Fixed clang wrong formatting

* Added channels last test case

* Added CUDA support for backward pass, interpolation with AA

* Removed unused buffers

Reviewed By: NicolasHug

Differential Revision: D30417194

fbshipit-source-id: 4aab5bc21621859cfc4254da6a230e0c8a8cffc2

Co-authored-by: vfdev-5 <vfdev-5@gmail.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Aug 19, 2021
1 parent 7994bb8 commit be4db52
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 45 deletions.
6 changes: 2 additions & 4 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,11 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(dt, size, interpolation):

# temporarily hard-code device as CPU, CUDA support will be done later
device = "cpu"
def test_interpolate_antialias_backward(device, dt, size, interpolation):

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
Expand Down
309 changes: 268 additions & 41 deletions torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,49 +165,32 @@ __global__ void upsample_gen2d_out_frame(
// Compute weights
int xmin, xsize, ymin, ysize;
typedef scalar_t (*filter_fn_t)(scalar_t);
filter_fn_t filter_fn;
if (interp_size == 2) {
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
w2,
width1,
rwidth,
support_w,
wx,
interp_width,
bilinear_filter,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
h2,
height1,
rheight,
support_h,
wy,
interp_height,
bilinear_filter,
ymin,
ysize);
filter_fn = bilinear_filter;
} else if (interp_size == 4) {
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
w2,
width1,
rwidth,
support_w,
wx,
interp_width,
bicubic_filter,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
h2,
height1,
rheight,
support_h,
wy,
interp_height,
bicubic_filter,
ymin,
ysize);
filter_fn = bicubic_filter;
}
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
w2,
width1,
rwidth,
support_w,
wx,
interp_width,
filter_fn,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
h2,
height1,
rheight,
support_h,
wy,
interp_height,
filter_fn,
ymin,
ysize);

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
Expand Down Expand Up @@ -239,6 +222,8 @@ static void upsample_gen2d_out_cuda_template(
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
// Copied and adapted from
// UpSampleBicubic2d.cu::upsample_bicubic2d_out_cuda_template
TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
checkAllSameGPU("upsample_gen2d_out_cuda", {input_arg, output_arg});

Expand All @@ -256,7 +241,7 @@ static void upsample_gen2d_out_cuda_template(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "upsample_bilinear2d_out_frame", [&] {
input.scalar_type(), "upsample_gen2d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = input.packed_accessor64<scalar_t, 4>();
Expand Down Expand Up @@ -287,6 +272,174 @@ static void upsample_gen2d_out_cuda_template(
});
}

// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t, int interp_size>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_gen2d_backward_out_frame(
const int num_elements,
const accscalar_t height_scale,
const accscalar_t width_scale,
const bool align_corners,
PackedTensorAccessor64<scalar_t, 4> idata,
const PackedTensorAccessor64<scalar_t, 4> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;

const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int input_height = idata.size(2);
const int input_width = idata.size(3);
const int output_height = odata.size(2);
const int output_width = odata.size(3);

if (index >= num_elements) {
return;
}

const int output_x = index % output_width;
const int output_y = index / output_width;
// special case: output just copy
if (input_height == output_height && input_width == output_width) {
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = odata[n][c][output_y][output_x];
idata[n][c][output_y][output_x] = val;
}
}
return;
}

const accscalar_t support_h = static_cast<accscalar_t>(
(height_scale >= 1.0) ? (interp_size * 0.5) * height_scale
: interp_size * 0.5);
const accscalar_t support_w = static_cast<accscalar_t>(
(width_scale >= 1.0) ? (interp_size * 0.5) * width_scale
: interp_size * 0.5);

const int interp_height = (int)ceilf(support_h) * 2 + 1;
const int interp_width = (int)ceilf(support_w) * 2 + 1;

// Setup local buffers
// TODO: maybe we can specify dynamic shared memory size before calling the
// cuda code, however we should then ensure that device has enough shared
// memory
scalar_t wx[256];
scalar_t wy[256];

// Compute weights
int xmin, xsize, ymin, ysize;
typedef scalar_t (*filter_fn_t)(scalar_t);
filter_fn_t filter_fn;
if (interp_size == 2) {
filter_fn = bilinear_filter;
} else if (interp_size == 4) {
filter_fn = bicubic_filter;
}
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
output_x,
input_width,
width_scale,
support_w,
wx,
interp_width,
filter_fn,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
output_y,
input_height,
height_scale,
support_h,
wy,
interp_height,
filter_fn,
ymin,
ysize);

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
scalar_t out_value = odata[n][c][output_y][output_x];
for (int y = 0; y < ysize; y++) {
for (int x = 0; x < xsize; x++) {
upsample_increment_value_bounded<scalar_t, accscalar_t>(
idata,
n,
c,
input_height,
input_width,
ymin + y,
xmin + x,
wx[x] * wy[y] * out_value);
}
}
}
}
}

template <int interp_size>
static void upsample_gen2d_backward_out_cuda_template(
const Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
// Copied and adapted from
// UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template
TensorArg grad_input_arg{grad_input, "grad_input", 1},
grad_output_arg{grad_output_, "grad_output_", 2};
checkAllSameGPU(
"upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg});

int output_height = output_size[0];
int output_width = output_size[1];

int nbatch = input_size[0];
int channels = input_size[1];
int input_height = input_size[2];
int input_width = input_size[3];

Tensor grad_output = grad_output_.contiguous();

grad_input.zero_();

const int num_kernels = output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_gen2d_backward_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = grad_input.packed_accessor64<scalar_t, 4>();
auto odata = grad_output.packed_accessor64<scalar_t, 4>();

const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
input_height, output_height, align_corners, scales_h);
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);

// We are using static buffer memory of 256 * sizeof(float) per thread
// to store weights. Size of weights array is
// interp_size = scale * 2 + 1 for bilinear mode
TORCH_CHECK(
rheight < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
TORCH_CHECK(
rwidth < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");

upsample_gen2d_backward_out_frame<scalar_t, accscalar_t, interp_size>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(
num_kernels, rheight, rwidth, align_corners, idata, odata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}

} // namespace internal_upsample
} // namespace native
} // namespace at
Expand Down Expand Up @@ -371,6 +524,56 @@ at::Tensor interpolate_gen2d_aa_forward_kernel(
return output;
}

template <int interp_size>
at::Tensor interpolate_gen2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};

// Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward
auto grad_input = at::empty({0}, grad_output.options());
auto osize = at::native::upsample::compute_output_size(
input_size, output_size, scale_factors);
auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1);

auto full_output_size = upsample_2d_common_check(input_size, osize);

TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
grad_output.dim());

for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(",
i,
") = ",
full_output_size[i],
" but got grad_output.size(",
i,
") = ",
grad_output.size(i));
}

grad_input.resize_(input_size, grad_output.suggest_memory_format());

at::native::internal_upsample::upsample_gen2d_backward_out_cuda_template<
interp_size>(
grad_input,
grad_output,
{full_output_size[2], full_output_size[3]},
input_size,
align_corners,
scale_h,
scale_w);
return grad_input;
}

at::Tensor interpolate_bilinear2d_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
Expand All @@ -387,6 +590,24 @@ at::Tensor interpolate_bicubic2d_aa_forward_kernel(
input, output_size, align_corners);
}

at::Tensor interpolate_bilinear2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
return interpolate_gen2d_aa_backward_kernel<2>(
grad_output, output_size, input_size, align_corners);
}

at::Tensor interpolate_bicubic2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
return interpolate_gen2d_aa_backward_kernel<4>(
grad_output, output_size, input_size, align_corners);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
Expand All @@ -396,6 +617,12 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"),
TORCH_FN(interpolate_bicubic2d_aa_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"),
TORCH_FN(interpolate_bilinear2d_aa_backward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"),
TORCH_FN(interpolate_bicubic2d_aa_backward_kernel));
}

} // namespace ops
Expand Down

0 comments on commit be4db52

Please sign in to comment.