-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Add allowlist for complex backward #45461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
tools/autograd/gen_variable_type.py
Outdated
'eq_', 'ne_', 'add', '__radd__', 'sum', '_conj', 'sin', 'cos', 'mul', 'sinh', | ||
'cosh', '__rmul__', 'sgn', 'view_as_real', 'real', 'imag', 'asin', 'acos', 'sub', | ||
'div', 'cat', 'view_as_complex', 'neg', 'complex', 'select', '_s_where', 'as_strided', | ||
'_fft_with_size' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_fft_with_size
doesn't use complex, it uses (..., 2) shaped real tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fft failures are real, though.
💊 CI failures summary and remediationsAs of commit 5d2c5fe (more details on the Dr. CI page):
❄️ 1 failure tentatively classified as flakybut reruns have not yet been triggered to confirm:
|
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs`, `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not check the C->R functions nor the fact that the functions in the list are actually properly implemented.
The rest of the codegen looks good except the TensorList support.
Also it would be nice to show what the new generated code looks like for a sample function.
tools/autograd/gen_variable_type.py
Outdated
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: | ||
return body | ||
for arg in differentiable_outputs: | ||
if arg['type'] == 'Tensor': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about TensorList?
In particular functions like unbind()
will return such objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions that are differentiable and return TensorList are: torch.unbind
, torch.split
(both of which have correct backward definition for complex). So, I think its ok to just not do anything for that case. However, I'll add check for TensorList type for any functions that maybe added in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no other? Sounds good.
But yes a check is nice to make sure we don't break this in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's _cudnn_rnn_backward
but it's non-differentiable.
Added split', split_with_sizes, unsafe_split, split_with_sizes_backward
to the list and also added a check to error out for tensorlist otherwise.
@@ -67,6 +67,14 @@ inline void throw_error_out_requires_grad(const char* name) { | |||
"but one of the arguments requires grad."); | |||
} | |||
|
|||
inline void throw_error_for_complex_fns_backward_not_implemented(const Tensor& tensor, const char* name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: name looks overly verbose.
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
@@ -67,6 +67,14 @@ inline void throw_error_out_requires_grad(const char* name) { | |||
"but one of the arguments requires grad."); | |||
} | |||
|
|||
inline void throw_error_for_complex_fns_backward_not_implemented(const Tensor& tensor, const char* name) { | |||
if (tensor.requires_grad() && tensor.is_complex()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a TORCH_CHECK here?
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Thanks, looks good. cc @robieta I'm expecting a mild increase in instruction count here for AD benchmarks, will be more pronounced for operators on tensor lists. |
I think the test failures are OK. I'm not sure why the XLA build is failing but it doesn't look related to this PR. |
@ezyang I think (almost) all the TensorList op we have are actually in the allowlist. So the impact should be minimal. |
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
@anjali411 merged this pull request in 415ed43. |
Stack from ghstack:
This PR disables autograd for all C -> C, R -> C functions which are not included in the allowlist
GRADIENT_IMPLEMENTED_FOR_COMPLEX
. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
torch.abs
(updated in #39955 ),torch.angle
,torch.norm
,torch.irfft
,torch.istft
.Differential Revision: D23998156