Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor][codegen] Codegen constexpr globals and constexpr annotated globals correctly. #126195

Closed
wants to merge 10 commits into from

Conversation

amjames
Copy link
Collaborator

@amjames amjames commented May 14, 2024

Stack from ghstack (oldest at bottom):

Triton #3762
disallows access to globals which are not tl.constexpr

Triton has always treated captured globals this way, but they now
require it be explicit in user code.

Updated codegen to make sure these variables are defined before writing
the kernel source when compiling a user defined triton kernel.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @desertfire @chauhang

[ghstack-poisoned]
@amjames amjames mentioned this pull request May 14, 2024
Copy link

pytorch-bot bot commented May 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126195

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 6 Unrelated Failures

As of commit 15e5551 with merge base ade0754 (image):

NEW FAILURE - The following job has failed:

  • trunk / macos-13-py3-arm64 / build (gh)
    /Users/ec2-user/runner/_work/pytorch/pytorch/c10/util/StringUtil.cpp:45:8: error: 'wstring_convert<std::codecvt_utf8_utf16<wchar_t>>' is deprecated [-Werror,-Wdeprecated-declarations]

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@amjames
Copy link
Collaborator Author

amjames commented May 14, 2024

This is needed after moving the pin in #126098 I split it out since it does not seem to break anything with the current pin.

[ghstack-poisoned]
amjames added a commit to amjames/pytorch that referenced this pull request May 16, 2024
…ser defined kernel source

[Triton pytorch#3762](triton-lang/triton#3762)
disallows access to globals which are not `tl.constexpr`

Triton has always treated captured globals this way, but they now
require it be explicit in user code.

Updated codegen to make sure these variables are defined before writing
the kernel source when compiling a user defined triton kernel.

ghstack-source-id: 5fc2a44f6d39a699d704ca1f311eb84770af8647
Pull Request resolved: pytorch#126195
[ghstack-poisoned]
[ghstack-poisoned]
@amjames amjames changed the title [inductor][codegen] Define constexpr globals with annotation before user defined kernel source [inductor][codegen] Codegen constexpr globals and constexpr annotated globals correctly. May 16, 2024
[ghstack-poisoned]
@amjames amjames requested a review from peterbell10 May 17, 2024 12:46
Copy link
Contributor

@alexbaden alexbaden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this latest version locally with the CUDA backend and XPU backends and both now pass the inductor test_triton_kernels/test_triton_kernel_constants test with a recent Triton/TritonXPU commit.

test/inductor/test_triton_kernels.py Outdated Show resolved Hide resolved
test/inductor/test_triton_kernels.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/wrapper.py Outdated Show resolved Hide resolved
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@bertmaher
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 4, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@bertmaher
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-13-py3-arm64 / build

Details for Dev Infra team Raised by workflow job

@bertmaher
Copy link
Contributor

@pytorchbot merge -f "failure looks unrelated, I don't think triton can be involved with the macOS build"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
… globals correctly. (pytorch#126195)

[Triton pytorch#3762](triton-lang/triton#3762)
disallows access to globals which are not `tl.constexpr`

Triton has always treated captured globals this way, but they now
require it be explicit in user code.

Updated codegen to make sure these variables are defined before writing
the kernel source when compiling a user defined triton kernel.

Pull Request resolved: pytorch#126195
Approved by: https://github.com/alexbaden, https://github.com/bertmaher
@github-actions github-actions bot deleted the gh/amjames/22/head branch July 7, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants