Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Flash Attention support on ROCM #114309

Merged
merged 166 commits into from
Dec 14, 2023
Merged

Conversation

xinyazhang
Copy link
Contributor

@xinyazhang xinyazhang commented Nov 21, 2023

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

  • Only supports MI200 series GPU (i.e., gcnArchName == gfx90a:sramecc+:xnack-.
  • Only supports power of two sequence lengths.
  • No support for varlen APIs.
  • Only support head dimension 16,32,64,128.
  • Performance is still being optimized.

Fixes #112997

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang

Copy link

pytorch-bot bot commented Nov 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114309

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f75038b with merge base 310f6ab (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/rocm module: rocm AMD GPU support for Pytorch labels Nov 21, 2023
@jithunnair-amd jithunnair-amd added the release notes: rocm mandatorylabel label Nov 21, 2023
@xinyazhang
Copy link
Contributor Author

@pytorchbot label ciflow/rocm

@jithunnair-amd jithunnair-amd added the keep-going Don't stop on first failure, keep running tests until the end label Nov 28, 2023
@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased xinyazhang/up-fa-mathaot onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout xinyazhang/up-fa-mathaot && git pull --rebase)

@xinyazhang
Copy link
Contributor Author

xinyazhang commented Nov 29, 2023

CI Failures are likely irrelevant to this PR

@jithunnair-amd
Copy link
Collaborator

@malfet Hi Nikita, please add anyone else who would be relevant as a reviewer.

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 1, 2023
@xinyazhang
Copy link
Contributor Author

@mruberry @ngimel can you review this PR? These failures (dynamos, inductors, etc.) are unlikely related to these changes

CMakeLists.txt Outdated Show resolved Hide resolved
test/test_transformers.py Outdated Show resolved Hide resolved
@atalman
Copy link
Contributor

atalman commented Dec 16, 2023

@pytorchbot revert -m "trunk tests are failing" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 114309 failed

Reason: 1 mandatory check(s) are pending/not yet run. The first few are:

  • EasyCLA

Dig deeper by viewing the pending checks on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@atalman
Copy link
Contributor

atalman commented Dec 16, 2023

/easycla

@atalman
Copy link
Contributor

atalman commented Dec 16, 2023

@pytorchbot revert -m "trunk tests are failing" -c nosignal

Copy link

CLA Missing ID CLA Not Signed

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 114309 failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

jeffdaily added a commit to ROCm/pytorch that referenced this pull request Dec 16, 2023
pytorchmergebot pushed a commit that referenced this pull request Dec 16, 2023
guilhermeleobas pushed a commit to guilhermeleobas/pytorch that referenced this pull request Dec 18, 2023
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes pytorch#112997

Pull Request resolved: pytorch#114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
guilhermeleobas pushed a commit to guilhermeleobas/pytorch that referenced this pull request Dec 18, 2023
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes pytorch#112997

Pull Request resolved: pytorch#114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes pytorch#112997

Pull Request resolved: pytorch#114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
ZhiweiYan-96 pushed a commit to ZhiweiYan-96/pytorch that referenced this pull request Dec 22, 2023
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes pytorch#112997

Pull Request resolved: pytorch#114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
ZhiweiYan-96 pushed a commit to ZhiweiYan-96/pytorch that referenced this pull request Dec 22, 2023
pytorchmergebot pushed a commit that referenced this pull request Jan 4, 2024
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (#114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: #115981
Approved by: https://github.com/malfet
@netw0rkf10w
Copy link

@xinyazhang It got UserWarning: 1Torch was not compiled with flash attention with the latest nightly on MI250x. Could you tell me if this is expected? Thanks!

@xinyazhang
Copy link
Contributor Author

xinyazhang commented Feb 12, 2024

@xinyazhang It got UserWarning: 1Torch was not compiled with flash attention with the latest nightly on MI250x. Could you tell me if this is expected? Thanks!

Hi, this is a bug from the build system changes related to AOTriton. Our internal branch already has fixes along with AOTriton V2 Integration. They will be sent to upstream as PR when the whole feature set is completed.

@netw0rkf10w
Copy link

@xinyazhang Thanks for the information! Any ETA please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/binaries_wheel Trigger binary build and upload jobs for wheel on the PR ciflow/rocm ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for Flash Attention for AMD/ROCm