-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Forward AD formulas batch 2 #57863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Forward AD formulas batch 2 #57863
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 697868e (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 5b32b08 Pull Request resolved: pytorch#57863
[ghstack-poisoned]
[ghstack-poisoned]
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387763](https://our.internmc.facebook.com/intern/diff/D28387763) [ghstack-poisoned]
Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387763](https://our.internmc.facebook.com/intern/diff/D28387763) [ghstack-poisoned]
Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387763](https://our.internmc.facebook.com/intern/diff/D28387763) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't looked at the actual formulas, but here are some comments on the rest
@@ -5369,9 +5369,10 @@ def gradgradcheck_method_precision_override(test_name): | |||
return override | |||
|
|||
def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable, | |||
input_variables, run_gradgradcheck=True, check_batched_grad=True): | |||
input_variables, run_gradgradcheck=True, check_batched_grad=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes to test_autograd
are for method tests? Maybe you could add a similar check that checks_forward_ad=True would raise an error unless its explicitly set to True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will induce larger changes because some other tests will need to be updated as well. I'll do that in a follow up PR.
Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387763](https://our.internmc.facebook.com/intern/diff/D28387763) [ghstack-poisoned]
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@@ -137,7 +137,7 @@ def postprocess_forward_derivatives( | |||
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: | |||
required_inputs = set() | |||
for arg in args_with_derivatives: | |||
if arg.type == 'TensorList': | |||
if arg.type == 'at::TensorList': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a bugfix? (Should this be arg.type == 'at::TensorList' or arg.type == 'TensorList'
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bugfix indeed. This code wasn't actually used before and it is now that we have the "cat" formula.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we assert 'TensorList' not in arg.type
after this if statement? Would make it easier to catch bugs like this in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. It was just me not knowing that the namespace is included in this name.
I think there should be methods on the type to check that directly but I couldn't find it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formula for cross looks fishy, other than that the formulas lgtm
Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387763](https://our.internmc.facebook.com/intern/diff/D28387763) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to revert the changes to the submodules, otherwise this lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still seems weird to me that the Tensor variant of clamp does not raise an error since we added formulas for both scalar/tensor variants but only added support_forward_ad=True
for the scalar variant in OpInfos. Any idea whats up with that?
LGTM otherwise though.
We did not add the formula for the Tensor variant actually, only the multiple scalar versions. So this is expected. |
Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387763](https://our.internmc.facebook.com/intern/diff/D28387763) [ghstack-poisoned]
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
ghstack-source-id: 46ae249 Pull Request resolved: pytorch#57863
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Aren't clamp_max.Tensor and clamp_min.Tensor tensor variants? Or maybe I'm missing something? |
They are, but for the clamp_min/max functions. The Tensor variant for clamp is |
Summary: Pull Request resolved: pytorch#57863 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D28387763 Pulled By: albanD fbshipit-source-id: e1b60ab728bb05b9e3323ee0dc7e401aaf5b8817
Slow gradcheck also passes for this PR and can be found here: #57976
Stack from ghstack:
Differential Revision: D28387763