Skip to content

Conversation

anjali411
Copy link
Contributor

@anjali411 anjali411 commented Sep 28, 2020

Stack from ghstack:

This PR disables autograd for all C -> C, R -> C functions which are not included in the allowlist GRADIENT_IMPLEMENTED_FOR_COMPLEX. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:

>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.

The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
torch.abs (updated in #39955 ), torch.angle, torch.norm, torch.irfft, torch.istft.

Differential Revision: D23998156

@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label Sep 28, 2020
@anjali411 anjali411 added this to the 1.7.0 milestone Sep 28, 2020
'eq_', 'ne_', 'add', '__radd__', 'sum', '_conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'view_as_real', 'real', 'imag', 'asin', 'acos', 'sub',
'div', 'cat', 'view_as_complex', 'neg', 'complex', 'select', '_s_where', 'as_strided',
'_fft_with_size'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fft_with_size doesn't use complex, it uses (..., 2) shaped real tensors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fft failures are real, though.

@dr-ci
Copy link

dr-ci bot commented Sep 28, 2020

💊 CI failures summary and remediations

As of commit 5d2c5fe (more details on the Dr. CI page):



❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_ios_11_2_1_x86_64_build (1/1)

Step: "Update Homebrew" (full log | diagnosis details | 🔁 rerun) ❄️

fatal: Could not read from remote repository.
remote: Total 178 (delta 92), reused 46 (delta 24), pack-reused 0         
Receiving objects:  98% (175/178) Receiving objects:  99% (177/178) Receiving objects: 100% (178/178) Receiving objects: 100% (178/178), 63.90 KiB | 10.65 MiB/s, done. 
Resolving deltas:  96% (89/92) Resolving deltas:  97% (90/92) Resolving deltas: 100% (92/92) Resolving deltas: 100% (92/92), completed with 85 local objects. 
From ssh://github.com/Homebrew/homebrew-cask-versions 
 + 15f6b44...90ed6b8 master     -> origin/master  (forced update) 
+ git reset --hard origin/master 
HEAD is now at 90ed6b8 Update microsoft-edge-beta from 86.0.622.19 to 86.0.622.28 (#9686) 
+ for path in '$(find /usr/local/Homebrew -type d -name .git)' 
+ cd /usr/local/Homebrew/Library/Taps/homebrew/homebrew-core/.git/.. 
+ git fetch --depth=1 origin 
fatal: Could not read from remote repository. 
 
Please make sure you have the correct access rights 
and the repository exists. 
Connection to github.com closed by remote host.  

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

Since your merge base is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 46 times.

This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:
```
>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.
```

The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
`torch.abs`, `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`.

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 28, 2020
ghstack-source-id: d734ad1
Pull Request resolved: #45461
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not check the C->R functions nor the fact that the functions in the list are actually properly implemented.

The rest of the codegen looks good except the TensorList support.
Also it would be nice to show what the new generated code looks like for a sample function.

if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
return body
for arg in differentiable_outputs:
if arg['type'] == 'Tensor':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about TensorList?
In particular functions like unbind() will return such objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions that are differentiable and return TensorList are: torch.unbind, torch.split (both of which have correct backward definition for complex). So, I think its ok to just not do anything for that case. However, I'll add check for TensorList type for any functions that maybe added in future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no other? Sounds good.
But yes a check is nice to make sure we don't break this in the future.

Copy link
Contributor Author

@anjali411 anjali411 Sep 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's _cudnn_rnn_backward but it's non-differentiable.
Added split', split_with_sizes, unsafe_split, split_with_sizes_backward to the list and also added a check to error out for tensorlist otherwise.

@@ -67,6 +67,14 @@ inline void throw_error_out_requires_grad(const char* name) {
"but one of the arguments requires grad.");
}

inline void throw_error_for_complex_fns_backward_not_implemented(const Tensor& tensor, const char* name) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: name looks overly verbose.

This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:
```
>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.
```

The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
`torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`.

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 29, 2020
ghstack-source-id: 557f6e9
Pull Request resolved: #45461
@mruberry mruberry removed the request for review from apaszke September 29, 2020 11:09
@@ -67,6 +67,14 @@ inline void throw_error_out_requires_grad(const char* name) {
"but one of the arguments requires grad.");
}

inline void throw_error_for_complex_fns_backward_not_implemented(const Tensor& tensor, const char* name) {
if (tensor.requires_grad() && tensor.is_complex()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a TORCH_CHECK here?

This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:
```
>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.
```

The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
`torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`.

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 29, 2020
ghstack-source-id: 59c3aaf
Pull Request resolved: #45461
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:
```
>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.
```

The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
`torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`.

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 29, 2020
ghstack-source-id: 0ca1478
Pull Request resolved: #45461
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@ezyang
Copy link
Contributor

ezyang commented Sep 29, 2020

Thanks, looks good.

cc @robieta I'm expecting a mild increase in instruction count here for AD benchmarks, will be more pronounced for operators on tensor lists.

@mruberry
Copy link
Collaborator

I think the test failures are OK. I'm not sure why the XLA build is failing but it doesn't look related to this PR.

@albanD
Copy link
Collaborator

albanD commented Sep 29, 2020

@ezyang I think (almost) all the TensorList op we have are actually in the allowlist. So the impact should be minimal.

This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:
```
>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.
```

The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
`torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`.

Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 30, 2020
ghstack-source-id: 916aa69
Pull Request resolved: #45461
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:
```
>>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
>>> x.pow(3)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: pow does not support automatic differentiation for outputs with complex dtype.
```

The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
`torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`.

Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 30, 2020
ghstack-source-id: a3aaa9b
Pull Request resolved: #45461
@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in 415ed43.

malfet pushed a commit that referenced this pull request Sep 30, 2020
ghstack-source-id: a3aaa9b
Pull Request resolved: #45461
This was referenced Oct 1, 2020
@facebook-github-bot facebook-github-bot deleted the gh/anjali411/59/head branch October 4, 2020 14:18
@anjali411 anjali411 changed the title Add whitelist for complex backward Add allowlist for complex backward Jan 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complex_autograd Merged module: complex Related to complex number support in PyTorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants