Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Use correct handle training state when prefetching #98249

Closed
wants to merge 6 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Apr 3, 2023

Stack from ghstack (oldest at bottom):

This PR ensures that when prefetching a FlatParamHandle.unshard(), we temporarily set the FlatParamHandle._training_state to the expected training state as if the unshard() were not prefetched since the as_params argument to _use_unsharded_views() depends on the handle's training state.

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98249

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 58bfb34:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Apr 3, 2023
awgu added a commit that referenced this pull request Apr 3, 2023
ghstack-source-id: 867bf2c20716db0f80ca3e0f0d373f32b7da2d5a
Pull Request resolved: #98249
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unittest?

@@ -1003,11 +1015,25 @@ def _prefetch_handles(
return
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
for handles_key in handles_to_prefetch:
# Temporarily emulate the training state while calling `_unshard`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what improvement does doing this provide if there's no functionality difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For _use_unsharded_views(as_params=....), as_params depends on the handle's _training_state.

For example, suppose we are backward prefetching with BACKWARD_PRE and handle h1 prefetches handle h2.

  • h2._training_state is IDLE.
  • h1._training_state is BACKWARD_PRE.
  • With this PR's change, we will make h2._training_state = BACKWARD_PRE, while we prefetch the unshard, so that as_params in _use_unsharded_views() will correctly use as_params=False instead of True.
  • Without this change, _use_unsharded_views() will use as_params=True, which is actually incorrect for reentrant checkpointing.

I need to investigate more, but I think our FSDP <> AC unit tests were too weak to catch this bug that I introduced in #97981. Before #97981, we would just override the prefetched unshard with the correct _use_unsharded_views(as_params=...). However, after #97981, we skip the second overriding _use_unsharded_views(as_params=...). With this PR, we still skip the second overriding _use_unsharded_views(), but that is no longer an issue since the first prefetched _use_unsharded_views() uses the correct training state and hence the correct as_params.

awgu added a commit to awgu/pytorch that referenced this pull request Apr 3, 2023
ghstack-source-id: 5e5806e2a1a89b9c607f401c3427dacc0be7e6a7
Pull Request resolved: pytorch#98249
awgu added a commit that referenced this pull request Apr 4, 2023
ghstack-source-id: 99da08f26f09bb65fc27ecfbfc24281bb2f8186d
Pull Request resolved: #98249
awgu added a commit that referenced this pull request Apr 4, 2023
ghstack-source-id: 4d96d46f45b20f7b3da0280ebfce79217670e1d8
Pull Request resolved: #98249
awgu added a commit that referenced this pull request Apr 4, 2023
ghstack-source-id: 2f0081349be96d0c4467cb2c78e94895d00b0510
Pull Request resolved: #98249
@awgu awgu marked this pull request as ready for review April 4, 2023 10:57
@awgu awgu added topic: improvements topic category topic: bug fixes topic category ciflow/trunk Trigger trunk jobs on your pull request labels Apr 4, 2023
@awgu
Copy link
Contributor Author

awgu commented Apr 4, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: bug fixes topic category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants