New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(fx): make all make_fx
invocations isolated (opaque to higher make_fx
invocations) by default
#93290
fix(fx): make all make_fx
invocations isolated (opaque to higher make_fx
invocations) by default
#93290
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/93290
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9b7b95d: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I... guess we can do this? It kind of feels better to not by default (and let get_isolated_subgraph be used for this case) because if you do it this way, there is no way to have a single trace get recorded by multiple proxy tensor modes at once. This may seem like a weird thing to want but in fact functorch nested grad does something like this (where nested grad levels get recorded onto multiple tapes one per grad level). The countervailing argument is make_fx doesn't return its outputs, so therefore it is not true compute and shouldn't get traced. I could be convinced by this. |
Actually, when we call So this PR will isolate graphmodules by default. I believe this is a good change as making a tape of making a tape doesn't seem sensible, and produces the weird artifacts observed. (Unless: an inner tape is used in the definition of the exterior function?) |
make_fx
traces for nested make_fx
make_fx
traces except current for nested make_fx
make_fx
traces except current for nested make_fx
make_fx
invocations isolated (opaque to higher make_fx
invocations) by default
This is actually impossible. |
Could you point to an example? I tried constructing several examples where things might fail, including nesting
There is basically no compilation pathway that has nested tapes that rely directly on an interior tape. Because it seems that |
It's not make_fx per se, but tape recording (which is like fx but a bit different). https://github.com/albanD/subclass_zoo/blob/main/simple_functorch.ipynb search for "to compute higher order gradients". But yeah, consider me convinced, we can land this |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed (Rule Dig deeper by viewing the failures on hud Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #88996 (comment)
Example code:
Before:
After: