Skip to content

Commit

Permalink
Hook up general convolution to convolution_backward (#69584)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #69584

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D32936380

Pulled By: jbschlosser

fbshipit-source-id: c6fdd88db33bd1a9d0eabea47ae09a4d5b170e92
  • Loading branch information
jbschlosser authored and facebook-github-bot committed Dec 13, 2021
1 parent 0420de3 commit fc37e5b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 24 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -1276,6 +1276,8 @@
manual_cpp_binding: True

- func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
dispatch:
CompositeExplicitAutograd: convolution

- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
dispatch:
Expand All @@ -1290,6 +1292,8 @@
CompositeExplicitAutograd: convolution_backward_overrideable

- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
dispatch:
CompositeExplicitAutograd: _convolution

- func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor

Expand Down
25 changes: 1 addition & 24 deletions test/test_nn.py
Expand Up @@ -13473,31 +13473,8 @@ def _make_noncontiguous(inp):
backend_actual = torch._C._select_conv_backend(*inputs)
self.assertEqual(backend_actual, backend_expected)

# Autograd function to hook up the general convolution function to convolution_backward
# without a derivatives.yaml entry. TODO: Once general forward + backward are hooked up together,
# remove this.
class MyConv(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
ctx.save_for_backward(input, weight, bias)
ctx.stuff = (stride, padding, dilation, transposed, output_padding, groups)
return torch.convolution(input, weight, bias, stride, padding, dilation, transposed,
output_padding, groups)

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride, padding, dilation, transposed, output_padding, groups = ctx.stuff
grad_input, grad_weight, grad_bias = torch.ops.aten.convolution_backward(
grad_output, input, weight, None if bias is None else bias.shape, stride, padding, dilation,
transposed, output_padding, groups,
list(ctx.needs_input_grad[:2]) + [False if bias is None else True])
return grad_input, grad_weight, None if bias is None else grad_bias, None, \
None, None, None, None, None

convolution = MyConv.apply

# Ensure backward call succeeds.
convolution = torch.ops.aten.convolution
output = convolution(*inputs)
grad_output = torch.randn(output.shape, device=device, dtype=dtype)
if not contiguous:
Expand Down
9 changes: 9 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -1937,6 +1937,15 @@
self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding)
indices: non_differentiable

- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"

# TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution.
# Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context
# by convolution_backward instead of being passed along from the forward pass.
- name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"

- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, at::globalContext().benchmarkCuDNN(), at::globalContext().deterministicCuDNN() || at::globalContext().deterministicAlgorithms(), at::globalContext().userEnabledCuDNN(), at::globalContext().allowTF32CuDNN(), grad_input_mask)

Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/load_derivatives.py
Expand Up @@ -503,6 +503,12 @@ def stride_expr(name: str) -> str:
'suffix': '_sizes',
'nctype': lambda name: NamedCType(name, BaseCType(intArrayRefT)),
}),
# replace self->sizes() with self_sizes_opt
(r'{}->sizes\(\)', {
'suffix': '_sizes_opt',
'nctype': lambda name: NamedCType(name, OptionalCType(BaseCType(intArrayRefT))),
'expr': lambda name: f'{name}.has_value() ? c10::optional<IntArrayRef>({name}->sizes()) : c10::nullopt',
}),
# replace self.options() with self_options
(r'{}.options\(\)', {
'suffix': '_options',
Expand Down

0 comments on commit fc37e5b

Please sign in to comment.