Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Mar 8, 2024

Stack from ghstack (oldest at bottom):

to_local accepts a grad_placements if user choose to pass, previously
we enforce the grad_out to be the "same" placement as the current
DTensor for safety.

But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

to_local accepts a `grad_placements` if user choose to pass, previously
we enforce the grad_out to be the "same" placement as the current
DTensor for safety.

But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121474

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit fa58f6b with merge base 3d089de (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wanchaol added a commit that referenced this pull request Mar 8, 2024
to_local accepts a `grad_placements` if user choose to pass, previously
we enforce the grad_out to be the "same" placement as the current
DTensor for safety.

But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too

ghstack-source-id: bdc6344
Pull Request resolved: #121474
@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Mar 8, 2024
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

@wanchaol
Copy link
Collaborator Author

wanchaol commented Mar 8, 2024

@yoyoyocmu @xw285cornell This PR would directly produce reduce_scatter instead of allreduce + chunk, so that we don't need to add fusion pass to fuse this pattern in #120051

Copy link
Contributor

@yoyoyocmu yoyoyocmu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@wanchaol wanchaol added ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (dtensor) release notes category labels Mar 8, 2024
Copy link
Collaborator

@yifuwang yifuwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@wanchaol
Copy link
Collaborator Author

wanchaol commented Mar 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pianpwk pushed a commit that referenced this pull request Mar 11, 2024
to_local accepts a `grad_placements` if user choose to pass, previously
we enforce the grad_out to be the "same" placement as the current
DTensor for safety.

But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too

Pull Request resolved: #121474
Approved by: https://github.com/awgu, https://github.com/yoyoyocmu, https://github.com/yifuwang
@github-actions github-actions bot deleted the gh/wanchaol/445/head branch April 8, 2024 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants