From 932ac7bd713fbac63d96d86571ba4529ad0423f0 Mon Sep 17 00:00:00 2001 From: Eli Uriegas <1700823+seemethere@users.noreply.github.com> Date: Fri, 10 Dec 2021 11:42:03 -0800 Subject: [PATCH] [release/1.10] Remove fgrad_input from slow_conv2d (#64280) (#69622) Co-authored-by: Peter Bell --- aten/src/ATen/core/aten_interned_strings.h | 4 +- aten/src/ATen/native/ConvolutionMM2d.cpp | 103 +++------- aten/src/ATen/native/cuda/ConvolutionMM2d.cu | 190 ++++++------------ aten/src/ATen/native/native_functions.yaml | 8 +- .../check_backward_compatibility.py | 2 + test/cpp/jit/test_misc.cpp | 13 +- tools/autograd/derivatives.yaml | 6 +- tools/autograd/gen_variable_type.py | 2 +- torch/csrc/jit/runtime/autodiff.cpp | 7 +- 9 files changed, 105 insertions(+), 230 deletions(-) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index d766c69963be188..c0a8fe1ddff6e7c 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -695,8 +695,8 @@ _(aten, th_resize_as) \ _(aten, th_tensor) \ _(aten, th_zero) \ _(aten, thnn_conv2d) \ -_(aten, thnn_conv2d_backward) \ -_(aten, thnn_conv2d_forward) \ +_(aten, _slow_conv2d_backward) \ +_(aten, _slow_conv2d_forward) \ _(aten, tile) \ _(aten, slow_conv3d) \ _(aten, slow_conv3d_backward) \ diff --git a/aten/src/ATen/native/ConvolutionMM2d.cpp b/aten/src/ATen/native/ConvolutionMM2d.cpp index f06a8c2b82d4974..496c96aa6a9b485 100644 --- a/aten/src/ATen/native/ConvolutionMM2d.cpp +++ b/aten/src/ATen/native/ConvolutionMM2d.cpp @@ -210,7 +210,7 @@ void slow_conv2d_backward_update_grad_input_frame( int64_t pad_width) { auto grad_output_2d = grad_output.reshape( {grad_output.size(0), grad_output.size(1) * grad_output.size(2)}); - fgrad_input.addmm_(weight, grad_output_2d, 0, 1); + at::mm_out(fgrad_input, weight, grad_output_2d); grad_input.zero_(); unfolded2d_acc_stub( @@ -236,7 +236,6 @@ void slow_conv2d_backward_out_cpu_template( const Tensor& input_, const Tensor& weight_, const Tensor& finput, - Tensor& fgrad_input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding) { @@ -264,22 +263,20 @@ void slow_conv2d_backward_out_cpu_template( const Tensor input = input_.contiguous(); const Tensor grad_output = grad_output_.contiguous(); grad_input.resize_as_(input); - fgrad_input.resize_as_(finput); - fgrad_input.zero_(); const Tensor tweight = weight.transpose(0, 1); const int64_t batch_size = input.size(0); at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) { NoGradGuard no_grad; AutoDispatchBelowADInplaceOrView non_variable_type_mode; + auto fgrad_input = at::empty(finput.sizes().slice(1), finput.options()); for (int64_t t = start; t < end; t++) { Tensor grad_input_t = grad_input[t]; Tensor grad_output_t = grad_output[t]; - Tensor fgrad_input_t = fgrad_input[t]; slow_conv2d_backward_update_grad_input_frame( grad_input_t, grad_output_t, tweight, - fgrad_input_t, + fgrad_input, kernel_height, kernel_width, stride_height, @@ -290,51 +287,26 @@ void slow_conv2d_backward_out_cpu_template( }); } -void slow_conv2d_backward_parameters_frame( +void slow_conv2d_backward_weight_frame( Tensor& grad_weight, - Tensor& grad_bias, Tensor& grad_output, const Tensor& finput) { auto grad_output_2d = grad_output.view( {grad_output.size(0), grad_output.size(1) * grad_output.size(2)}); - if (grad_weight.defined()) { - const Tensor tfinput = finput.transpose(0, 1); - grad_weight.addmm_(grad_output_2d, tfinput); - } - - if (grad_bias.defined()) { - AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::BFloat16, - grad_output.scalar_type(), - "slow_conv2d_backward_parameters", - [&] { - auto grad_output_2d_acc = grad_output_2d.accessor(); - auto grad_bias_acc = grad_bias.accessor(); - const auto sz = grad_output_2d.size(1); - for (int64_t i = 0; i < grad_bias.size(0); i++) { - scalar_t sum = 0; - for (int64_t k = 0; k < sz; k++) { - sum += grad_output_2d_acc[i][k]; - } - grad_bias_acc[i] += sum; - } - }); - } + const Tensor tfinput = finput.transpose(0, 1); + grad_weight.addmm_(grad_output_2d, tfinput); } -static void slow_conv2d_backward_parameters_out_cpu_template( +static void slow_conv2d_backward_weight_out_cpu_template( Tensor& grad_weight, - Tensor& grad_bias, const Tensor& input_, const Tensor& grad_output_, const Tensor& finput, - Tensor fgrad_input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding) { CheckedFrom c = "slow_conv2d_backward_parameters_cpu"; auto grad_weight_arg = TensorArg(grad_weight, "grad_weight_arg", 0); - auto grad_bias_arg = TensorArg(grad_bias, "grad_bias_arg", 0); const int64_t kernel_height = kernel_size[0]; const int64_t kernel_width = kernel_size[1]; @@ -344,20 +316,14 @@ static void slow_conv2d_backward_parameters_out_cpu_template( const int64_t stride_width = stride[1]; Tensor grad_weight_2d; - if (grad_weight.defined()) { - checkContiguous(c, grad_weight_arg); - grad_weight_2d = view_weight_2d(grad_weight); - } - - if (grad_bias.defined()) { - checkContiguous(c, grad_bias_arg); - } + checkContiguous(c, grad_weight_arg); + grad_weight_2d = view_weight_2d(grad_weight); slow_conv2d_shape_check( input_, grad_output_, grad_weight_2d, - grad_bias, + {}, kernel_height, kernel_width, stride_height, @@ -377,21 +343,21 @@ static void slow_conv2d_backward_parameters_out_cpu_template( finput_t = finput[t]; } - slow_conv2d_backward_parameters_frame( - grad_weight_2d, grad_bias, grad_output_t, finput_t); + slow_conv2d_backward_weight_frame( + grad_weight_2d, grad_output_t, finput_t); } } } // namespace -std::tuple slow_conv2d_forward_out_cpu(const Tensor& self, +std::tuple slow_conv2d_forward_out_cpu( + const Tensor& self, const Tensor& weight_, IntArrayRef kernel_size, const c10::optional& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor& output, - Tensor& finput, - Tensor& fgrad_input) { + Tensor& finput) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; @@ -474,10 +440,10 @@ std::tuple slow_conv2d_forward_out_cpu(const Tensor& } }); - return std::tuple(output, finput, fgrad_input); + return std::tuple(output, finput); } -std::tuple slow_conv2d_forward_cpu( +std::tuple slow_conv2d_forward_cpu( const Tensor& self, const Tensor& weight, IntArrayRef kernel_size, const c10::optional& bias_opt, @@ -489,7 +455,6 @@ std::tuple slow_conv2d_forward_cpu( auto output = at::empty({0}, self.options()); auto finput = at::empty({0}, self.options()); - auto fgrad_input = at::empty({0}, self.options()); at::native::slow_conv2d_forward_out_cpu( self, weight, @@ -498,19 +463,18 @@ std::tuple slow_conv2d_forward_cpu( stride, padding, output, - finput, - fgrad_input); - return std::make_tuple(output, finput, fgrad_input); + finput); + return std::make_tuple(output, finput); } -std::tuple slow_conv2d_backward_out_cpu(const Tensor& grad_output, +std::tuple slow_conv2d_backward_out_cpu( + const Tensor& grad_output, const Tensor& self, const Tensor& weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, const Tensor& finput, - const Tensor& fgrad_input, Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias) { @@ -521,31 +485,23 @@ std::tuple slow_conv2d_backward_out_cpu(const Tensor& self, weight, finput, - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(fgrad_input), // cast away auto-generated const of buffer kernel_size, stride, padding); } - if (grad_weight.defined()) { - grad_weight.resize_(weight.sizes()); - grad_weight.zero_(); - } - if (grad_bias.defined()) { - grad_bias.resize_({grad_output.size(1)}); - grad_bias.zero_(); + at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3}); } - if (grad_weight.defined() || grad_bias.defined()) { - slow_conv2d_backward_parameters_out_cpu_template( + if (grad_weight.defined()) { + grad_weight.resize_(weight.sizes()); + grad_weight.zero_(); + slow_conv2d_backward_weight_out_cpu_template( grad_weight, - grad_bias, self, grad_output, finput, - fgrad_input, kernel_size, stride, padding); @@ -563,7 +519,6 @@ std::tuple slow_conv2d_backward_cpu( IntArrayRef stride, IntArrayRef padding, const Tensor& finput, - const Tensor& fgrad_input, std::array output_mask) { Tensor grad_input; Tensor grad_weight; @@ -589,7 +544,6 @@ std::tuple slow_conv2d_backward_cpu( stride, padding, finput, - fgrad_input, grad_input, grad_weight, grad_bias); @@ -603,8 +557,7 @@ Tensor & thnn_conv2d_out(const Tensor & self, const Tensor & weight, IntArrayRef const Tensor& bias = *bias_maybe_owned; Tensor finput = at::empty({0}, self.options()); - Tensor fgrad_input = at::empty({0}, self.options()); - return std::get<0>(at::thnn_conv2d_forward_out(output, finput, fgrad_input, self, weight, kernel_size, bias, stride, padding)); + return std::get<0>(at::_slow_conv2d_forward_out(output, finput, self, weight, kernel_size, bias, stride, padding)); } Tensor thnn_conv2d(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional& bias_opt, IntArrayRef stride, IntArrayRef padding) { @@ -612,7 +565,7 @@ Tensor thnn_conv2d(const Tensor & self, const Tensor & weight, IntArrayRef kerne c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; - return std::get<0>(at::thnn_conv2d_forward(self, weight, kernel_size, bias, stride, padding)); + return std::get<0>(at::_slow_conv2d_forward(self, weight, kernel_size, bias, stride, padding)); } } // namespace native diff --git a/aten/src/ATen/native/cuda/ConvolutionMM2d.cu b/aten/src/ATen/native/cuda/ConvolutionMM2d.cu index bf3f8ac0a6effeb..9ccb2214325caff 100644 --- a/aten/src/ATen/native/cuda/ConvolutionMM2d.cu +++ b/aten/src/ATen/native/cuda/ConvolutionMM2d.cu @@ -117,7 +117,6 @@ void slow_conv2d_forward( const Tensor &weight_, const Tensor &bias, const Tensor &columns, - const Tensor &ones_, int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW) { @@ -125,9 +124,6 @@ void slow_conv2d_forward( slow_conv2d_shape_check( input, {}, weight, bias, kH, kW, dH, dW, padH, padW, /*weight_nullable*/false); - TORCH_CHECK(!bias.defined() || bias.is_contiguous(), - "bias tensor has to be contiguous"); - constexpr int ndim = 4; constexpr int dimf = 1; constexpr int dimh = 2; @@ -148,16 +144,18 @@ void slow_conv2d_forward( // Resize temporary columns resize_output(columns, {nInputPlane * kW * kH, outputHeight * outputWidth}); - // Define a buffer of ones, for bias accumulation - // Note: this buffer can be shared with other modules, it only ever gets increased, - // and always contains ones. - Tensor ones; - if (bias.defined()) { - ones = at::ones({outputHeight, outputWidth}, input.options()); - } const bool requires_columns = ( kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0); + if (bias.defined()) { + TORCH_CHECK(bias.scalar_type() == input.scalar_type(), + "Expected bias to have type ", input.scalar_type(), + " but got ", bias.scalar_type()); + output.copy_(bias.view({-1, 1, 1})); + } else { + output.zero_(); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "slow_conv2d_cuda", [&] { // For each elt in batch, do: @@ -166,28 +164,6 @@ void slow_conv2d_forward( auto input_n = input.select(0, elt); auto output_n = output.select(0, elt); - // Do Bias first: - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m_ = nOutputPlane; - int64_t n_ = outputHeight * outputWidth; - int64_t k_ = 1; - - // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) - if (bias.defined()) { - at::cuda::blas::gemm( - 't', 'n', - n_, m_, k_, - scalar_t(1), - ones.data_ptr(), k_, - bias.data_ptr(), k_, - scalar_t(0), - output_n.data_ptr(), n_ - ); - } else { - output_n.zero_(); - } - if (requires_columns) { // Extract columns: at::native::im2col( @@ -230,7 +206,6 @@ void slow_conv2d_backward( const Tensor &grad_input, const Tensor &weight_, const Tensor &grad_columns, - const Tensor &ones, int kH, int kW, int dH, int dW, int padH, int padW) { @@ -300,26 +275,17 @@ void slow_conv2d_backward( }); } -void slow_conv2d_grad_weight_bias( +void slow_conv2d_grad_weight( const Tensor &input, const Tensor &grad_output, const Tensor &grad_weight_, - const Tensor &grad_bias, const Tensor &columns, - const Tensor &ones, int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW) { - if (grad_weight_.defined()) { - TORCH_CHECK(grad_weight_.is_contiguous(), "grad_weight needs to be contiguous"); - } - if (grad_bias.defined()) { - TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous"); - TORCH_CHECK(ones.is_contiguous(), "ones needs to be contiguous"); - } - + TORCH_CHECK(grad_weight_.is_contiguous(), "grad_weight needs to be contiguous"); auto grad_weight = new_view_weight_MM2d(grad_weight_); - slow_conv2d_shape_check(input, grad_output, grad_weight, grad_bias, + slow_conv2d_shape_check(input, grad_output, grad_weight, {}, kH, kW, dH, dW, padH, padW, /*weight_nullable=*/true); // Params @@ -338,12 +304,6 @@ void slow_conv2d_grad_weight_bias( // Batch size + input planes int64_t batchSize = input_sizes[0]; - // Define a buffer of ones, for bias accumulation - if (ones.defined() && ones.numel() < outputHeight * outputWidth) { - ones.resize_({outputHeight, outputWidth}); - ones.fill_(1); - } - // Resize temporary columns resize_output(columns, {nInputPlane * kH * kW, outputHeight * outputWidth}); @@ -351,69 +311,47 @@ void slow_conv2d_grad_weight_bias( kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0); AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), - "slow_conv2d_grad_weight_bias_cuda", [&] { + "slow_conv2d_grad_weight_cuda", [&] { // For each elt in batch, do: for (int elt = 0; elt < batchSize; elt ++) { // Matrix mulitply per output: auto grad_output_n = grad_output.select(0, elt); - // Do Weight: - if (grad_weight.defined()) { - // Matrix mulitply per output: - auto input_n = input.select(0, elt); - - if (requires_columns) { - // Extract columns: - at::native::im2col( - c10::cuda::getCurrentCUDAStream(), - input_n.data_ptr(), - nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, padH, padW, dH, dW, - 1, 1, - columns.data_ptr() - ); - } - - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m = nOutputPlane; - int64_t n = nInputPlane*kW*kH; - int64_t k = columns.sizes()[1]; - - // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) - auto gemm_in_ptr = requires_columns ? - columns.data_ptr() : - input_n.data_ptr(); - at::cuda::blas::gemm( - 't', 'n', - n, m, k, - scalar_t(1), - gemm_in_ptr, k, - grad_output_n.data_ptr(), k, - scalar_t(1), - grad_weight.data_ptr(), n - ); - } + // Matrix mulitply per output: + auto input_n = input.select(0, elt); - // Do Bias: - if (grad_bias.defined()) { - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m_ = nOutputPlane; - int64_t k_ = outputHeight * outputWidth; - - // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices) - at::cuda::blas::gemv( - 't', - k_, m_, - scalar_t(1), - grad_output_n.data_ptr(), k_, - ones.data_ptr(), 1, - scalar_t(1), - grad_bias.data_ptr(), 1 + if (requires_columns) { + // Extract columns: + at::native::im2col( + c10::cuda::getCurrentCUDAStream(), + input_n.data_ptr(), + nInputPlane, inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, padH, padW, dH, dW, + 1, 1, + columns.data_ptr() ); } + + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m = nOutputPlane; + int64_t n = nInputPlane*kW*kH; + int64_t k = columns.sizes()[1]; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + auto gemm_in_ptr = requires_columns ? + columns.data_ptr() : + input_n.data_ptr(); + at::cuda::blas::gemm( + 't', 'n', + n, m, k, + scalar_t(1), + gemm_in_ptr, k, + grad_output_n.data_ptr(), k, + scalar_t(1), + grad_weight.data_ptr(), n + ); } }); } @@ -421,7 +359,7 @@ void slow_conv2d_grad_weight_bias( } // namespace (anonymous) -std::tuple slow_conv2d_forward_out_cuda( +std::tuple slow_conv2d_forward_out_cuda( const Tensor &self_, const Tensor &weight_, IntArrayRef kernel_size, @@ -429,8 +367,7 @@ std::tuple slow_conv2d_forward_out_cuda( IntArrayRef stride, IntArrayRef padding, Tensor &output, - Tensor &finput, - Tensor &fgrad_input) { + Tensor &finput) { TORCH_CHECK(kernel_size.size() == 2); TORCH_CHECK(stride.size() == 2); TORCH_CHECK(padding.size() == 2); @@ -450,16 +387,14 @@ std::tuple slow_conv2d_forward_out_cuda( *weight, *bias, finput, - fgrad_input, kernel_size[0], kernel_size[1], stride[0], stride[1], padding[0], padding[1] ); - return std::tuple{ - output, finput, fgrad_input}; + return std::tuple{output, finput}; } -std::tuple slow_conv2d_forward_cuda( +std::tuple slow_conv2d_forward_cuda( const Tensor &self, const Tensor &weight, IntArrayRef kernel_size, @@ -468,9 +403,8 @@ std::tuple slow_conv2d_forward_cuda( IntArrayRef padding) { auto output = at::empty({0}, self.options()); auto finput = at::empty({0}, self.options()); - auto fgrad_input = at::empty({0}, self.options()); return slow_conv2d_forward_out_cuda( - self, weight, kernel_size, bias, stride, padding, output, finput, fgrad_input); + self, weight, kernel_size, bias, stride, padding, output, finput); } std::tuple slow_conv2d_backward_out_cuda( @@ -481,18 +415,9 @@ std::tuple slow_conv2d_backward_out_cuda( IntArrayRef stride, IntArrayRef padding, const Tensor& finput, - const Tensor& fgrad_input, Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias) { - if (grad_weight.defined()) { - resize_output(grad_weight, weight_.sizes()); - grad_weight.zero_(); - } - if (grad_bias.defined()) { - resize_output(grad_bias, {weight_.sizes()[0]}); - grad_bias.zero_(); - } auto grad_output = grad_output_.expect_contiguous(); if (grad_input.defined()) { resize_output(grad_input, self_.sizes()); @@ -501,20 +426,23 @@ std::tuple slow_conv2d_backward_out_cuda( slow_conv2d_backward( self_, *grad_output, grad_input, *weight, - finput, fgrad_input, + finput, kernel_size[0], kernel_size[1], stride[0], stride[1], padding[0], padding[1]); } - if (grad_weight.defined() || grad_bias.defined()) { + if (grad_bias.defined()) { + at::sum_out(grad_bias, *grad_output, IntArrayRef{0, 2, 3}); + } + if (grad_weight.defined()) { + resize_output(grad_weight, weight_.sizes()); + grad_weight.zero_(); auto self = self_.expect_contiguous(); - slow_conv2d_grad_weight_bias( + slow_conv2d_grad_weight( *self, *grad_output, grad_weight, - grad_bias, finput, - fgrad_input, kernel_size[0], kernel_size[1], stride[0], stride[1], padding[0], padding[1] @@ -532,7 +460,6 @@ std::tuple slow_conv2d_backward_cuda( IntArrayRef stride, IntArrayRef padding, const Tensor& finput, - const Tensor& fgrad_input, std::array output_mask) { Tensor grad_input; Tensor grad_weight; @@ -558,7 +485,6 @@ std::tuple slow_conv2d_backward_cuda( stride, padding, finput, - fgrad_input, grad_input, grad_weight, grad_bias); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index afa9af3df697da8..50fd82f521ffcda 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9543,25 +9543,25 @@ - func: thnn_conv2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0) -> Tensor python_module: nn -- func: thnn_conv2d_forward.output(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, *, Tensor(a!) output, Tensor(b!) finput, Tensor(c!) fgrad_input) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: _slow_conv2d_forward.output(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, *, Tensor(a!) output, Tensor(b!) finput) -> (Tensor(a!), Tensor(b!)) python_module: nn dispatch: CPU: slow_conv2d_forward_out_cpu CUDA: slow_conv2d_forward_out_cuda -- func: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) +- func: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput) python_module: nn dispatch: CPU: slow_conv2d_forward_cpu CUDA: slow_conv2d_forward_cuda -- func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: _slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) python_module: nn dispatch: CPU: slow_conv2d_backward_out_cpu CUDA: slow_conv2d_backward_out_cuda -- func: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) +- func: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) python_module: nn dispatch: CPU: slow_conv2d_backward_cpu diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 9ae5185a7e26b58..f884ff7622a292c 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -50,6 +50,8 @@ ("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)), ("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)), ("aten::randperm", datetime.date(9999, 1, 1)), + ("aten::thnn_conv2d_forward", datetime.date(2021, 9, 30)), + ("aten::thnn_conv2d_backward", datetime.date(2021, 9, 30)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index e03920fcfca5ba9..5cfe6f98515e9f1 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -152,8 +152,8 @@ TEST(THNNConvTest, Basic) { at::Tensor bias = torch::randn({out_channels}); // run forward eagerly - at::Tensor output, finput, fgradinput; - std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward( + at::Tensor output, finput; + std::tie(output, finput) = at::_slow_conv2d_forward( input, weight, kernel_size, bias, stride, padding); // make grad_outputs @@ -161,12 +161,10 @@ TEST(THNNConvTest, Basic) { torch::randn_like(output, at::MemoryFormat::Preserve); at::Tensor grad_finput = torch::zeros_like(finput, at::MemoryFormat::Preserve); - at::Tensor grad_fgradinput = - torch::zeros_like(fgradinput, at::MemoryFormat::Preserve); // run backward eagerly at::Tensor grad_input, grad_weight, grad_bias; - std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward( + std::tie(grad_input, grad_weight, grad_bias) = at::_slow_conv2d_backward( grad_output, input, weight, @@ -174,7 +172,6 @@ TEST(THNNConvTest, Basic) { stride, padding, finput, - fgradinput, {true, true, true}); // make JIT graph @@ -188,7 +185,7 @@ TEST(THNNConvTest, Basic) { auto biasg = graph->addInput("bias"); Value* conv = graph->insert( - aten::thnn_conv2d_forward, + aten::_slow_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val}); auto outputs = conv->node()->outputs(); for (auto output : outputs) { @@ -212,7 +209,6 @@ TEST(THNNConvTest, Basic) { tensor_list tensor_grads_in; tensor_grads_in.push_back(grad_output); tensor_grads_in.push_back(grad_finput); - tensor_grads_in.push_back(grad_fgradinput); // Get outputs from the interpreter tensor_list tensors_out, tensor_grads_out; @@ -223,7 +219,6 @@ TEST(THNNConvTest, Basic) { tensor_list expected_tensors_out, expected_tensor_grads_out; expected_tensors_out.push_back(output); expected_tensors_out.push_back(finput); - expected_tensors_out.push_back(fgradinput); expected_tensor_grads_out.push_back(grad_input); expected_tensor_grads_out.push_back(grad_weight); expected_tensor_grads_out.push_back(grad_bias); diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fb72c654b0afce9..0e805a4b6c18454 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1840,10 +1840,10 @@ - name: slow_conv_transpose3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, false, grad_input_mask) -- name: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) - self, weight, bias: "grad.defined() ? thnn_conv2d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask) : std::tuple()" +- name: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput) + self, weight, bias: "grad.defined() ? _slow_conv2d_backward(grad, self, weight, kernel_size, stride, padding, finput, grad_input_mask) : std::tuple()" -- name: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) +- name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, false, grad_input_mask) - name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index d17072f3b1c354c..5f2176ac3fd4dbd 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -240,7 +240,7 @@ DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = { # These non-view functions return tensors with storage use_count != 1 - 'thnn_conv2d_forward', 'slow_conv3d_forward', 'channel_shuffle', + '_slow_conv2d_forward', 'slow_conv3d_forward', 'channel_shuffle', # If an input is returned as-is in output, we cannot guarantee its storage_impl # use count to be 1 either. diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 0c54f46b7f5c2a3..23bcc2a29cc7efe 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -49,7 +49,7 @@ bool needTrimGrad(Node* n) { bool isDifferentiable(const Node* n) { // TODO: scalar-tensor ops should be canonicalized static OperatorSet differentiable_ops = { - "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)", + "aten::_slow_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor)", "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", }; @@ -236,10 +236,10 @@ class GradientHelper { return {}; } else if ( node->matches( - "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) { + "aten::_slow_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor)")) { auto graph = node->owningGraph(); auto backward_value = graph->insert( - aten::thnn_conv2d_backward, + aten::_slow_conv2d_backward, {grad_values.at(0), inputs.at(0), inputs.at(1), @@ -247,7 +247,6 @@ class GradientHelper { node->namedInput(attr::stride), node->namedInput(attr::padding), outputs.at(1), - outputs.at(2), graph->insertConstant(c10::List({true, true, true}))}); // graph->insert returns a tuple automatically if multiple outputs are // returned. So unpack them again.