Skip to content

[inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack #158459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Jul 16, 2025

Stack from ghstack (oldest at bottom):

Update triton commit hash to 11ec6354315768a85da41032535e3b7b99c5f706, which is the new release/3.4.x branch in triton-lang/triton.

Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler.

This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and [] as parameters.

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 16, 2025
In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler.

This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and [] as parameters.

ghstack-source-id: 8691928
Pull Request resolved: #158459
Copy link

pytorch-bot bot commented Jul 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158459

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 7 Unrelated Failures

As of commit 93b42d6 with merge base e311886 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@atalman
Copy link
Contributor

atalman commented Jul 16, 2025

@davidberard98 please set triton hash to: 11ec6354315768a85da41032535e3b7b99c5f706 as in https://github.com/pytorch/pytorch/pull/158442/files

…nfig params"

In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler.

This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and [] as parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 16, 2025
In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler.

This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and [] as parameters.

ghstack-source-id: df12bd3
Pull Request resolved: #158459
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 16, 2025
@davidberard98 davidberard98 changed the title [inductor][triton] Update HAS_WARP_SPEC to check triton.Config params [triton] New 3.4 release pin && update HAS_WARP_SPEC to check triton.Config params Jul 16, 2025
@atalman atalman changed the title [triton] New 3.4 release pin && update HAS_WARP_SPEC to check triton.Config params [inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack Jul 16, 2025
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@davidberard98
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

[ghstack-poisoned]
@atalman atalman added the ci-no-td Do not run TD on this PR label Jul 16, 2025
…nfig params. Update Triton Hash to top of release/3.4.x stack"


Update triton commit hash to `11ec6354315768a85da41032535e3b7b99c5f706`, which is the new release/3.4.x branch in triton-lang/triton.

Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 16, 2025
In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler.

This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and [] as parameters.

ghstack-source-id: d506851
Pull Request resolved: #158459
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

…nfig params. Update Triton Hash to top of release/3.4.x stack"


Update triton commit hash to `11ec6354315768a85da41032535e3b7b99c5f706`, which is the new release/3.4.x branch in triton-lang/triton.

Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 17, 2025
In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler.

This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and [] as parameters.

ghstack-source-id: d2f46ea
Pull Request resolved: #158459
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test

Details for Dev Infra team Raised by workflow job

@atalman
Copy link
Contributor

atalman commented Jul 17, 2025

@pytorchbot merge -f "lint and workflows are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@atalman
Copy link
Contributor

atalman commented Jul 17, 2025

@pytorchbot cherry-pick --onto release/2.8 -c release

@pytorchbot
Copy link
Collaborator

Cherry picking #158459

Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 7892f5a007859ae02b9cac5441cd51f37147ef04 returned non-zero exit code 1

Auto-merging test/inductor/test_static_cuda_launcher.py
CONFLICT (content): Merge conflict in test/inductor/test_static_cuda_launcher.py
Auto-merging torch/_inductor/runtime/triton_compat.py
error: could not apply 7892f5a0078... [inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack (#158459)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

atalman pushed a commit to atalman/pytorch that referenced this pull request Jul 18, 2025
…. Update Triton Hash to top of release/3.4.x stack (pytorch#158459)

Update triton commit hash to `11ec6354315768a85da41032535e3b7b99c5f706`, which is the new release/3.4.x branch in triton-lang/triton.

Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters.

Pull Request resolved: pytorch#158459
Approved by: https://github.com/atalman
atalman added a commit that referenced this pull request Jul 18, 2025
…Config params. Update Triton Hash to top of release/3.4.x stack (#158646)

* [inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack (#158459)

Update triton commit hash to `11ec6354315768a85da41032535e3b7b99c5f706`, which is the new release/3.4.x branch in triton-lang/triton.

Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters.

Pull Request resolved: #158459
Approved by: https://github.com/atalman

* dont_upde_hash

* Revert "dont_upde_hash"

This reverts commit 5fffb12.

* fix_docker_builds

---------

Co-authored-by: David Berard <dberard@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/h100 ciflow/inductor ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/triton_binaries ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: inductor release notes: inductor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants