Skip to content

Conversation

quanta42
Copy link
Contributor

@quanta42 quanta42 commented Jun 5, 2024

This commit improves the FullyShardedDataParallel (FSDP) class in PyTorch by reducing unnecessary GPU synchronizations by reusing a pre-allocated zero tensor.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang

Copy link

pytorch-bot bot commented Jun 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128069

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 204485d with merge base 92151c8 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jun 5, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (sharded) release notes category labels Jun 5, 2024
# implemented using post-save and pre-load hooks
_init_state_dict_state(self)
_register_all_state_dict_hooks(self)
self.zero = torch.tensor(0.0, device=self.compute_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we initialize this lazily (upon calling clip_grad_norm_ and actually needing to use this)?

Maybe we can put it in a private attribute like self._zero_scalar?

@awgu
Copy link
Collaborator

awgu commented Jun 5, 2024

Thanks @quanta42 for the contribution! Could you sign the CLA?

@quanta42
Copy link
Contributor Author

quanta42 commented Jun 5, 2024

Thanks @quanta42 for the contribution! Could you sign the CLA?

I did sign it but I work for Adobe so I need the authorization from the administrators.
I also cc them in slack, hopefully, they will get back to me soon.

Thanks for the prompt review!

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Could you fix lint?

@awgu awgu added release notes: distributed (fsdp) release notes category and removed release notes: distributed (sharded) release notes category labels Jun 6, 2024
Copy link
Contributor

github-actions bot commented Aug 6, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@quanta42
Copy link
Contributor Author

@awgu I was wondering why this hasn't been merged after fixing the linting?
I don't have the authorization to merge this and it seems that now tests are failing probably because of code update?
What's the process to merge this type of changes?

@awgu
Copy link
Collaborator

awgu commented Aug 20, 2024

sorry this fell through the cracks
let me rebase and try to remerge

@awgu
Copy link
Collaborator

awgu commented Aug 20, 2024

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

This commit improves the FullyShardedDataParallel (FSDP) class in PyTorch by reducing unnecessary GPU synchronizations by reusing a pre-allocated zero tensor.
@pytorchmergebot
Copy link
Collaborator

Successfully rebased patch-1 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout patch-1 && git pull --rebase)

@quanta42
Copy link
Contributor Author

sorry this fell through the cracks let me rebase and try to remerge

Thank you

@awgu
Copy link
Collaborator

awgu commented Aug 20, 2024

@awgu
Copy link
Collaborator

awgu commented Aug 20, 2024

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 20, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-focal-py3.8-clang10-onnx / test (default, 2, 2, amz2023.linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants