Skip to content

Commit

Permalink
Remove backward ops for cuDNN transposed convolution
Browse files Browse the repository at this point in the history
ghstack-source-id: 46cf61dd8e77d3c82fcecf4bc4cca5b81248c742
Pull Request resolved: #69902
  • Loading branch information
jbschlosser committed Dec 14, 2021
1 parent 810f317 commit 2857965
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 31 deletions.
4 changes: 0 additions & 4 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -272,10 +272,6 @@ _(aten, cudnn_batch_norm) \
_(aten, cudnn_batch_norm_backward) \
_(aten, cudnn_convolution) \
_(aten, cudnn_convolution_transpose) \
_(aten, cudnn_convolution_transpose_backward) \
_(aten, cudnn_convolution_transpose_backward_bias) \
_(aten, cudnn_convolution_transpose_backward_input) \
_(aten, cudnn_convolution_transpose_backward_weight) \
_(aten, cudnn_convolution_relu) \
_(aten, cudnn_convolution_add_relu) \
_(aten, cudnn_grid_sampler) \
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/ConvUtils.h
Expand Up @@ -9,7 +9,11 @@ namespace at { namespace native {
using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct ConvParams {
Expand Down
12 changes: 9 additions & 3 deletions aten/src/ATen/native/Convolution.cpp
Expand Up @@ -23,8 +23,10 @@ constexpr int MIOPEN_DIM_MAX = 5;
namespace at { namespace native {

DEFINE_DISPATCH(cudnn_convolution_backward_stub);
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub, cudnn_convolution_backward_fn);
DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub);
DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub);
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub, cudnn_convolution_backward_fn);
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub, cudnn_convolution_transpose_backward_fn);

std::ostream& operator<<(std::ostream & out, const ConvParams& params) {
out << "ConvParams {"
Expand Down Expand Up @@ -1565,12 +1567,16 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
break;
}
case ConvBackend::CudnnTranspose:
{
check_input_same_type_as_parameters(input, weight);
std::tie(backend_grad_input, backend_grad_weight) = at::cudnn_convolution_transpose_backward(
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_transpose_backward_stub(
input.device().type(),
input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding,
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
{output_mask[0], output_mask[1]});
input_weight_output_mask);
break;
}
case ConvBackend::Empty:
if (output_mask[0]) {
backend_grad_input = at::zeros_like(input);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cudnn/ConvShared.cpp
Expand Up @@ -630,6 +630,7 @@ Tensor cudnn_convolution_add_relu(
}

REGISTER_CUDA_DISPATCH(cudnn_convolution_backward_stub, &cudnn_convolution_backward);
REGISTER_CUDA_DISPATCH(cudnn_convolution_transpose_backward_stub, &cudnn_convolution_transpose_backward);

}}

Expand Down
14 changes: 0 additions & 14 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -1442,20 +1442,6 @@
dispatch:
CUDA: cudnn_convolution_transpose

# NB: output_padding not strictly needed here, but it's helpful for the float
# backwards
- func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor)
dispatch:
CUDA: cudnn_convolution_transpose_backward

- func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
dispatch:
CUDA: cudnn_convolution_transpose_backward_input

- func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
dispatch:
CUDA: cudnn_convolution_transpose_backward_weight

- func: cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor
dispatch:
CUDA: cudnn_convolution_relu
Expand Down
3 changes: 3 additions & 0 deletions test/backward_compatibility/check_backward_compatibility.py
Expand Up @@ -54,6 +54,9 @@
("aten::cudnn_convolution_backward", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_backward_input", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_backward_weight", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_transpose_backward", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_transpose_backward_input", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_transpose_backward_weight", datetime.date(2022, 1, 31)),
("aten::_slow_conv2d_forward", datetime.date(2022, 1, 31)),
("aten::_slow_conv2d_backward", datetime.date(2022, 1, 31)),
("aten::slow_conv3d_forward", datetime.date(2022, 1, 31)),
Expand Down
7 changes: 2 additions & 5 deletions tools/autograd/derivatives.yaml
Expand Up @@ -2253,13 +2253,10 @@
log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity)

- name: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
self, weight: "grad.defined() ? cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, grad_input_mask) : std::tuple<Tensor, Tensor>()"

- name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, allow_tf32, grad_input_mask)
self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, output_padding, stride, dilation, true, groups, {grad_input_mask[0], grad_input_mask[1]})"

- name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, {grad_input_mask[0], grad_input_mask[1]})"
self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, std::vector<int64_t>(padding.size(), 0), stride, dilation, false, groups, {grad_input_mask[0], grad_input_mask[1]})"

- name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output
self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple<Tensor, Tensor>()"
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -4747,18 +4747,20 @@ Tensor warn_backwards(const Tensor &grad_output) {
}

// This function only exists because cuDNN does not support bias gradient computation and it's not easy
// to slice a std::tuple to return only grad_input / grad_weight from convolution_backward.
// to slice a std::tuple to return only grad_input / grad_weight from convolution_backward. It will
// be removed when the cudnn_convolution and cudnn_convolution_transpose go away.
std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding,
at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array<bool,2> output_mask) {
at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, bool transposed, int64_t groups,
::std::array<bool,2> output_mask) {
if (!grad_output.defined()) {
return std::tuple<Tensor, Tensor>();
}

// Just call the general backward and ignore the bias gradient part.
std::tuple<Tensor, Tensor, Tensor> grad_inputs = at::native::convolution_backward(
grad_output, self, weight, c10::nullopt, stride, padding, dilation, /*transposed=*/ false,
std::vector<int64_t>(padding.size(), 0), groups, {output_mask[0], output_mask[1], false});
grad_output, self, weight, c10::nullopt, stride, padding, dilation, transposed,
output_padding, groups, {output_mask[0], output_mask[1], false});
std::tuple<Tensor, Tensor> result = std::make_tuple(std::get<0>(grad_inputs), std::get<1>(grad_inputs));
return result;
}
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -388,7 +388,8 @@ Tensor warn_backwards(const Tensor &grad_output);

std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding,
at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array<bool,2> output_mask);
at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, bool transposed, int64_t groups,
::std::array<bool,2> output_mask);

} // namespace details
} // namespace generated
Expand Down

0 comments on commit 2857965

Please sign in to comment.