-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][Easy] Rename streams; add back stream sharing test #104966
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104966
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b57126f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -283,17 +283,10 @@ def _share_state_and_init_handle_attrs( | |||
"set yet or should have been set to `False`", | |||
) | |||
fsdp_state._is_root = False | |||
# Stream for unshard logic, including allocating the all-gather destination |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not need to duplicate these comments with the initial construction in _init_streams()
, as the comments may become stale and here we are only sharing the state.
pytorch/torch/distributed/fsdp/_runtime_utils.py
Lines 305 to 324 in b57126f
def _init_streams( | |
state: _FSDPState, | |
) -> _FSDPState: | |
""" | |
Initializes CUDA streams for overlapping communication, computation, and | |
data transfers. The streams should be shared across FSDP instances. | |
""" | |
assert state._is_root | |
assert state._device_handle.is_available() | |
# Stream for unshard logic, including allocating the all-gather destination | |
# tensors and the all-gathers themselves. | |
state._unshard_stream = state._device_handle.Stream() | |
# Stream for overlapping gradient reduction with the backward pass gradient | |
# computation. | |
state._post_backward_stream = state._device_handle.Stream() | |
# Stream for pre-unshard logic, namely allocations and writes for CPU | |
# offloading (H2D copy) and mixed precision (low precision cast). | |
state._pre_unshard_stream = state._device_handle.Stream() | |
# Default stream for computation | |
state._default_stream = state._device_handle.current_stream() |
@@ -276,7 +276,14 @@ def _test_nested_fully_shard_shared_state(self, use_policy: bool): | |||
# NOTE: This check only requires that the data structure state is | |||
# shared. Namely, sharing the FSDP state object itself is sufficient | |||
# but not necessary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We used to test that state._streams
was shared. This adds back testing that each stream is shared.
ghstack-source-id: 623474b088ab2d5ade75a1f952ae666d7e8e1f04 Pull Request resolved: pytorch#104966
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
CustomPolicy
#104986_FSDPPolicy.policy
with_Policy._run_policy
#104969ModuleWrapPolicy
to takeIterable
#104999ModuleWrapPolicy
#104427Purely out of preference, this PR renames the streams to
_unshard_stream
instead of_streams_unshard
etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.