New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: reset prefetch flag upon reshard #111354
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111354
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit adef22c with merge base e0e15a4 (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4261c62
to
5e068e6
Compare
5e068e6
to
adef22c
Compare
I updated the UT to add a case when The thing that does make a difference is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
I think the usage of the _prefetched
flag has changed slowly over time. If we were to rename it to _unsharded
, we may want to check its usage more precisely since I think it is not currently used if we do not do any prefetching.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The
prefetched
flag should be reset upon reshard. Otherwise, for zero2, next access to the corresponding parameter will skip "unshard" operation, and results in wrong parameter shape.The need of unsharding is also metioned in the comment of
FlatParameterHandle.unshard
.As
FlatParameterHandle
already guarded it against unnecessary all gather, this shouldn't incur extra communication overhead.Personally I also find
_prefetched
a bit of mis-named, it should really be_unsharded
.