Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reset prefetch flag upon reshard #111354

Closed
wants to merge 1 commit into from

Conversation

0x804d8000
Copy link
Contributor

The prefetched flag should be reset upon reshard. Otherwise, for zero2, next access to the corresponding parameter will skip "unshard" operation, and results in wrong parameter shape.

The need of unsharding is also metioned in the comment of FlatParameterHandle.unshard.

As FlatParameterHandle already guarded it against unnecessary all gather, this shouldn't incur extra communication overhead.

Personally I also find _prefetched a bit of mis-named, it should really be _unsharded.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111354

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit adef22c with merge base e0e15a4 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@0x804d8000
Copy link
Contributor Author

0x804d8000 commented Oct 16, 2023

I updated the UT to add a case when no_grad is used in forward pass. This is not needed to reproduce the failure, though without it we'd see a different error saying loss mismatch between FSDP and DDP. I suspect the reason is somehow related to https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_flat_param.py#L1681 .

The thing that does make a difference is forward_prefetch. Formerly test_hooks_multi_traversal did not specify this argument and left it to False. This prevented _prefetch_handle from ever running, and kept handle._prefetched always False. Therefore, _unshard was always called by _pre_forward_unshard, and thus no error was raised. Now with forward_prefetch=True added to the test, the error shows up.

Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

I think the usage of the _prefetched flag has changed slowly over time. If we were to rename it to _unsharded, we may want to check its usage more precisely since I think it is not currently used if we do not do any prefetching.

@awgu
Copy link
Contributor

awgu commented Oct 16, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants