-
Notifications
You must be signed in to change notification settings - Fork 21.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor][codegen] Codegen constexpr globals and constexpr annotated globals correctly. #126195
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126195
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 6 Unrelated FailuresAs of commit 15e5551 with merge base ade0754 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This is needed after moving the pin in #126098 I split it out since it does not seem to break anything with the current pin. |
…ser defined kernel source [Triton pytorch#3762](triton-lang/triton#3762) disallows access to globals which are not `tl.constexpr` Triton has always treated captured globals this way, but they now require it be explicit in user code. Updated codegen to make sure these variables are defined before writing the kernel source when compiling a user defined triton kernel. ghstack-source-id: 5fc2a44f6d39a699d704ca1f311eb84770af8647 Pull Request resolved: pytorch#126195
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this latest version locally with the CUDA backend and XPU backends and both now pass the inductor test_triton_kernels
/test_triton_kernel_constants
test with a recent Triton/TritonXPU commit.
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-13-py3-arm64 / build Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "failure looks unrelated, I don't think triton can be involved with the macOS build" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… globals correctly. (pytorch#126195) [Triton pytorch#3762](triton-lang/triton#3762) disallows access to globals which are not `tl.constexpr` Triton has always treated captured globals this way, but they now require it be explicit in user code. Updated codegen to make sure these variables are defined before writing the kernel source when compiling a user defined triton kernel. Pull Request resolved: pytorch#126195 Approved by: https://github.com/alexbaden, https://github.com/bertmaher
Stack from ghstack (oldest at bottom):
Triton #3762
disallows access to globals which are not
tl.constexpr
Triton has always treated captured globals this way, but they now
require it be explicit in user code.
Updated codegen to make sure these variables are defined before writing
the kernel source when compiling a user defined triton kernel.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @desertfire @chauhang