Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Use correct handle training state when prefetching #98249

Closed
wants to merge 6 commits into from
34 changes: 30 additions & 4 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from enum import auto, Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -49,6 +50,11 @@
)


class _PrefetchMode(Enum):
BACKWARD = auto()
FORWARD = auto()


def _get_fsdp_root_states_with_modules(
module: nn.Module,
) -> Tuple[List[_FSDPState], List[nn.Module]]:
Expand Down Expand Up @@ -428,11 +434,16 @@ def _pre_forward_unshard(
"""Unshards parameters in the pre-forward."""
if not handles:
return
_unshard(state, handles, state._streams["unshard"], state._streams["pre_unshard"])
handles_key = tuple(handles)
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not state._handles_prefetched.get(handles_key, False):
_unshard(
state, handles, state._streams["unshard"], state._streams["pre_unshard"]
)
state._needs_pre_forward_unshard[handles_key] = False
torch.cuda.current_stream().wait_stream(state._streams["unshard"])
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.FORWARD)


@no_type_check
Expand Down Expand Up @@ -639,7 +650,7 @@ def _pre_backward_hook(
# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
state._needs_pre_backward_unshard[_handles_key] = False
_prefetch_handles(state, _handles_key)
_prefetch_handles(state, _handles_key, _PrefetchMode.BACKWARD)
for handle in _handles:
handle.prepare_gradient_for_backward()
state._ran_pre_backward_hook[_handles_key] = True
Expand Down Expand Up @@ -693,7 +704,7 @@ def _post_backward_hook(
# per module case since the post-backward hook runs per handle, not per
# group of handles.
handles_key = (handle,)
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.BACKWARD)

if not state._sync_gradients:
if handle._use_orig_params:
Expand Down Expand Up @@ -994,6 +1005,7 @@ def _finalize_params(
def _prefetch_handles(
state: _FSDPState,
current_handles_key: _HandlesKey,
prefetch_mode: _PrefetchMode,
) -> None:
"""
Prefetches the next handles if needed (without synchronization). An empty
Expand All @@ -1003,11 +1015,25 @@ def _prefetch_handles(
return
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
for handles_key in handles_to_prefetch:
# Temporarily emulate the training state while calling `_unshard`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what improvement does doing this provide if there's no functionality difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For _use_unsharded_views(as_params=....), as_params depends on the handle's _training_state.

For example, suppose we are backward prefetching with BACKWARD_PRE and handle h1 prefetches handle h2.

  • h2._training_state is IDLE.
  • h1._training_state is BACKWARD_PRE.
  • With this PR's change, we will make h2._training_state = BACKWARD_PRE, while we prefetch the unshard, so that as_params in _use_unsharded_views() will correctly use as_params=False instead of True.
  • Without this change, _use_unsharded_views() will use as_params=True, which is actually incorrect for reentrant checkpointing.

I need to investigate more, but I think our FSDP <> AC unit tests were too weak to catch this bug that I introduced in #97981. Before #97981, we would just override the prefetched unshard with the correct _use_unsharded_views(as_params=...). However, after #97981, we skip the second overriding _use_unsharded_views(as_params=...). With this PR, we still skip the second overriding _use_unsharded_views(), but that is no longer an issue since the first prefetched _use_unsharded_views() uses the correct training state and hence the correct as_params.

prev_training_states: List[HandleTrainingState] = []
for handle in handles_key:
prev_training_states.append(handle._training_state)
if prefetch_mode == _PrefetchMode.BACKWARD:
handle._training_state = HandleTrainingState.BACKWARD_PRE
elif prefetch_mode == _PrefetchMode.FORWARD:
handle._training_state = HandleTrainingState.FORWARD
else:
raise ValueError(
f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}"
)
# Prefetch the next set of handles without synchronizing to allow
# the sync to happen as late as possible to maximize overlap
_unshard(
state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
)
for handle, prev_training_state in zip(handles_key, prev_training_states):
handle._training_state = prev_training_state
state._handles_prefetched[handles_key] = True


Expand Down