Skip to content

Commit

Permalink
[FSDP] Use correct handle training state when prefetching
Browse files Browse the repository at this point in the history
ghstack-source-id: 99da08f26f09bb65fc27ecfbfc24281bb2f8186d
Pull Request resolved: #98249
  • Loading branch information
awgu committed Apr 4, 2023
1 parent 6ae607d commit e4736dd
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 5 deletions.
99 changes: 98 additions & 1 deletion test/distributed/fsdp/test_fsdp_core.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# Owner(s): ["oncall: distributed"]

import contextlib
import functools
import itertools
import sys
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from unittest import mock

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from torch.distributed.fsdp.flat_param import FlatParamHandle
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.utils import _p_assert
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
AlwaysWrapNestedWrappedModule,
Expand Down Expand Up @@ -415,6 +420,98 @@ def test_transformer_no_grad(self, mixed_precision):
self.assertEqual(ref_output, no_grad_output)


class TestAutograd(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_unshard_params_as_tensors(
self,
):
"""
Tests that FSDP always unshards the logical parameters as ``Tensor``
views during forward and backward computation even when forward and/or
backward prefetching.
"""
self.run_subtests(
{
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP
# Skip testing `NO_SHARD` since it doubly uses
# `_use_unsharded_views()` for sharded views. Testing
# `FULL_SHARD` and `SHARD_GRAD_OP` provides good confidence
# that the `as_params` logic is correct.
],
"use_orig_params": [False, True],
"forward_prefetch": [False, True],
"backward_prefetch": [
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
None,
],
},
self._test_unshard_params_as_tensors,
)

def _test_unshard_params_as_tensors(
self,
sharding_strategy: ShardingStrategy,
use_orig_params: bool,
forward_prefetch: bool,
backward_prefetch: Optional[BackwardPrefetch],
):
NUM_MY_MODULES = 5

class MyModule(nn.Module):
def __init__(self, device: torch.device) -> None:
super().__init__()
self.p0 = nn.Parameter(torch.randn((3, 3), device=device))
self.p1 = nn.Parameter(torch.randn((3, 3), device=device))

def forward(self, x: torch.Tensor) -> torch.Tensor:
z = x
for p in (self.p0, self.p1):
z = z @ p
z = torch.nn.functional.relu(z)
return z

orig_use_unsharded_views = FlatParamHandle._use_unsharded_views

def _use_unsharded_views_assert_as_tensors(
self: FlatParamHandle, as_params: bool
) -> None:
_p_assert(
not as_params, "Expects to use Tensor views but using parameter views"
)
return orig_use_unsharded_views(self, as_params)

fsdp_kwargs = {
"sharding_strategy": sharding_strategy,
"use_orig_params": use_orig_params,
"forward_prefetch": forward_prefetch,
"backward_prefetch": backward_prefetch,
"auto_wrap_policy": ModuleWrapPolicy({MyModule}),
}
device = torch.device("cuda")
# Define a model with enough FSDP instances to exercise prefetching
model = nn.Sequential(*[MyModule(device) for _ in range(NUM_MY_MODULES)])
fsdp_model = FSDP(model, **fsdp_kwargs)
for _ in range(3):
inp = torch.randn((2, 3), device=device)
with self._patch_use_unsharded_views(
_use_unsharded_views_assert_as_tensors
):
loss = fsdp_model(inp).sum()
loss.backward()

@contextlib.contextmanager
def _patch_use_unsharded_views(self, new_use_unsharded_views: Callable):
orig_use_unsharded_views = FlatParamHandle._use_unsharded_views
FlatParamHandle._use_unsharded_views = new_use_unsharded_views
try:
yield
finally:
FlatParamHandle._use_unsharded_views = orig_use_unsharded_views


instantiate_parametrized_tests(TestHooks)
instantiate_parametrized_tests(TestParityWithDDP)
instantiate_parametrized_tests(TestNoGrad)
Expand Down
35 changes: 31 additions & 4 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from enum import auto, Enum
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -49,6 +50,11 @@
)


class _PrefetchMode(Enum):
BACKWARD = auto()
FORWARD = auto()


def _get_fsdp_root_states_with_modules(
module: nn.Module,
) -> Tuple[List[_FSDPState], List[nn.Module]]:
Expand Down Expand Up @@ -428,11 +434,16 @@ def _pre_forward_unshard(
"""Unshards parameters in the pre-forward."""
if not handles:
return
_unshard(state, handles, state._streams["unshard"], state._streams["pre_unshard"])
handles_key = tuple(handles)
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not state._handles_prefetched.get(handles_key, False):
_unshard(
state, handles, state._streams["unshard"], state._streams["pre_unshard"]
)
state._needs_pre_forward_unshard[handles_key] = False
torch.cuda.current_stream().wait_stream(state._streams["unshard"])
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.FORWARD)


@no_type_check
Expand Down Expand Up @@ -639,7 +650,7 @@ def _pre_backward_hook(
# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
state._needs_pre_backward_unshard[_handles_key] = False
_prefetch_handles(state, _handles_key)
_prefetch_handles(state, _handles_key, _PrefetchMode.BACKWARD)
for handle in _handles:
handle.prepare_gradient_for_backward()
state._ran_pre_backward_hook[_handles_key] = True
Expand Down Expand Up @@ -693,7 +704,7 @@ def _post_backward_hook(
# per module case since the post-backward hook runs per handle, not per
# group of handles.
handles_key = (handle,)
_prefetch_handles(state, handles_key)
_prefetch_handles(state, handles_key, _PrefetchMode.BACKWARD)

if not state._sync_gradients:
if handle._use_orig_params:
Expand Down Expand Up @@ -994,6 +1005,7 @@ def _finalize_params(
def _prefetch_handles(
state: _FSDPState,
current_handles_key: _HandlesKey,
prefetch_mode: _PrefetchMode,
) -> None:
"""
Prefetches the next handles if needed (without synchronization). An empty
Expand All @@ -1003,11 +1015,26 @@ def _prefetch_handles(
return
handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
for handles_key in handles_to_prefetch:
# Temporarily emulate the training state while calling `_unshard` to
# ensure the correct `as_params` for `_use_unsharded_views()`
prev_training_states: List[HandleTrainingState] = []
for handle in handles_key:
prev_training_states.append(handle._training_state)
if prefetch_mode == _PrefetchMode.BACKWARD:
handle._training_state = HandleTrainingState.BACKWARD_PRE
elif prefetch_mode == _PrefetchMode.FORWARD:
handle._training_state = HandleTrainingState.FORWARD
else:
raise ValueError(
f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}"
)
# Prefetch the next set of handles without synchronizing to allow
# the sync to happen as late as possible to maximize overlap
_unshard(
state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
)
for handle, prev_training_state in zip(handles_key, prev_training_states):
handle._training_state = prev_training_state
state._handles_prefetched[handles_key] = True


Expand Down

0 comments on commit e4736dd

Please sign in to comment.