Skip to content

Commit

Permalink
[FSDP] Use correct handle training state when prefetching (#98249)
Browse files Browse the repository at this point in the history
This PR ensures that when prefetching a `FlatParamHandle.unshard()`, we temporarily set the `FlatParamHandle._training_state` to the expected training state as if the `unshard()` were not prefetched since the `as_params` argument to `_use_unsharded_views()` depends on the handle's training state.

Pull Request resolved: #98249
Approved by: https://github.com/rohan-varma
  • Loading branch information
awgu authored and pytorchmergebot committed Apr 4, 2023
1 parent 950431c commit 0b31f87
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 5 deletions.
88 changes: 87 additions & 1 deletion test/distributed/fsdp/test_fsdp_core.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# Owner(s): ["oncall: distributed"]

import contextlib
import functools
import itertools
import sys
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from unittest import mock

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from torch.distributed.fsdp.flat_param import FlatParamHandle
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.utils import _p_assert
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
AlwaysWrapNestedWrappedModule,
Expand Down Expand Up @@ -415,6 +420,87 @@ def test_transformer_no_grad(self, mixed_precision):
self.assertEqual(ref_output, no_grad_output)


class TestAutograd(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_unshard_params_as_tensors(
self,
):
"""
Tests that FSDP always unshards the logical parameters as ``Tensor``
views during forward and backward computation even when forward and/or
backward prefetching.
"""
self.run_subtests(
{
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP
# Skip testing `NO_SHARD` since it doubly uses
# `_use_unsharded_views()` for sharded views. Testing
# `FULL_SHARD` and `SHARD_GRAD_OP` provides good confidence
# that the `as_params` logic is correct.
],
"use_orig_params": [False, True],
"forward_prefetch": [False, True],
"backward_prefetch": [
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
None,
],
},
self._test_unshard_params_as_tensors,
)

def _test_unshard_params_as_tensors(
self,
sharding_strategy: ShardingStrategy,
use_orig_params: bool,
forward_prefetch: bool,
backward_prefetch: Optional[BackwardPrefetch],
):
orig_use_unsharded_views = FlatParamHandle._use_unsharded_views

def _use_unsharded_views_assert_as_tensors(
self: FlatParamHandle, as_params: bool
) -> None:
_p_assert(
not as_params, "Expects to use Tensor views but using parameter views"
)
return orig_use_unsharded_views(self, as_params)

fsdp_kwargs = {
"sharding_strategy": sharding_strategy,
"use_orig_params": use_orig_params,
"forward_prefetch": forward_prefetch,
"backward_prefetch": backward_prefetch,
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
}
device = torch.device("cuda")
# Define a model with enough FSDP instances to exercise prefetching
NUM_LINEARS = 5
model = nn.Sequential(
*[nn.Linear(3, 3, device=device) for _ in range(NUM_LINEARS)]
)
fsdp_model = FSDP(model, **fsdp_kwargs)
self.assertEqual(len(list(FSDP.fsdp_modules(fsdp_model))), NUM_LINEARS + 1)
for _ in range(3):
inp = torch.randn((2, 3), device=device)
with self._patch_use_unsharded_views(
_use_unsharded_views_assert_as_tensors
):
loss = fsdp_model(inp).sum()
loss.backward()

@contextlib.contextmanager
def _patch_use_unsharded_views(self, new_use_unsharded_views: Callable):
orig_use_unsharded_views = FlatParamHandle._use_unsharded_views
FlatParamHandle._use_unsharded_views = new_use_unsharded_views
try:
yield
finally:
FlatParamHandle._use_unsharded_views = orig_use_unsharded_views


instantiate_parametrized_tests(TestHooks)
instantiate_parametrized_tests(TestParityWithDDP)
instantiate_parametrized_tests(TestNoGrad)
Expand Down
35 changes: 31 additions & 4 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from enum import auto, Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -49,6 +50,11 @@
)


class _PrefetchMode(Enum):
BACKWARD = auto()
FORWARD = auto()


def _get_fsdp_root_states_with_modules(
module: nn.Module,
) -> Tuple[List[_FSDPState], List[nn.Module]]:
Expand Down Expand Up @@ -428,11 +434,16 @@ def _pre_forward_unshard(
"""Unshards parameters in the pre-forward."""
if not handles:
return
_unshard(state, handles, state._streams["unshard"], state._streams["pre_unshard"])
handles_key = tuple(handles)
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not state._handles_prefetched.get(handles_key, False):
_unshard(
state, handles, state._streams["unshard"], state._streams["pre_unshard"]
)
state._needs_pre_forward_unshard[handles_key] = False
torch.cuda.current_stream().wait_stream(state._streams["unshard"])
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.FORWARD)


@no_type_check
Expand Down Expand Up @@ -639,7 +650,7 @@ def _pre_backward_hook(
# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
state._needs_pre_backward_unshard[_handles_key] = False
_prefetch_handles(state, _handles_key)
_prefetch_handles(state, _handles_key, _PrefetchMode.BACKWARD)
for handle in _handles:
handle.prepare_gradient_for_backward()
state._ran_pre_backward_hook[_handles_key] = True
Expand Down Expand Up @@ -693,7 +704,7 @@ def _post_backward_hook(
# per module case since the post-backward hook runs per handle, not per
# group of handles.
handles_key = (handle,)
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.BACKWARD)

if not state._sync_gradients:
if handle._use_orig_params:
Expand Down Expand Up @@ -994,6 +1005,7 @@ def _finalize_params(
def _prefetch_handles(
state: _FSDPState,
current_handles_key: _HandlesKey,
prefetch_mode: _PrefetchMode,
) -> None:
"""
Prefetches the next handles if needed (without synchronization). An empty
Expand All @@ -1003,11 +1015,26 @@ def _prefetch_handles(
return
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
for handles_key in handles_to_prefetch:
# Temporarily emulate the training state while calling `_unshard` to
# ensure the correct `as_params` for `_use_unsharded_views()`
prev_training_states: List[HandleTrainingState] = []
for handle in handles_key:
prev_training_states.append(handle._training_state)
if prefetch_mode == _PrefetchMode.BACKWARD:
handle._training_state = HandleTrainingState.BACKWARD_PRE
elif prefetch_mode == _PrefetchMode.FORWARD:
handle._training_state = HandleTrainingState.FORWARD
else:
raise ValueError(
f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}"
)
# Prefetch the next set of handles without synchronizing to allow
# the sync to happen as late as possible to maximize overlap
_unshard(
state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
)
for handle, prev_training_state in zip(handles_key, prev_training_states):
handle._training_state = prev_training_state
state._handles_prefetched[handles_key] = True


Expand Down

0 comments on commit 0b31f87

Please sign in to comment.