-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[inductor] Bypass FX graph cache when we have HigherOrderOperators #123325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123325
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f4dde59 with merge base 19f5033 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…rators" Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…perators" Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton ghstack-source-id: 8496a9d Pull Request resolved: #123325
…perators" Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton ghstack-source-id: 5b8bc4c Pull Request resolved: #123325
…perators" Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton ghstack-source-id: 2018983 Pull Request resolved: #123325
test/inductor/test_codecache.py
Outdated
""" | ||
|
||
@triton.jit | ||
def add_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to redefine this, import it from torch/testing/_internal/triton_utils.py
# HigherOrderOperators should be handled on a case-by-case basis. | ||
# Currently, we just skip caching if we have any. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be skipping all of them by default? i assume cond/map etc should just work out of the box
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was @eellison's suggestion. Elias, what do you think? What is the set of higher order ops? Is it currently just triton + cond/map, or are there others as well? And if there are others, is there a criteria to distinguish which we expect to work vs. not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A subset of all higher order ops: https://fburl.com/biggrep/h06kfrgg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oulgen Ah thanks. So there are several. So do you think I should allowlist cond/map here specifically, i.e., because they're important? Or were those just two semi-random examples and you're advocating to allow more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect they are more popular HOPs but perhaps @zou3519 would have a more definitive answer on what is safe here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm understanding correctly, we cache compiled FX graph from inductor. This PR is saying that a graph that contains HOPs cannot be cached by default.
If so, that sounds fine to me -- caching for things like torch.cond get a bit complicated because you also want to walk through the subgraphs. cc @aakhundov
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 so Sam added a cache from FX graph to inductor generated python file. This does not work for triton HOPs because the idx on the dynamo side table needs to match which is tricky. But for torch.cond etc everything is on the graph, so I think that should be fine to cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the cache from FX GraphModule to inductor generated python file? (And do we recurse into the submodules of the FX GraphModule? That's what's needed to get other HOPS like cond to work)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the cache from FX GraphModule to inductor generated python file?
Exactly.
This PR is saying that a graph that contains HOPs cannot be cached by default.
Correct. In other words, if we find find any HOPs, then bypass the cache altogether and compile normally.
And do we recurse into the submodules of the FX GraphModule?
Hmmm, I guess I still don't know enough about inductor compilation to answer that. Maybe @eellison can help give a definite answer. Is the question about how we compute the key to look up entries or the makeup of the cache entries themselves? I don't know if a code pointer will help, but this is where the cache plugs in: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/compile_fx.py#L464-L471. So at a high level, we compute the cache key from the FX GraphModule (the string representation of the code actually) and the inputs, and then we cache the output of fx_codegen_and_compile()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.cond and torch.while could be supported, but they dont work "out of the box". we'd need to recurse into the subgraphs. maybe we can do that automatically, but it does not work as is today. I think it's fine to add safety in this pr and follow up and add support.
It could also make more sense for folks who are adding the HOP to add the cache key then sam - they will know better about the details of what needs to be cached.
…gherOrderOperators" Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…rators" Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton ghstack-source-id: 4fdfc37 Pull Request resolved: #123325
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#123325) Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton Pull Request resolved: pytorch#123325 Approved by: https://github.com/eellison
…ytorch#123325) Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary. Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton Pull Request resolved: pytorch#123325 Approved by: https://github.com/eellison
Stack from ghstack (oldest at bottom):
Summary: The initial motivation was to avoid caching when we have triton higher order ops, but it's probably safer to avoid the cache for all higher order ops and allow/implement if/when we find it necessary.
Test Plan: Unit test cribbed from: https://docs-preview.pytorch.org/pytorch/tutorials/2783/recipes/torch_compile_user_defined_triton_kernel_tutorial.html?highlight=triton
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang