New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support autograd.Function w/ grad #99483
Conversation
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99483
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 36c5ca5: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: 33cf240b29659894b2d13e62f3b3fad2e7a0b0df Pull Request resolved: #99483
This PR needs a labelIf your changes are user facing and intended to be a part of release notes, please use a label starting with If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see |
…n w/ grad as a TorchVariable and simulate user invoked allow_in_graph, 2x faster Deberta training" cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: b037953dbc9eaa36d42e3b711507a63ec5b08eb7 Pull Request resolved: #99483 rm crap
This is known to not be sound - you can't just allow_in_graph any old thing like this. However, this (1) will help us measure the potential impact (2) be a base for figuring out how to determine soundness, or, failing that, config it out. |
…n w/ grad as a TorchVariable and simulate user invoked allow_in_graph, 2x faster Deberta training" cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: ac8c628b7b149fe50bbbf31d80ba2cd1db72d707 Pull Request resolved: #99483 rm crap Make it nice
…n w/ grad as a TorchVariable and simulate user invoked allow_in_graph, 2x faster Deberta training" cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: b6a20e9bb884150c10b4f5d205f3591d4426835e Pull Request resolved: #99483 rm crap Make it nice cleanup Fix test, source shenanigins
…n w/ grad as a TorchVariable and simulate user invoked allow_in_graph, 2x faster Deberta training" cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: de4f35599ef8fdd1902a6df4b27a92f7ecfb5b05 Pull Request resolved: #99483 rm crap Make it nice cleanup Fix test, source shenanigins GB fixes
cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: 265e353be064032edb404b043fc4f70000687d47 Pull Request resolved: #99483 Rm test after fix Refactor Refactor Refactor Refactor Fix tests Fix tests Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Rethink ghstack-source-id: 265e353be064032edb404b043fc4f70000687d47 Pull Request resolved: #101021 Rethink Rethink Rethink Rethink Fixes Fixes Fixes Feedback Fixes Fixes Fixes Fixes Fixes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I convinced myself that dtype
is not actually a problem. Added some more suggestions for methods to ban.
@@ -258,6 +258,20 @@ def is_fbcode(): | |||
"skipfiles_inline_module_allowlist", | |||
} | |||
|
|||
capture_autograd_function = True | |||
|
|||
_autograd_backward_strict_mode_banned_ops = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out that we don't need to ban dtype for safety, because it is guaranteed that the dtype of the grad_output
must be the same as the outputs:
pytorch/torch/csrc/autograd/engine.cpp
Lines 834 to 843 in 794cc39
if (c10::typeMetaToScalarType(metadata.options().dtype()) != | |
grad.scalar_type()) { | |
grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype())); | |
} | |
if (grad.dtype() != metadata.dtype()) { | |
std::stringstream ss; | |
ss << "invalid gradient at index " << i << " - expected dtype "; | |
ss << metadata.dtype() << " but got " << grad.dtype(); | |
AT_ERROR(format_error(ss.str())); | |
} |
torch/_dynamo/config.py
Outdated
"requires_grad", | ||
"storage_offset", | ||
"layout", | ||
"is_cuda", | ||
"is_quantized", | ||
"is_meta", | ||
"data", | ||
"is_sparse", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we do a regex and ban all is_*
methods? The naming implies those are Tensor properties. Alternatively, we could add all of the ones in https://pytorch.org/docs/stable/tensors.html to the list, but that might not be as robust:
|
||
jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] | ||
if jvp_fn is not torch.autograd.Function.jvp: | ||
unimplemented("NYI - User defind jvp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: defind -> defined
This PR adds support for tracing autograd.Function with grad. A few important bullet points outlining our approach: 1) Our goal is to verify soundness in order to add a call_function to the autograd.Function's `apply` to the graph. 2) We achieve (1) by either verifying soundness or rejecting soundness, by ensuring that both forward and backward of the autograd.Function are sound. 3) For the forward, if we verify soundness, we install its guards into the graph. 4) For the backward, if we verify soundness, we throw it out. However, backwards soundness verification is more onerous, and has a config driven set of banned attrs and methods for tensors. 1-4 above are achieved by turning the forward and backward into UserDefinedFunctionVariables, and inlining through them, relying on dynamo's soundness detection. If we graph break in these, we raise and treat them as unsound. As noted above, backwards is stricter yet. For the tracing, the safety comes from dynamo's HigherOrderOperator system. That system ensures that not only do we trace soundly, but that no new variables are lifted into inputs during the tracing, and that the forward and backwards are entirely self contained. Whenever we reject a function as unsound, we restore back, as usual. Due to some limitations in the lifting logic, we have an escape hatch we implemented for tensors that are known in forward, but cross into backwards through save_tensors (save) /saved_tensors (load). We escape hatch here to avoid having the known saved tensors coming from forward end up being accidentally treated as lifted variables (and rejected). This is sound, but a little hacky feeling. Additionally, due to some limitations in fx node removal, combined with how we produce subgraphs for the traces installed from HigherOrderOperators, we had to improve our node removal logic. In the event of a restore, we remove the old nodes from the graph, as usual in dynamo. However, because the references to these nodes may exist in subgraphs, we traverse any nodes users and remove them first if and only if they are in another graph. This is always sound, because removal should only be downstream of restoration at this point. cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: 15a889e99da3f3b7c6232852f363fb80f8e70a8e Pull Request resolved: #99483 Rm test after fix Refactor Refactor Refactor Refactor Fix tests Fix tests Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Rethink ghstack-source-id: 15a889e99da3f3b7c6232852f363fb80f8e70a8e Pull Request resolved: #101021 Rethink Rethink Rethink Rethink Fixes Fixes Fixes Feedback Fixes Fixes Fixes Fixes Fixes Fixes
This PR adds support for tracing autograd.Function with grad. A few important bullet points outlining our approach: 1) Our goal is to verify soundness in order to add a call_function to the autograd.Function's `apply` to the graph. 2) We achieve (1) by either verifying soundness or rejecting soundness, by ensuring that both forward and backward of the autograd.Function are sound. 3) For the forward, if we verify soundness, we install its guards into the graph. 4) For the backward, if we verify soundness, we throw it out. However, backwards soundness verification is more onerous, and has a config driven set of banned attrs and methods for tensors. 1-4 above are achieved by turning the forward and backward into UserDefinedFunctionVariables, and inlining through them, relying on dynamo's soundness detection. If we graph break in these, we raise and treat them as unsound. As noted above, backwards is stricter yet. For the tracing, the safety comes from dynamo's HigherOrderOperator system. That system ensures that not only do we trace soundly, but that no new variables are lifted into inputs during the tracing, and that the forward and backwards are entirely self contained. Whenever we reject a function as unsound, we restore back, as usual. Due to some limitations in the lifting logic, we have an escape hatch we implemented for tensors that are known in forward, but cross into backwards through save_tensors (save) /saved_tensors (load). We escape hatch here to avoid having the known saved tensors coming from forward end up being accidentally treated as lifted variables (and rejected). This is sound, but a little hacky feeling. Additionally, due to some limitations in fx node removal, combined with how we produce subgraphs for the traces installed from HigherOrderOperators, we had to improve our node removal logic. In the event of a restore, we remove the old nodes from the graph, as usual in dynamo. However, because the references to these nodes may exist in subgraphs, we traverse any nodes users and remove them first if and only if they are in another graph. This is always sound, because removal should only be downstream of restoration at this point. cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: c21fed1742cce90613c819b16e5125d9b0261f04 Pull Request resolved: #99483 Rm test after fix Refactor Refactor Refactor Refactor Fix tests Fix tests Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Rethink ghstack-source-id: c21fed1742cce90613c819b16e5125d9b0261f04 Pull Request resolved: #101021 Rethink Rethink Rethink Rethink Fixes Fixes Fixes Feedback Fixes Fixes Fixes Fixes Fixes Fixes lint
@pytorchbot merge -f "Broken base cse test on macos, 180 other test suites passed" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
This PR adds support for tracing autograd.Function with grad. A few important bullet points outlining our approach: 1) Our goal is to verify soundness in order to add a call_function to the autograd.Function's `apply` to the graph. 2) We achieve (1) by either verifying soundness or rejecting soundness, by ensuring that both forward and backward of the autograd.Function are sound. 3) For the forward, if we verify soundness, we install its guards into the graph. 4) For the backward, if we verify soundness, we throw it out. However, backwards soundness verification is more onerous, and has a config driven set of banned attrs and methods for tensors. 1-4 above are achieved by turning the forward and backward into UserDefinedFunctionVariables, and inlining through them, relying on dynamo's soundness detection. If we graph break in these, we raise and treat them as unsound. As noted above, backwards is stricter yet. For the tracing, the safety comes from dynamo's HigherOrderOperator system. That system ensures that not only do we trace soundly, but that no new variables are lifted into inputs during the tracing, and that the forward and backwards are entirely self contained. Whenever we reject a function as unsound, we restore back, as usual. Due to some limitations in the lifting logic, we have an escape hatch we implemented for tensors that are known in forward, but cross into backwards through save_tensors (save) /saved_tensors (load). We escape hatch here to avoid having the known saved tensors coming from forward end up being accidentally treated as lifted variables (and rejected). This is sound, but a little hacky feeling. Additionally, due to some limitations in fx node removal, combined with how we produce subgraphs for the traces installed from HigherOrderOperators, we had to improve our node removal logic. In the event of a restore, we remove the old nodes from the graph, as usual in dynamo. However, because the references to these nodes may exist in subgraphs, we traverse any nodes users and remove them first if and only if they are in another graph. This is always sound, because removal should only be downstream of restoration at this point. cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: f96d45545495b57a40a4e78571febcf3d926f4be Pull Request resolved: #99483 Rm test after fix Refactor Refactor Refactor Refactor Fix tests Fix tests Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Rethink ghstack-source-id: f96d45545495b57a40a4e78571febcf3d926f4be Pull Request resolved: #101021 Rethink Rethink Rethink Rethink Fixes Fixes Fixes Feedback Fixes Fixes Fixes Fixes Fixes Fixes lint
This PR adds support for tracing autograd.Function with grad. A few important bullet points outlining our approach: 1) Our goal is to verify soundness in order to add a call_function to the autograd.Function's `apply` to the graph. 2) We achieve (1) by either verifying soundness or rejecting soundness, by ensuring that both forward and backward of the autograd.Function are sound. 3) For the forward, if we verify soundness, we install its guards into the graph. 4) For the backward, if we verify soundness, we throw it out. However, backwards soundness verification is more onerous, and has a config driven set of banned attrs and methods for tensors. 1-4 above are achieved by turning the forward and backward into UserDefinedFunctionVariables, and inlining through them, relying on dynamo's soundness detection. If we graph break in these, we raise and treat them as unsound. As noted above, backwards is stricter yet. For the tracing, the safety comes from dynamo's HigherOrderOperator system. That system ensures that not only do we trace soundly, but that no new variables are lifted into inputs during the tracing, and that the forward and backwards are entirely self contained. Whenever we reject a function as unsound, we restore back, as usual. Due to some limitations in the lifting logic, we have an escape hatch we implemented for tensors that are known in forward, but cross into backwards through save_tensors (save) /saved_tensors (load). We escape hatch here to avoid having the known saved tensors coming from forward end up being accidentally treated as lifted variables (and rejected). This is sound, but a little hacky feeling. Additionally, due to some limitations in fx node removal, combined with how we produce subgraphs for the traces installed from HigherOrderOperators, we had to improve our node removal logic. In the event of a restore, we remove the old nodes from the graph, as usual in dynamo. However, because the references to these nodes may exist in subgraphs, we traverse any nodes users and remove them first if and only if they are in another graph. This is always sound, because removal should only be downstream of restoration at this point. cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: 533459c702becc9b5cc49e06658dceaa34e540ed Pull Request resolved: #99483 Rm test after fix Refactor Refactor Refactor Refactor Fix tests Fix tests Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Rethink ghstack-source-id: 533459c702becc9b5cc49e06658dceaa34e540ed Pull Request resolved: #101021 Rethink Rethink Rethink Rethink Fixes Fixes Fixes Feedback Fixes Fixes Fixes Fixes Fixes Fixes lint Fixes
This PR adds support for tracing autograd.Function with grad. A few important bullet points outlining our approach: 1) Our goal is to verify soundness in order to add a call_function to the autograd.Function's `apply` to the graph. 2) We achieve (1) by either verifying soundness or rejecting soundness, by ensuring that both forward and backward of the autograd.Function are sound. 3) For the forward, if we verify soundness, we install its guards into the graph. 4) For the backward, if we verify soundness, we throw it out. However, backwards soundness verification is more onerous, and has a config driven set of banned attrs and methods for tensors. 1-4 above are achieved by turning the forward and backward into UserDefinedFunctionVariables, and inlining through them, relying on dynamo's soundness detection. If we graph break in these, we raise and treat them as unsound. As noted above, backwards is stricter yet. For the tracing, the safety comes from dynamo's HigherOrderOperator system. That system ensures that not only do we trace soundly, but that no new variables are lifted into inputs during the tracing, and that the forward and backwards are entirely self contained. Whenever we reject a function as unsound, we restore back, as usual. Due to some limitations in the lifting logic, we have an escape hatch we implemented for tensors that are known in forward, but cross into backwards through save_tensors (save) /saved_tensors (load). We escape hatch here to avoid having the known saved tensors coming from forward end up being accidentally treated as lifted variables (and rejected). This is sound, but a little hacky feeling. Additionally, due to some limitations in fx node removal, combined with how we produce subgraphs for the traces installed from HigherOrderOperators, we had to improve our node removal logic. In the event of a restore, we remove the old nodes from the graph, as usual in dynamo. However, because the references to these nodes may exist in subgraphs, we traverse any nodes users and remove them first if and only if they are in another graph. This is always sound, because removal should only be downstream of restoration at this point. cc soumith penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training ghstack-source-id: 651fb247d42da6eb241908e8d3bc9f61487523f3 Pull Request resolved: #99483 Rm test after fix Refactor Refactor Refactor Refactor Fix tests Fix tests Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Feedback Rethink ghstack-source-id: 651fb247d42da6eb241908e8d3bc9f61487523f3 Pull Request resolved: #101021 Rethink Rethink Rethink Rethink Fixes Fixes Fixes Feedback Fixes Fixes Fixes Fixes Fixes Fixes lint Fixes Fixes
@pytorchbot merge -f "Broken base cse test on macos, 180 other test suites passed" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This PR adds support for tracing autograd.Function with grad.
A few important bullet points outlining our approach:
apply
to the graph.1-4 above are achieved by turning the forward and backward into UserDefinedFunctionVariables, and inlining through them, relying on dynamo's soundness detection. If we graph break in these, we raise and treat them as unsound. As noted above, backwards is stricter yet.
For the tracing, the safety comes from dynamo's HigherOrderOperator system. That system ensures that not only do we trace soundly, but that no new variables are lifted into inputs during the tracing, and that the forward and backwards are entirely self contained.
Whenever we reject a function as unsound, we restore back, as usual.
Due to some limitations in the lifting logic, we have an escape hatch we implemented for tensors that are known in forward, but cross into backwards through save_tensors (save) /saved_tensors (load). We escape hatch here to avoid having the known saved tensors coming from forward end up being accidentally treated as lifted variables (and rejected). This is sound, but a little hacky feeling.
Additionally, due to some limitations in fx node removal, combined with how we produce subgraphs for the traces installed from HigherOrderOperators, we had to improve our node removal logic. In the event of a restore, we remove the old nodes from the graph, as usual in dynamo. However, because the references to these nodes may exist in subgraphs, we traverse any nodes users and remove them first if and only if they are in another graph. This is always sound, because removal should only be downstream of restoration at this point.
cc @soumith @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire