New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix inference mode / PyDispatcher / Functionalize interaction #103275
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103275
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8871c77: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 5b7cf0d317ce5dcbbf93c7a76aa8dcfef9f38c21 Pull Request resolved: #103275
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OHHH this is a good catch
torch/_ops.py
Outdated
# to use it, instead of the C++ decomp. We can't though, because Functionalize | ||
# isn't part of the CompositeImplicitAutograd alias set. | ||
# (open quesetion: will we eventually need to do this for functorch transform keys too?) | ||
self.py_kernels[torch._C.DispatchKey.Functionalize] = fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A less hacky version of this would be to modify our py_impl CompositeImplicitAutograd sites to call some higher level function which takes care of doing both registrations. This would be good because the hack as written violates the invariant that one py_impl does one registration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm the only thing I'm worried about is that people might forget to use that new API, and continuing to register future ops with py_impl(CompositeImplicitAutograd)
(which will now always be subtly wrong). If you're not worried about that though, then I can change it (or.. maybe just have py_impl()
give you a nice error message if you pass in that key?)
torch/_ops.py
Outdated
@@ -122,6 +122,14 @@ def inner(fn): | |||
f"Trying to override a python impl for {k} on operator {self.name()}" | |||
) | |||
self.py_kernels[k] = fn | |||
if k == torch._C.DispatchKey.CompositeImplicitAutograd and torch._C.DispatchKey.Functionalize not in self.py_kernels: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to remember now why we didn't just add Functionalize (and all of the other functorch transform keys) directly to the CompositeImplicitAutograd alias keyset, but I couldn't remember (@zou3519 any chance you remember?).
This PR is good as-is, but if we want to be less hacky we should just make Functionalize a part of the CompositeImplicitAutograd alias keyset (and maybe we should rename it since it is now more than just autograd?)
Not all CompositeImplicitAutograd operations work with vmap (due to "not preserve Tensor subclass-ness"), which is why vmap isn't there. Although we've fixed most of these cases, I am wary of actually adding vmap to the CompositeImplicitAutograd set because it is unclear what % of aten ops our OpInfos actually cover.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This suggestion, if it works, is better!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - I'll try it and see if there's fallout.
There was actually a comment about it (hooray!) here. At the time, at::ones()
and friends were all CompositeImplicitAutograd
, and decomposed into empty()
+ fill_()
. I think we didn't want functionalization to decompose ones() since it would result in a bunch of unnecessary functionalization logic running. But this shouldn't really matter anymore, since those factory functions all got changed to be CompositeExplicitAutograd
.
…ion" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov [ghstack-poisoned]
ghstack-source-id: 834bb63738908aee2ad4cf0471d73d76c8dd9a7f Pull Request resolved: #103275
…ion" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov [ghstack-poisoned]
ghstack-source-id: cf06aa9bd5ce15c2d6ebd36aea239514c5a38e96 Pull Request resolved: #103275
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
…ion" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Successfully rebased |
ghstack-source-id: f54feef356e9ee866298d53c3f040c10d612780b Pull Request resolved: #103275
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #103132
This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to
CompositeImplicitAutograd
, and a C++ decomp registered directly to theFunctionalize
key, so the C++ decomp gets precedence over the python decomp.The way this showed up was that a model was running
matmul()
under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call.sizes()
on a tensor with dynamic shapes" error.For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the
CompositeImplicitAutograd
key, this PR just automatically registers that decomp to theFunctionalize
key at the same time.I'm trying to remember now why we didn't just add
Functionalize
(and all of the other functorch transform keys) directly to theCompositeImplicitAutograd
alias keyset, but I couldn't remember (@zou3519 any chance you remember?).Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov