Skip to content

Commit

Permalink
Structured kernel definition for upsample_nearest2d (#50189)
Browse files Browse the repository at this point in the history
Summary:
See the structured kernel definition [RFC](pytorch/rfcs#9) for context.

Pull Request resolved: #50189

Reviewed By: mrshenli

Differential Revision: D25903846

Pulled By: soulitzer

fbshipit-source-id: 0059fda9b7d86f596ca35d830562dd4b859293a0
  • Loading branch information
soulitzer authored and facebook-github-bot committed Jan 14, 2021
1 parent fc9f013 commit 19a8e68
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 125 deletions.
156 changes: 64 additions & 92 deletions aten/src/ATen/native/UpSampleNearest2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,9 @@
#include <ATen/native/UpSample.h>

namespace at {
namespace native {
namespace {

static void upsample_nearest2d_out_cpu_template(
Tensor& output,
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
TORCH_CHECK(
output_size.size() == 2,
"It is expected output_size equals to 2, but got size ",
output_size.size());

int64_t output_height = output_size[0];
int64_t output_width = output_size[1];

int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_height = input.size(2);
int64_t input_width = input.size(3);
namespace meta {

upsample_2d_shape_check(
input,
Tensor(),
nbatch,
channels,
input_height,
input_width,
output_height,
output_width);

output.resize_({nbatch, channels, output_height, output_width}, input.suggest_memory_format());

AT_ASSERT(input_width > 0 && output_width > 0);
upsample_nearest2d_kernel(kCPU, output, input, scales_h, scales_w);
}

static void upsample_nearest2d_backward_out_cpu_template(
Tensor& grad_input,
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
static std::array<int64_t, 4> upsample_nearest2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 2,
"It is expected output_size equals to 2, but got size ",
Expand All @@ -66,65 +24,84 @@ static void upsample_nearest2d_backward_out_cpu_template(
int64_t input_height = input_size[2];
int64_t input_width = input_size[3];

upsample_2d_shape_check(
Tensor(),
grad_output,
nbatch,
channels,
TORCH_CHECK(
input_height > 0 && input_width > 0 && output_height > 0 &&
output_width > 0,
"Input and output sizes should be greater than 0,"
" but got input (H: ",
input_height,
", W: ",
input_width,
") output (H: ",
output_height,
output_width);
", W: ",
output_width,
")");

grad_input.resize_({nbatch, channels, input_height, input_width});
grad_input.zero_();

upsample_nearest2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w);
return {nbatch, channels, output_height, output_width};
}
} // namespace

Tensor& upsample_nearest2d_out_cpu(
Tensor& output,
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
upsample_nearest2d_out_cpu_template(output, input, output_size, scales_h, scales_w);
return output;
TORCH_META_FUNC(upsample_nearest2d) (
const Tensor& input, IntArrayRef output_size, c10::optional<double> scales_h, c10::optional<double> scales_w
) {
auto full_output_size = upsample_nearest2d_common_check(input.sizes(), output_size);

// Allow for empty batch size but not other dimensions
TORCH_CHECK(
input.numel() != 0 || prod_intlist(input.sizes().begin() + 1, input.sizes().end()),
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());

set_output(full_output_size, input.options());
}

Tensor upsample_nearest2d_cpu(
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
auto output = at::empty({0}, input.options());
upsample_nearest2d_out_cpu_template(output, input, output_size, scales_h, scales_w);
return output;
TORCH_META_FUNC(upsample_nearest2d_backward) (
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w
) {
auto full_output_size = upsample_nearest2d_common_check(input_size, output_size);

TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());

for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(", i, ") = ", full_output_size[i],
" but got grad_output.size(", i, ") = ", grad_output.size(i));
}

set_output(input_size, grad_output.options());
}

Tensor& upsample_nearest2d_backward_out_cpu(
Tensor& grad_input,
const Tensor& grad_output,
} // namespace meta

namespace native {

TORCH_IMPL_FUNC(upsample_nearest2d_out_cpu) (
const Tensor& input,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
upsample_nearest2d_backward_out_cpu_template(
grad_input, grad_output, output_size, input_size, scales_h, scales_w);
return grad_input;
c10::optional<double> scales_w,
Tensor& output
) {
upsample_nearest2d_kernel(kCPU, output, input, scales_h, scales_w);
}

Tensor upsample_nearest2d_backward_cpu(
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cpu) (
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
auto grad_input = at::zeros(input_size, grad_output.options());
upsample_nearest2d_backward_out_cpu_template(
grad_input, grad_output, output_size, input_size, scales_h, scales_w);
return grad_input;
c10::optional<double> scales_w,
Tensor& grad_input) {
grad_input.zero_();
upsample_nearest2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w);
}

using at::native::upsample::compute_output_size;
Expand All @@ -134,12 +111,10 @@ Tensor upsample_nearest2d_cpu(
const Tensor& input,
c10::optional<IntArrayRef> output_size,
c10::optional<ArrayRef<double>> scale_factors) {
auto output = at::empty({0}, input.options());
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
auto scale_h = get_scale_value(scale_factors, 0);
auto scale_w = get_scale_value(scale_factors, 1);
upsample_nearest2d_out_cpu_template(output, input, osize, scale_h, scale_w);
return output;
return at::upsample_nearest2d(input, osize, scale_h, scale_w);
}

Tensor upsample_nearest2d_backward_cpu(
Expand All @@ -150,10 +125,7 @@ Tensor upsample_nearest2d_backward_cpu(
auto osize = compute_output_size(input_size, output_size, scale_factors);
auto scale_h = get_scale_value(scale_factors, 0);
auto scale_w = get_scale_value(scale_factors, 1);
auto grad_input = at::zeros(input_size, grad_output.options());
upsample_nearest2d_backward_out_cpu_template(
grad_input, grad_output, osize, input_size, scale_h, scale_w);
return grad_input;
return at::upsample_nearest2d_backward(grad_output, osize, input_size, scale_h, scale_w);
}

DEFINE_DISPATCH(upsample_nearest2d_kernel);
Expand Down
32 changes: 6 additions & 26 deletions aten/src/ATen/native/cuda/UpSampleNearest2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -292,44 +292,24 @@ static void upsample_nearest2d_backward_out_cuda_template(

} // namespace

Tensor& upsample_nearest2d_out_cuda(
Tensor& output,
TORCH_IMPL_FUNC(upsample_nearest2d_out_cuda) (
const Tensor& input,
IntArrayRef output_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
c10::optional<double> scales_w,
Tensor& output) {
upsample_nearest2d_out_cuda_template(output, input, output_size, scales_h, scales_w);
return output;
}

Tensor upsample_nearest2d_cuda(const Tensor& input, IntArrayRef output_size, c10::optional<double> scales_h, c10::optional<double> scales_w) {
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
upsample_nearest2d_out_cuda_template(output, input, output_size, scales_h, scales_w);
return output;
}

Tensor& upsample_nearest2d_backward_out_cuda(
Tensor& grad_input,
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cuda) (
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
c10::optional<double> scales_w,
Tensor& grad_input) {
upsample_nearest2d_backward_out_cuda_template(
grad_input, grad_output, output_size, input_size, scales_h, scales_w);
return grad_input;
}

Tensor upsample_nearest2d_backward_cuda(
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
upsample_nearest2d_backward_out_cuda_template(
grad_input, grad_output, output_size, input_size, scales_h, scales_w);
return grad_input;
}

using at::native::upsample::compute_output_size;
Expand Down
11 changes: 4 additions & 7 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8325,31 +8325,28 @@
structured_delegate: upsample_nearest1d_backward.grad_input

- func: upsample_nearest2d.out(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
structured: True
dispatch:
CPU: upsample_nearest2d_out_cpu
CUDA: upsample_nearest2d_out_cuda

- func: upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
python_module: nn
structured_delegate: upsample_nearest2d.out
dispatch:
CPU: upsample_nearest2d_cpu
CUDA: upsample_nearest2d_cuda
QuantizedCPU: upsample_nearest2d_quantized_cpu

- func: upsample_nearest2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
structured: True
dispatch:
CPU: upsample_nearest2d_backward_out_cpu
CUDA: upsample_nearest2d_backward_out_cuda

- func: upsample_nearest2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
python_module: nn
dispatch:
CPU: upsample_nearest2d_backward_cpu
CUDA: upsample_nearest2d_backward_cuda
structured_delegate: upsample_nearest2d_backward.grad_input

- func: upsample_nearest3d.out(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
Expand Down
12 changes: 12 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8594,6 +8594,18 @@ def test_upsamplingNearest2d(self):
self.assertEqual(torch.ones(1, 1, 4, 4).contiguous(memory_format=memory_format), out_t.data)
self.assertEqual(torch.ones(1, 1, 4, 4, dtype=torch.uint8).contiguous(memory_format=memory_format), out_uint8_t.data)

# test forward when input's height is not same as width
m = nn.Upsample(size=(4, 2), mode='nearest')
in_t = torch.ones(1, 1, 2, 1).contiguous(memory_format=memory_format)
with warnings.catch_warnings(record=True) as w:
out_t = m(in_t)
self.assertEqual(torch.ones(1, 1, 4, 2).contiguous(memory_format=memory_format), out_t.data)

# test backward when input's height is not same as width
input = torch.ones(1, 1, 2, 1, requires_grad=True).contiguous(memory_format=memory_format)
gradcheck(lambda x: F.interpolate(x, size=(4, 2), mode='nearest'), [input])
gradgradcheck(lambda x: F.interpolate(x, size=(4, 2), mode='nearest'), [input])

input = torch.randn(1, 1, 2, 2, requires_grad=True).contiguous(memory_format=memory_format)
self.assertEqual(
F.interpolate(input, 4, mode='nearest'),
Expand Down

0 comments on commit 19a8e68

Please sign in to comment.