-
Notifications
You must be signed in to change notification settings - Fork 24.9k
distributed: templated ring attention #124215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124215
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 84e2834 with merge base b5d4ebe ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable, have a few suggestions and questions inlined
], | ||
) | ||
|
||
out.sum().backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we assert the single gpu query.grad
same as dquery.grad.full_tensor()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to in this test -- the other tests validate the ring attention accuracy for both forwards and backwards. This is intended to test that the different attention ops have equivalent behavior. The backwards pass uses autograd so I think there's little risk of an issue.
@@ -54,6 +54,10 @@ def _merge_sdpa( | |||
""" | |||
assert len(chunks) == len(logsumexps) | |||
|
|||
# LSE may be padded in the sequence dimension such as with memory efficient attention. | |||
seq_len = chunks[0].size(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a no-op in other attention op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes exactly -- it's already size seq_len in flash attention. mem efficient attention the length is aligned to 32 len
@@ -704,5 +705,65 @@ def run_with_backward(): | |||
self.assertIsNotNone(t.grad) | |||
|
|||
|
|||
class TestFunctionalAutogradWithNCCL(MultiProcessTestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should just inherit from TestCollectivesWithNCCL
to unify the setups?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TestCollectivesWithNCCL has tests in it which we don't want to run again
0ee8bde
to
cf93d8b
Compare
eecec31
to
01669f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! one more comment inlined
01669f6
to
05c940a
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
05c940a
to
84e2834
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled. If you believe this is a mistake, then you can re trigger it through pytorch-bot. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR. This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way. Misc changes: * Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test * Adds compile support to the ring attention implementations (required some tweaks to process groups) Test plan: ``` pytest test/distributed/_tensor/test_attention.py pytest test/distributed/test_functional_api.py ``` Pull Request resolved: pytorch#124215 Approved by: https://github.com/wanchaol
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR. This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way. Misc changes: * Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test * Adds compile support to the ring attention implementations (required some tweaks to process groups) Test plan: ``` pytest test/distributed/_tensor/test_attention.py pytest test/distributed/test_functional_api.py ``` Pull Request resolved: pytorch#124215 Approved by: https://github.com/wanchaol
@d4l3k We are seeing failures related to distributed ring attention on ROCm CI, and this PR seems to be related to those changes. An example of a CI failure: https://github.com/pytorch/pytorch/actions/runs/9099620639/job/25016328856#step:15:10702
Looks like we need to ensure the |
@jithunnair-amd all_to_all_single is registered for the generic As far as I can tell there's nothing that's CUDA specific -- other than maybe we're setting the wrong device_type in the unit test? https://github.com/pytorch/pytorch/blob/main/test/distributed/_tensor/test_attention.py#L174-L176 |
@d4l3k it seems that _templated_ring_attention doesn't support attn_mask,right? |
@zhangvia this is right, the only masking we currently support is causal masking. We expect to support flex attention soon which will support attn_mask. |
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.
This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.
Misc changes:
Test plan:
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang