-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[dtensor] to_local backward grad placement passthrough #121474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
to_local accepts a `grad_placements` if user choose to pass, previously we enforce the grad_out to be the "same" placement as the current DTensor for safety. But I realized that we DO NOT need to enforce this constraint. Why? backward placement does not need to be the same as fwd tensor placement, this is already the case for param vs param.grad (i.e. param can be replicate and grad can be partial), so we should not restrict this to activation vs activation grad too [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121474
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit fa58f6b with merge base 3d089de ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
to_local accepts a `grad_placements` if user choose to pass, previously we enforce the grad_out to be the "same" placement as the current DTensor for safety. But I realized that we DO NOT need to enforce this constraint. Why? backward placement does not need to be the same as fwd tensor placement, this is already the case for param vs param.grad (i.e. param can be replicate and grad can be partial), so we should not restrict this to activation vs activation grad too ghstack-source-id: bdc6344 Pull Request resolved: #121474
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
|
@yoyoyocmu @xw285cornell This PR would directly produce reduce_scatter instead of allreduce + chunk, so that we don't need to add fusion pass to fuse this pattern in #120051 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
to_local accepts a `grad_placements` if user choose to pass, previously we enforce the grad_out to be the "same" placement as the current DTensor for safety. But I realized that we DO NOT need to enforce this constraint. Why? backward placement does not need to be the same as fwd tensor placement, this is already the case for param vs param.grad (i.e. param can be replicate and grad can be partial), so we should not restrict this to activation vs activation grad too Pull Request resolved: #121474 Approved by: https://github.com/awgu, https://github.com/yoyoyocmu, https://github.com/yifuwang
Stack from ghstack (oldest at bottom):
to_local accepts a
grad_placementsif user choose to pass, previouslywe enforce the grad_out to be the "same" placement as the current
DTensor for safety.
But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang