-
Notifications
You must be signed in to change notification settings - Fork 25.4k
[Inductor-FX] Support unbacked symbol definitions #163729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163729
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3ab437c with merge base 5f90e8c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
call_args = add_constants_to_call_args(call_args, kernel_config) | ||
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) | ||
call_kwargs = dict(zip(signature, call_args)) | ||
assert not any(kwarg in kernel_config.kwargs for kwarg in call_kwargs), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test case in this PR exposed a very tricky bug related to arg/kwarg handling. Although this should no longer happen with the updated method, this assert guards against that bug if by some off chance we missed something.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Problem
Inductor sometimes generates unbacked symints to handle things like mismatched branches of
torch.cond
. This code is represented bypytree.KeyPath
, with special codegen logic to convert it to Python and C++. This was not previously supported by the FX backend.Feature
This PR adds support for unbacked symbol declarations to the FX backend. The implementation is fairly straightforward.
UnbackedSymbolDefsLine
. This contains all the information needed to generate the Python and C++ code.UnbackedSymbolDefsLine.codegen()
.triton_meta
. This PR rewrites the relevant helper function to do this in a more principled way.Test plan
This PR imports an existing control flow test to the FX backend's test suite. The test uses unbacked symbol definitions to handle mismatched dynamic shapes coming from
torch.cond
branches.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben