Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Apr 25, 2024

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124958

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 218ac4c with merge base c82fcb7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 25, 2024
@kwen2501 kwen2501 requested review from H-Huang and wconstab April 25, 2024 19:01
@kwen2501 kwen2501 added the topic: not user facing topic category label Apr 25, 2024
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Apr 25, 2024
ghstack-source-id: fc32dda
Pull Request resolved: #124958
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Apr 30, 2024
Add document: distributed.pipelining.rst

ghstack-source-id: d79b483
Pull Request resolved: #124958
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Apr 30, 2024
Add document: distributed.pipelining.rst

Move some modules to private

ghstack-source-id: fc59a93
Pull Request resolved: #124958
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Apr 30, 2024
Add document: distributed.pipelining.rst

ghstack-source-id: 7ee3b0e
Pull Request resolved: #124958

# TODO: handling requires_grad=False dynamically. Can we analyze this during initial
# IR emission?
def _null_coalesce_accumulate(lhs, rhs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where will this be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a second thought, I think it is okay to remove it.

else:
grad_inputs.append(None)

# Alternative impl: `torch.autograd.grad`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the trade offs, is there a reason to pick one over the other?

Well, to answer my own question we will want to use .grad if we implement zero bubble as we discussed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can create another util function to hold the .grad impl. Well, maybe two bc we will need two calls for zero bubble.

.. role:: hidden
:class: hidden-section

Pipeline Parallelism
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adding this is there a plan to dedup with the Readme added in the first PR? Seems like we wouldn't need that anymore but I'm not sure if the content is 100%same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Nice to reduce maintenance load.

@kwen2501
Copy link
Contributor Author

kwen2501 commented May 1, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 1, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 2, 2024
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This is a helper function which:
1. computes the gradients for the stage inputs, and
2. accumulates gradients for the stage module's parameters.

A unit test for this function is also added.

Pull Request resolved: pytorch#124958
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#124776, pytorch#124875
@github-actions github-actions bot deleted the gh/kwen2501/17/head branch June 4, 2024 02:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-td-distributed ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants