Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Traceable FSDP2] Add all_gather_into_tensor out variant #126334

Closed
wants to merge 4 commits into from

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented May 15, 2024

This PR adds torch.ops._c10d_functional.all_gather_into_tensor_out.

It's important for tracing FSDP2, because FSDP2 pre-allocates the output buffer of AllGather, and makes input buffer an alias of the output buffer, and expects both of them to be used to achieve lower memory usage. If we don't preserve this behavior and instead functionalize the AllGather op, AllGather op will then create a brand-new output buffer (instead of reusing), thus significantly increasing the memory usage.

The expectation is that we will "re-inplace" the AllGather op by switching to the out variant in Inductor post-grad stage via an FX pass, so this API is not expected to be directly used by users.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @chauhang @d4l3k

@yf225 yf225 requested a review from yifuwang May 15, 2024 21:45
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels May 15, 2024
Copy link

pytorch-bot bot commented May 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126334

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6722c5c with merge base c312cd8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -321,6 +334,13 @@ TORCH_LIBRARY(_c10d_functional, m) {
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
{at::Tag::pt2_compliant_tag});

m.def(
"all_gather_into_tensor_(Tensor(a!) output, Tensor input, int group_size, str group_name) -> Tensor(a!)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really love this to be an out-variant API, this is not a inplace allgather from operator prospective, it's a out variant op, we should follow the aten naming convention to make the op be an actual out-variant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great! Updated.

@yf225 yf225 changed the title Add inplace all_gather_into_tensor [Compile FSDP2] Add inplace all_gather_into_tensor May 15, 2024
@yf225 yf225 changed the title [Compile FSDP2] Add inplace all_gather_into_tensor [Compile FSDP2] Add all_gather_into_tensor out variant May 15, 2024
@yf225 yf225 changed the title [Compile FSDP2] Add all_gather_into_tensor out variant [Traceable FSDP2] Add all_gather_into_tensor out variant May 15, 2024
@@ -321,6 +334,13 @@ TORCH_LIBRARY(_c10d_functional, m) {
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
{at::Tag::pt2_compliant_tag});

m.def(
"all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We can establish the convention that:

  • Collectives that modify the input are postfixed with _
  • Out-variant of collectives are postfixed with _out, and the out argument is keyword-only

Copy link
Contributor

@yifuwang yifuwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@yf225
Copy link
Contributor Author

yf225 commented May 16, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 16, 2024
@yf225 yf225 added the topic: not user facing topic category label May 16, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
)

This PR adds `torch.ops._c10d_functional.all_gather_into_tensor_out`.

It's important for tracing FSDP2, because FSDP2 pre-allocates the output buffer of AllGather, and makes input buffer an alias of the output buffer, and expects both of them to be used to achieve lower memory usage. If we don't preserve this behavior and instead functionalize the AllGather op, AllGather op will then create a brand-new output buffer (instead of reusing), thus significantly increasing the memory usage.

The expectation is that we will "re-inplace" the AllGather op by switching to the out variant in Inductor post-grad stage via an FX pass, so this API is not expected to be directly used by users.

Pull Request resolved: pytorch#126334
Approved by: https://github.com/yifuwang, https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants