Skip to content

distributed: templated ring attention #124215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Conversation

d4l3k
Copy link
Member

@d4l3k d4l3k commented Apr 16, 2024

This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

  • Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
  • Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

@d4l3k d4l3k requested review from yifuwang, wanchaol and drisspg April 16, 2024 21:04
Copy link

pytorch-bot bot commented Apr 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124215

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 84e2834 with merge base b5d4ebe (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Apr 16, 2024
@d4l3k d4l3k requested a review from yoyoyocmu April 16, 2024 21:08
@d4l3k d4l3k added the topic: not user facing topic category label Apr 16, 2024
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable, have a few suggestions and questions inlined

],
)

out.sum().backward()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we assert the single gpu query.grad same as dquery.grad.full_tensor()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to in this test -- the other tests validate the ring attention accuracy for both forwards and backwards. This is intended to test that the different attention ops have equivalent behavior. The backwards pass uses autograd so I think there's little risk of an issue.

@@ -54,6 +54,10 @@ def _merge_sdpa(
"""
assert len(chunks) == len(logsumexps)

# LSE may be padded in the sequence dimension such as with memory efficient attention.
seq_len = chunks[0].size(2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a no-op in other attention op?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly -- it's already size seq_len in flash attention. mem efficient attention the length is aligned to 32 len

@@ -704,5 +705,65 @@ def run_with_backward():
self.assertIsNotNone(t.grad)


class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should just inherit from TestCollectivesWithNCCL to unify the setups?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestCollectivesWithNCCL has tests in it which we don't want to run again

@d4l3k d4l3k force-pushed the d4l3k/template_attention branch 2 times, most recently from 0ee8bde to cf93d8b Compare April 17, 2024 17:20
@d4l3k d4l3k requested a review from wanchaol April 17, 2024 17:22
@d4l3k d4l3k force-pushed the d4l3k/template_attention branch 2 times, most recently from eecec31 to 01669f6 Compare April 17, 2024 20:53
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! one more comment inlined

@d4l3k d4l3k force-pushed the d4l3k/template_attention branch from 01669f6 to 05c940a Compare April 18, 2024 00:14
@d4l3k
Copy link
Member Author

d4l3k commented Apr 18, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 18, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@d4l3k
Copy link
Member Author

d4l3k commented Apr 18, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@d4l3k
Copy link
Member Author

d4l3k commented Apr 18, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased d4l3k/template_attention onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout d4l3k/template_attention && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the d4l3k/template_attention branch from 05c940a to 84e2834 Compare April 18, 2024 17:52
@d4l3k
Copy link
Member Author

d4l3k commented Apr 18, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake, then you can re trigger it through pytorch-bot.

@d4l3k
Copy link
Member Author

d4l3k commented Apr 19, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@d4l3k d4l3k deleted the d4l3k/template_attention branch April 19, 2024 17:11
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: pytorch#124215
Approved by: https://github.com/wanchaol
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: pytorch#124215
Approved by: https://github.com/wanchaol
@jithunnair-amd
Copy link
Collaborator

@d4l3k We are seeing failures related to distributed ring attention on ROCm CI, and this PR seems to be related to those changes. An example of a CI failure: https://github.com/pytorch/pytorch/actions/runs/9099620639/job/25016328856#step:15:10702

NotImplementedError: Could not run '_c10d_functional_autograd::all_to_all_single' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. '_c10d_functional_autograd::all_to_all_single' is only available for these backends: [Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Looks like we need to ensure the _c10d_functional_autograd::all_to_all_single operator is built for the CUDA backend even for ROCm builds. Can you help point to the relevant location where that might need to be done?

@d4l3k
Copy link
Member Author

d4l3k commented May 15, 2024

@jithunnair-amd all_to_all_single is registered for the generic Autograd backend and not CUDA specific: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/Functional.cpp#L562

As far as I can tell there's nothing that's CUDA specific -- other than maybe we're setting the wrong device_type in the unit test? https://github.com/pytorch/pytorch/blob/main/test/distributed/_tensor/test_attention.py#L174-L176

@zhangvia
Copy link

@d4l3k it seems that _templated_ring_attention doesn't support attn_mask,right?

@XilunWu
Copy link
Contributor

XilunWu commented Feb 3, 2025

@zhangvia this is right, the only masking we currently support is causal masking. We expect to support flex attention soon which will support attn_mask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants