-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Traceable FSDP2] Add all_gather_into_tensor out variant #126334
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126334
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6722c5c with merge base c312cd8 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
5f18b32
to
172bc95
Compare
@@ -321,6 +334,13 @@ TORCH_LIBRARY(_c10d_functional, m) { | |||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_), | |||
{at::Tag::pt2_compliant_tag}); | |||
|
|||
m.def( | |||
"all_gather_into_tensor_(Tensor(a!) output, Tensor input, int group_size, str group_name) -> Tensor(a!)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would really love this to be an out-variant API, this is not a inplace allgather from operator prospective, it's a out variant op, we should follow the aten naming convention to make the op be an actual out-variant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great! Updated.
@@ -321,6 +334,13 @@ TORCH_LIBRARY(_c10d_functional, m) { | |||
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_), | |||
{at::Tag::pt2_compliant_tag}); | |||
|
|||
m.def( | |||
"all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We can establish the convention that:
- Collectives that modify the input are postfixed with
_
- Out-variant of collectives are postfixed with
_out
, and theout
argument is keyword-only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) This PR adds `torch.ops._c10d_functional.all_gather_into_tensor_out`. It's important for tracing FSDP2, because FSDP2 pre-allocates the output buffer of AllGather, and makes input buffer an alias of the output buffer, and expects both of them to be used to achieve lower memory usage. If we don't preserve this behavior and instead functionalize the AllGather op, AllGather op will then create a brand-new output buffer (instead of reusing), thus significantly increasing the memory usage. The expectation is that we will "re-inplace" the AllGather op by switching to the out variant in Inductor post-grad stage via an FX pass, so this API is not expected to be directly used by users. Pull Request resolved: pytorch#126334 Approved by: https://github.com/yifuwang, https://github.com/wanchaol
This PR adds
torch.ops._c10d_functional.all_gather_into_tensor_out
.It's important for tracing FSDP2, because FSDP2 pre-allocates the output buffer of AllGather, and makes input buffer an alias of the output buffer, and expects both of them to be used to achieve lower memory usage. If we don't preserve this behavior and instead functionalize the AllGather op, AllGather op will then create a brand-new output buffer (instead of reusing), thus significantly increasing the memory usage.
The expectation is that we will "re-inplace" the AllGather op by switching to the out variant in Inductor post-grad stage via an FX pass, so this API is not expected to be directly used by users.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @chauhang @d4l3k