-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Do not _unshard
if already prefetched
#97981
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97981
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c789061: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: d8149d2f96f14b8a60393371ce16c328264753ca Pull Request resolved: #97981
[ghstack-poisoned]
ghstack-source-id: 9d76f160eb1e37ede7485773d6344e632a61042b Pull Request resolved: #97981
[ghstack-poisoned]
ghstack-source-id: dd1c57e7790d0bca8ff3a780b17168391d60ba3a Pull Request resolved: #97981
[ghstack-poisoned]
ghstack-source-id: b952c9ea2446a3d4c11e235854a169b60c41488d Pull Request resolved: #97981
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
- This undoes #97981 due to a minor concern: The unshard logic (namely the `_use_unsharded_flat_parameter()` -> `_use_unsharded_views()` logic) for a handle depends on its own training state, not on that of the handle that prefetched it. Therefore, we actually want the handle itself to run `_use_unsharded_flat_parameter()`, not the handle that is prefetching. Otherwise, the handle could still be in `IDLE` but is being prefetched from another handle in `BACKWARD_PRE`, in which case the original handle does not correctly use the logic as if it were in `BACKWARD_PRE`. - Instead, we only run the `_use_unsharded_flat_parameter()` when the handle itself is running `unshard()`, which we differentiate using the `is_prefetch: bool` flag. [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
_unshard
if already prefetched #97981_runtime_utils.py
#97980use_orig_params=True
#97667FSDPParamInfo
to useFlatParamHandle
#97665use_orig_params: bool
#97664