New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Use correct handle training state when prefetching #98249
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98249
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 58bfb34: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 867bf2c20716db0f80ca3e0f0d373f32b7da2d5a Pull Request resolved: #98249
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unittest?
@@ -1003,11 +1015,25 @@ def _prefetch_handles( | |||
return | |||
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key) | |||
for handles_key in handles_to_prefetch: | |||
# Temporarily emulate the training state while calling `_unshard` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what improvement does doing this provide if there's no functionality difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For _use_unsharded_views(as_params=....)
, as_params
depends on the handle's _training_state
.
For example, suppose we are backward prefetching with BACKWARD_PRE
and handle h1
prefetches handle h2
.
h2._training_state
isIDLE
.h1._training_state
isBACKWARD_PRE
.- With this PR's change, we will make
h2._training_state = BACKWARD_PRE
, while we prefetch the unshard, so thatas_params
in_use_unsharded_views()
will correctly useas_params=False
instead ofTrue
. - Without this change,
_use_unsharded_views()
will useas_params=True
, which is actually incorrect for reentrant checkpointing.
I need to investigate more, but I think our FSDP <> AC unit tests were too weak to catch this bug that I introduced in #97981. Before #97981, we would just override the prefetched unshard with the correct _use_unsharded_views(as_params=...)
. However, after #97981, we skip the second overriding _use_unsharded_views(as_params=...)
. With this PR, we still skip the second overriding _use_unsharded_views()
, but that is no longer an issue since the first prefetched _use_unsharded_views()
uses the correct training state and hence the correct as_params
.
[ghstack-poisoned]
ghstack-source-id: 5e5806e2a1a89b9c607f401c3427dacc0be7e6a7 Pull Request resolved: pytorch#98249
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 99da08f26f09bb65fc27ecfbfc24281bb2f8186d Pull Request resolved: #98249
[ghstack-poisoned]
ghstack-source-id: 4d96d46f45b20f7b3da0280ebfce79217670e1d8 Pull Request resolved: #98249
[ghstack-poisoned]
ghstack-source-id: 2f0081349be96d0c4467cb2c78e94895d00b0510 Pull Request resolved: #98249
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
_use_sharded_views()
forSHARD_GRAD_OP
#98250requires_grad
foruse_orig_params=True
#98221This PR ensures that when prefetching a
FlatParamHandle.unshard()
, we temporarily set theFlatParamHandle._training_state
to the expected training state as if theunshard()
were not prefetched since theas_params
argument to_use_unsharded_views()
depends on the handle's training state.