Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo x fsdp] Simplify stream logic handling #103902

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_nested_fully_shard_shared_state(self):
# NOTE: This check only requires that the data structure state is
# shared. Namely, sharing the FSDP state object itself is sufficient
# but not necessary.
data_structure_names = ["_streams", "_exec_order_data", "_free_event_queue"]
data_structure_names = ["_exec_order_data", "_free_event_queue"]
awgu marked this conversation as resolved.
Show resolved Hide resolved
for data_structure_name in data_structure_names:
all_structures = set()
for module in (
Expand Down
4 changes: 0 additions & 4 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,6 @@ def _init_core_state(
state._use_orig_params = use_orig_params
state.training_state = TrainingState.IDLE
state._is_root = None
_streams: Dict[str, torch.cuda.Stream] = {}
state._streams = _streams
_stream_to_name: Dict[torch.cuda.Stream, str] = {}
state._stream_to_name = _stream_to_name
state._free_event_queue = _FreeEventQueue()
state._debug_level = dist.get_debug_level()
state._exec_order_data = exec_order_utils._ExecOrderData(
Expand Down
59 changes: 28 additions & 31 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,17 @@ def _share_state_and_init_handle_attrs(
"set yet or should have been set to `False`",
)
fsdp_state._is_root = False
fsdp_state._streams = root_state._streams
fsdp_state._stream_to_name = root_state._stream_to_name
# Stream for unshard logic, including allocating the all-gather destination
# tensors and the all-gathers themselves.
fsdp_state._streams_unshard = root_state._streams_unshard
# Stream for overlapping gradient reduction with the backward pass gradient
# computation.
fsdp_state._streams_post_backward = root_state._streams_post_backward
# Stream for pre-unshard logic, namely allocations and writes for CPU
# offloading (H2D copy) and mixed precision (low precision cast).
fsdp_state._streams_pre_unshard = root_state._streams_pre_unshard
# Default stream for computation
fsdp_state._streams_default = root_state._streams_default
fsdp_state._exec_order_data = root_state._exec_order_data
fsdp_state._free_event_queue = root_state._free_event_queue
fsdp_state._handles_prefetched = root_state._handles_prefetched
Expand All @@ -313,21 +322,15 @@ def _init_streams(
assert state._device_handle.is_available()
# Stream for unshard logic, including allocating the all-gather destination
# tensors and the all-gathers themselves.
state._streams["unshard"] = state._device_handle.Stream()
state._streams_unshard = state._device_handle.Stream()
# Stream for overlapping gradient reduction with the backward pass gradient
# computation.
state._streams["post_backward"] = state._device_handle.Stream()
state._streams_post_backward = state._device_handle.Stream()
# Stream for pre-unshard logic, namely allocations and writes for CPU
# offloading (H2D copy) and mixed precision (low precision cast).
state._streams["pre_unshard"] = state._device_handle.Stream()
state._streams_pre_unshard = state._device_handle.Stream()
# Default stream for computation
state._streams["default"] = state._device_handle.current_stream()
state._stream_to_name = {
state._device_handle.current_stream(): "default",
state._streams["unshard"]: "unshard",
state._streams["pre_unshard"]: "pre_unshard",
state._streams["post_backward"]: "post_backward",
}
state._streams_default = state._device_handle.current_stream()


@no_type_check
Expand Down Expand Up @@ -474,11 +477,9 @@ def _pre_forward_unshard(
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not state._handles_prefetched.get(handles_key, False):
_unshard(
state, handles, state._streams["unshard"], state._streams["pre_unshard"]
)
_unshard(state, handles, state._streams_unshard, state._streams_pre_unshard)
state._needs_pre_forward_unshard[handles_key] = False
state._device_handle.current_stream().wait_stream(state._streams["unshard"])
state._device_handle.current_stream().wait_stream(state._streams_unshard)
_prefetch_handles(state, handles_key, _PrefetchMode.FORWARD)


Expand Down Expand Up @@ -624,8 +625,8 @@ def _root_pre_forward(
state._needs_pre_forward_unshard[handles_key] = True
_wait_for_computation_stream(
state._device_handle.current_stream(),
state._streams["unshard"],
state._streams["pre_unshard"],
state._streams_unshard,
state._streams_pre_unshard,
)
_clear_grads_if_needed(state._all_handles)

Expand Down Expand Up @@ -704,10 +705,10 @@ def _pre_backward_hook(
_unshard(
state,
_handles,
state._streams["unshard"],
state._streams["pre_unshard"],
state._streams_unshard,
state._streams_pre_unshard,
)
state._device_handle.current_stream().wait_stream(state._streams["unshard"])
state._device_handle.current_stream().wait_stream(state._streams_unshard)

# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
Expand Down Expand Up @@ -780,11 +781,9 @@ def _post_backward_hook(

# Wait for all ops in the current stream (e.g. gradient
# computation) to finish before reduce-scattering the gradient
state._streams["post_backward"].wait_stream(
state._device_handle.current_stream()
)
state._streams_post_backward.wait_stream(state._device_handle.current_stream())

with state._device_handle.stream(state._streams["post_backward"]):
with state._device_handle.stream(state._streams_post_backward):
autograd_computed_grad = flat_param.grad.data
if state._exec_order_data.is_first_iter: # only check once
_check_comm_hook(
Expand Down Expand Up @@ -867,14 +866,14 @@ def _post_backward_hook(
# post-backward stream, inform the caching allocator
_no_dispatch_record_stream(
grad_to_offload.data,
state._streams["post_backward"],
state._streams_post_backward,
)

# Since the unsharded gradient is produced in the computation
# stream and consumed in the post-backward stream, inform the
# caching allocator (before it goes out of scope)
_no_dispatch_record_stream(
autograd_computed_grad, state._streams["post_backward"]
autograd_computed_grad, state._streams_post_backward
)

if handle._use_orig_params:
Expand Down Expand Up @@ -1003,7 +1002,7 @@ def _post_backward_final_callback(

if root_state._sync_gradients:
state._device_handle.current_stream().wait_stream(
root_state._streams["post_backward"]
root_state._streams_post_backward
)
if root_state.cpu_offload.offload_params:
# Wait for non-blocking GPU -> CPU sharded gradient copies from the
Expand Down Expand Up @@ -1123,9 +1122,7 @@ def _prefetch_handles(
)
# Prefetch the next set of handles without synchronizing to allow
# the sync to happen as late as possible to maximize overlap
_unshard(
state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
)
_unshard(state, handles_key, state._streams_unshard, state._streams_pre_unshard)
for handle, prev_training_state in zip(handles_key, prev_training_states):
handle._training_state = prev_training_state
state._handles_prefetched[handles_key] = True
Expand Down
Loading