-
Notifications
You must be signed in to change notification settings - Fork 25.5k
[FSDP] summon_full_params()
in computation stream
#86836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86836
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c777cde: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 2bc70f8 Pull Request resolved: pytorch#86836
ghstack-source-id: 9463308 Pull Request resolved: pytorch/pytorch#86836
ghstack-source-id: 2bc70f8 Pull Request resolved: pytorch#86836
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but curious why we need this change.
free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles] | ||
self._unshard(self._handles) | ||
self._streams["computation"].wait_stream(self._streams["unshard"]) | ||
# No need to call `wait_stream()` since we unshard in the computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, why would we want to move this to computation stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows us to use caching allocator blocks from the computation stream for these all-gathers, which should help avoid over-allocating blocks to the unshard stream.
This should help with memory usage. [ghstack-poisoned]
This should help with memory usage. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @awgu. |
This should help with memory usage. In particular, this allows FSDP to use caching allocator blocks from the computation stream for the `summon_full_params()` all-gathers, which should help avoid over-allocating blocks to the unshard stream. Pull Request resolved: pytorch#86836 Approved by: https://github.com/rohan-varma
This should help with memory usage. In particular, this allows FSDP to use caching allocator blocks from the computation stream for the `summon_full_params()` all-gathers, which should help avoid over-allocating blocks to the unshard stream. Pull Request resolved: pytorch#86836 Approved by: https://github.com/rohan-varma
Stack from ghstack:
summon_full_params()
in computation stream #86836 [FSDP]summon_full_params()
in computation streamNone
edge case #87308 [FSDP][2/N] Fix grad zero vs.None
edge casesummon_full_params(with_grads)
None
gradient #87314 [FSDP][1/N] Updatesummon_full_params(with_grads)
None
gradientThis should help with memory usage. In particular, this allows FSDP to use caching allocator blocks from the computation stream for the
summon_full_params()
all-gathers, which should help avoid over-allocating blocks to the unshard stream.