Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/2] Added backward pass on CUDA for interpolation with anti-alias option #4211

Merged
merged 11 commits into from
Aug 4, 2021
6 changes: 2 additions & 4 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,11 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(dt, size, interpolation):

# temporarily hard-code device as CPU, CUDA support will be done later
device = "cpu"
def test_interpolate_antialias_backward(device, dt, size, interpolation):

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
Expand Down
311 changes: 270 additions & 41 deletions torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,49 +165,32 @@ __global__ void upsample_gen2d_out_frame(
// Compute weights
int xmin, xsize, ymin, ysize;
typedef scalar_t (*filter_fn_t)(scalar_t);
filter_fn_t filter_fn;
if (interp_size == 2) {
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
w2,
width1,
rwidth,
support_w,
wx,
interp_width,
bilinear_filter,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
h2,
height1,
rheight,
support_h,
wy,
interp_height,
bilinear_filter,
ymin,
ysize);
filter_fn = bilinear_filter;
} else if (interp_size == 4) {
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
w2,
width1,
rwidth,
support_w,
wx,
interp_width,
bicubic_filter,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
h2,
height1,
rheight,
support_h,
wy,
interp_height,
bicubic_filter,
ymin,
ysize);
filter_fn = bicubic_filter;
}
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
w2,
width1,
rwidth,
support_w,
wx,
interp_width,
filter_fn,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
h2,
height1,
rheight,
support_h,
wy,
interp_height,
filter_fn,
ymin,
ysize);

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
Expand Down Expand Up @@ -239,6 +222,8 @@ static void upsample_gen2d_out_cuda_template(
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
// Copied and adapted from
// UpSampleBicubic2d.cu::upsample_bicubic2d_out_cuda_template
TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
checkAllSameGPU("upsample_gen2d_out_cuda", {input_arg, output_arg});

Expand All @@ -256,7 +241,7 @@ static void upsample_gen2d_out_cuda_template(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "upsample_bilinear2d_out_frame", [&] {
input.scalar_type(), "upsample_gen2d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = input.packed_accessor64<scalar_t, 4>();
Expand Down Expand Up @@ -287,6 +272,176 @@ static void upsample_gen2d_out_cuda_template(
});
}

// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t, int interp_size>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_gen2d_backward_out_frame(
const int num_elements,
const accscalar_t height_scale,
const accscalar_t width_scale,
const bool align_corners,
PackedTensorAccessor64<scalar_t, 4> idata,
const PackedTensorAccessor64<scalar_t, 4> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;

const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int input_height = idata.size(2);
const int input_width = idata.size(3);
const int output_height = odata.size(2);
const int output_width = odata.size(3);

if (index >= num_elements) {
return;
}

const int output_x = index % output_width;
const int output_y = index / output_width;
// special case: output just copy
if (input_height == output_height && input_width == output_width) {
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = odata[n][c][output_y][output_x];
idata[n][c][output_y][output_x] = val;
}
}
return;
}

const accscalar_t support_h = static_cast<accscalar_t>(
(height_scale >= 1.0) ? (interp_size * 0.5) * height_scale
: interp_size * 0.5);
const accscalar_t support_w = static_cast<accscalar_t>(
(width_scale >= 1.0) ? (interp_size * 0.5) * width_scale
: interp_size * 0.5);

const int interp_height = (int)ceilf(support_h) * 2 + 1;
const int interp_width = (int)ceilf(support_w) * 2 + 1;

// Setup local buffers
// TODO: maybe we can specify dynamic shared memory size before calling the
// cuda code, however we should then ensure that device has enough shared
// memory
scalar_t wx[256];
scalar_t wy[256];
scalar_t buffer1[256];
scalar_t buffer2[256];
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

// Compute weights
int xmin, xsize, ymin, ysize;
typedef scalar_t (*filter_fn_t)(scalar_t);
filter_fn_t filter_fn;
if (interp_size == 2) {
filter_fn = bilinear_filter;
} else if (interp_size == 4) {
filter_fn = bicubic_filter;
}
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
output_x,
input_width,
width_scale,
support_w,
wx,
interp_width,
filter_fn,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
output_y,
input_height,
height_scale,
support_h,
wy,
interp_height,
filter_fn,
ymin,
ysize);

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
scalar_t out_value = odata[n][c][output_y][output_x];
for (int y = 0; y < ysize; y++) {
for (int x = 0; x < xsize; x++) {
upsample_increment_value_bounded<scalar_t, accscalar_t>(
idata,
n,
c,
input_height,
input_width,
ymin + y,
xmin + x,
wx[x] * wy[y] * out_value);
}
}
}
}
}

template <int interp_size>
static void upsample_gen2d_backward_out_cuda_template(
const Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
// Copied and adapted from
// UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template
TensorArg grad_input_arg{grad_input, "grad_input", 1},
grad_output_arg{grad_output_, "grad_output_", 2};
checkAllSameGPU(
"upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg});

int output_height = output_size[0];
int output_width = output_size[1];

int nbatch = input_size[0];
int channels = input_size[1];
int input_height = input_size[2];
int input_width = input_size[3];

Tensor grad_output = grad_output_.contiguous();

grad_input.zero_();

const int num_kernels = output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_gen2d_backward_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = grad_input.packed_accessor64<scalar_t, 4>();
auto odata = grad_output.packed_accessor64<scalar_t, 4>();

const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
input_height, output_height, align_corners, scales_h);
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);

// We are using static buffer memory of 256 * sizeof(float) per thread
// to store weights. Size of weights array is
// interp_size = scale * 2 + 1 for bilinear mode
TORCH_CHECK(
rheight < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
TORCH_CHECK(
rwidth < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");

upsample_gen2d_backward_out_frame<scalar_t, accscalar_t, interp_size>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(
num_kernels, rheight, rwidth, align_corners, idata, odata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}

} // namespace internal_upsample
} // namespace native
} // namespace at
Expand Down Expand Up @@ -371,6 +526,56 @@ at::Tensor interpolate_gen2d_aa_forward_kernel(
return output;
}

template <int interp_size>
at::Tensor interpolate_gen2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};

// Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward
auto grad_input = at::empty({0}, grad_output.options());
auto osize = at::native::upsample::compute_output_size(
input_size, output_size, scale_factors);
auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1);

auto full_output_size = upsample_2d_common_check(input_size, osize);

TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
grad_output.dim());

for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(",
i,
") = ",
full_output_size[i],
" but got grad_output.size(",
i,
") = ",
grad_output.size(i));
}

grad_input.resize_(input_size, grad_output.suggest_memory_format());

at::native::internal_upsample::upsample_gen2d_backward_out_cuda_template<
interp_size>(
grad_input,
grad_output,
{full_output_size[2], full_output_size[3]},
input_size,
align_corners,
scale_h,
scale_w);
return grad_input;
}

at::Tensor interpolate_bilinear2d_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
Expand All @@ -387,6 +592,24 @@ at::Tensor interpolate_bicubic2d_aa_forward_kernel(
input, output_size, align_corners);
}

at::Tensor interpolate_bilinear2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
return interpolate_gen2d_aa_backward_kernel<2>(
grad_output, output_size, input_size, align_corners);
}

at::Tensor interpolate_bicubic2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
return interpolate_gen2d_aa_backward_kernel<4>(
grad_output, output_size, input_size, align_corners);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
Expand All @@ -396,6 +619,12 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"),
TORCH_FN(interpolate_bicubic2d_aa_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"),
TORCH_FN(interpolate_bilinear2d_aa_backward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"),
TORCH_FN(interpolate_bicubic2d_aa_backward_kernel));
}

} // namespace ops
Expand Down