-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[AOTAutograd] Use _set_grad_enabled instead of no_grad #128183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128183
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (3 Unrelated Failures)As of commit 6306c3f with merge base 9554300 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This saves ~1us of overhead from each inductor graph call. [ghstack-poisoned]
t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy() | ||
if runtime_metadata.grad_enabled_mutation is not None: | ||
torch.set_grad_enabled(runtime_metadata.grad_enabled_mutation) | ||
torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there so much overhead calling torch.set_grad_enabled
vs calling the C API directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [1]: import torch
...: %timeit torch.set_grad_enabled(False)
...: %timeit torch._C._set_grad_enabled(False)
536 ns ± 3.02 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
217 ns ± 1.28 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
It's 2.5 x slower, mostly because torch.set_grad_enabled
is a context manager object that pretends to be a normal function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the wonders of python
This saves ~1us of overhead from each inductor graph call. ghstack-source-id: 217cfd7 Pull Request resolved: pytorch#128183
This saves ~1us of overhead from each inductor graph call. ghstack-source-id: a6dd843 Pull Request resolved: pytorch#128183
`gen_alias_from_base` spends about ~0.5 us in this import statement, which is called for each view in the graph output. Pull Request resolved: #128184 Approved by: https://github.com/lezcano ghstack dependencies: #128183
Going through the dispatcher + pybind11 + torch.ops adds about 2 us overhead per call compared to `PyArgParser`. Note that views of inputs are reconstructed by AOTAutograd before being returned to the python code, so dispatching for autograd's sake shouldn't be required here. Pull Request resolved: #128185 Approved by: https://github.com/lezcano ghstack dependencies: #128183, #128184
This saves ~1us of overhead from each inductor graph call. Pull Request resolved: pytorch#128183 Approved by: https://github.com/lezcano
…h#128184) `gen_alias_from_base` spends about ~0.5 us in this import statement, which is called for each view in the graph output. Pull Request resolved: pytorch#128184 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#128183
…8185) Going through the dispatcher + pybind11 + torch.ops adds about 2 us overhead per call compared to `PyArgParser`. Note that views of inputs are reconstructed by AOTAutograd before being returned to the python code, so dispatching for autograd's sake shouldn't be required here. Pull Request resolved: pytorch#128185 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#128183, pytorch#128184
… from C++ (pytorch#128187) Marginal overhead reduction when calling through the `torch.ops` API. Pull Request resolved: pytorch#128187 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#128183, pytorch#128184, pytorch#128185
Stack from ghstack (oldest at bottom):
This saves ~1us of overhead from each inductor graph call.