-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial Flash Attention support on ROCM #114309
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114309
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f75038b with merge base 310f6ab (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label ciflow/rocm |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
5693b15
to
d38d585
Compare
d38d585
to
f252833
Compare
CI Failures are likely irrelevant to this PR |
@malfet Hi Nikita, please add anyone else who would be relevant as a reviewer. |
@pytorchbot revert -m "trunk tests are failing" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
Reverting PR 114309 failedReason: 1 mandatory check(s) are pending/not yet run. The first few are:
Dig deeper by viewing the pending checks on hud |
/easycla |
@pytorchbot revert -m "trunk tests are failing" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
Reverting PR 114309 failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
This reverts commit 5bddbed.
This reverts commit 5bddbed. Pull Request resolved: #115975 Approved by: https://github.com/atalman, https://github.com/malfet
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes pytorch#112997 Pull Request resolved: pytorch#114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
…torch#115975) This reverts commit 5bddbed. Pull Request resolved: pytorch#115975 Approved by: https://github.com/atalman, https://github.com/malfet
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes pytorch#112997 Pull Request resolved: pytorch#114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
…torch#115975) This reverts commit 5bddbed. Pull Request resolved: pytorch#115975 Approved by: https://github.com/atalman, https://github.com/malfet
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes pytorch#112997 Pull Request resolved: pytorch#114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
…torch#115975) This reverts commit 5bddbed. Pull Request resolved: pytorch#115975 Approved by: https://github.com/atalman, https://github.com/malfet
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes pytorch#112997 Pull Request resolved: pytorch#114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
…torch#115975) This reverts commit 5bddbed. Pull Request resolved: pytorch#115975 Approved by: https://github.com/atalman, https://github.com/malfet
Note about the Updates: This PR: 1. skips more flash attention related UTs on MI200 2. Fix additional ATen compiling errors after hipification 3. Fix the author "root" of a specific commit 4. Includes the patch from Nikita in favor of block level static initialization. CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge. Original PR (#114309) Note: This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - Only supports power of two sequence lengths. - No support for varlen APIs. - Only support head dimension 16,32,64,128. - Performance is still being optimized. Fixes #112997 Pull Request resolved: #115981 Approved by: https://github.com/malfet
@xinyazhang It got |
Hi, this is a bug from the build system changes related to AOTriton. Our internal branch already has fixes along with AOTriton V2 Integration. They will be sent to upstream as PR when the whole feature set is completed. |
@xinyazhang Thanks for the information! Any ETA please? |
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
gcnArchName == gfx90a:sramecc+:xnack-
.Fixes #112997
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang