Skip to content

Commit

Permalink
[FSDP] Use correct handle training state when prefetching
Browse files Browse the repository at this point in the history
ghstack-source-id: 867bf2c20716db0f80ca3e0f0d373f32b7da2d5a
Pull Request resolved: #98249
  • Loading branch information
awgu committed Apr 3, 2023
1 parent 1b6720d commit 4107ce7
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from enum import auto, Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -49,6 +50,11 @@
)


class _PrefetchMode(Enum):
BACKWARD = auto()
FORWARD = auto()


def _get_fsdp_root_states_with_modules(
module: nn.Module,
) -> Tuple[List[_FSDPState], List[nn.Module]]:
Expand Down Expand Up @@ -428,11 +434,16 @@ def _pre_forward_unshard(
"""Unshards parameters in the pre-forward."""
if not handles:
return
_unshard(state, handles, state._streams["unshard"], state._streams["pre_unshard"])
handles_key = tuple(handles)
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not state._handles_prefetched.get(handles_key, False):
_unshard(
state, handles, state._streams["unshard"], state._streams["pre_unshard"]
)
state._needs_pre_forward_unshard[handles_key] = False
torch.cuda.current_stream().wait_stream(state._streams["unshard"])
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.FORWARD)


@no_type_check
Expand Down Expand Up @@ -639,7 +650,7 @@ def _pre_backward_hook(
# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
state._needs_pre_backward_unshard[_handles_key] = False
_prefetch_handles(state, _handles_key)
_prefetch_handles(state, _handles_key, _PrefetchMode.BACKWARD)
for handle in _handles:
handle.prepare_gradient_for_backward()
state._ran_pre_backward_hook[_handles_key] = True
Expand Down Expand Up @@ -693,7 +704,7 @@ def _post_backward_hook(
# per module case since the post-backward hook runs per handle, not per
# group of handles.
handles_key = (handle,)
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.BACKWARD)

if not state._sync_gradients:
if handle._use_orig_params:
Expand Down Expand Up @@ -994,6 +1005,7 @@ def _finalize_params(
def _prefetch_handles(
state: _FSDPState,
current_handles_key: _HandlesKey,
prefetch_mode: _PrefetchMode,
) -> None:
"""
Prefetches the next handles if needed (without synchronization). An empty
Expand All @@ -1003,11 +1015,25 @@ def _prefetch_handles(
return
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
for handles_key in handles_to_prefetch:
# Temporarily emulate the training state while calling `_unshard`
prev_training_states: List[HandleTrainingState] = []
for handle in handles_key:
prev_training_states.append(handle._training_state)
if prefetch_mode == _PrefetchMode.BACKWARD:
handle._training_state = HandleTrainingState.BACKWARD_PRE
elif prefetch_mode == _PrefetchMode.FORWARD:
handle._training_state = HandleTrainingState.FORWARD
else:
raise ValueError(
f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}"
)
# Prefetch the next set of handles without synchronizing to allow
# the sync to happen as late as possible to maximize overlap
_unshard(
state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
)
for handle, prev_training_state in zip(handles_key, prev_training_states):
handle._training_state = prev_training_state
state._handles_prefetched[handles_key] = True


Expand Down

0 comments on commit 4107ce7

Please sign in to comment.