From 87210008dcd0b17329082ce0784ab09b76e3db0d Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 17 Jun 2024 13:16:21 -0700 Subject: [PATCH] [FSDP2] Added APIs for explicit fwd/bwd prefetching ghstack-source-id: 86e664adafe8e22d99a3809ba6712d54c25dc5e5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128884 --- .../_composable/fsdp/test_fully_shard_comm.py | 205 +++++++++++++++++- .../fsdp/test_fully_shard_training.py | 40 +++- .../_composable/fsdp/_fsdp_param_group.py | 29 ++- .../_composable/fsdp/_fsdp_state.py | 11 +- .../_composable/fsdp/fully_shard.py | 47 +++- torch/testing/_internal/common_fsdp.py | 13 ++ 6 files changed, 333 insertions(+), 12 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 5acb9d895b413..c0e3fbc9aea88 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -43,6 +43,7 @@ FSDPTestMultiThread, MLP, patch_post_backward, + patch_reshard, patch_unshard, ) from torch.testing._internal.common_utils import run_tests @@ -372,7 +373,7 @@ def test_manual_reshard_with_reshard_after_forward_false(self): ) -class TestFullyShardBackwardPrefetch(FSDPTest): +class TestFullyShardPrefetch(FSDPTest): @property def world_size(self) -> int: return min(4, torch.cuda.device_count()) @@ -578,6 +579,193 @@ def _test_backward_prefetch_unused_in_backward( self.assertEqual(events, expected_events) events.clear() + @skip_if_lt_x_gpu(2) + def test_set_modules_to_forward_prefetch(self): + n_layers = 4 + reshard_after_forward = True + checkpoint_impl = "utils" + model, _, inp = self._init_transformer( + n_layers, reshard_after_forward, checkpoint_impl + ) + + def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: + # Use model-specific knowledge to configure forward prefetching: + # each transformer block (layer) prefetches for the next few + for i, layer in enumerate(model.layers): + if i >= len(model.layers) - num_to_prefetch: + break + layers_to_prefetch = [ + model.layers[i + j] for j in range(1, num_to_prefetch + 1) + ] + layer.set_modules_to_forward_prefetch(layers_to_prefetch) + + events: List[EventType] = [] + unshard_with_record = self._get_unshard_with_record( + FSDPParamGroup.unshard, events + ) + reshard_with_record = self._get_reshard_with_record( + FSDPParamGroup.reshard, events + ) + post_backward_with_record = self._get_post_backward_with_record( + FSDPParamGroup.post_backward, events + ) + expected_backward_events = [ + # Default backward prefetching + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + with patch_unshard(unshard_with_record), patch_reshard( + reshard_with_record + ), patch_post_backward(post_backward_with_record): + set_forward_prefetch(model, num_to_prefetch=1) + loss = model(inp) + expected_forward_events = [ + ("unshard", "", TrainingState.FORWARD), + # `layers.i` prefetches `layers.i+1` + ("unshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + self.assertEqual(events, expected_backward_events) + events.clear() + + set_forward_prefetch(model, num_to_prefetch=2) + loss = model(inp) + expected_forward_events = [ + ("unshard", "", TrainingState.FORWARD), + # `layers.i` prefetches `layers.i+1` and `layers.i+2` + ("unshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + self.assertEqual(events, expected_backward_events) + events.clear() + + @skip_if_lt_x_gpu(2) + def test_set_modules_to_backward_prefetch(self): + n_layers = 4 + reshard_after_forward = True + checkpoint_impl = "utils" + model, _, inp = self._init_transformer( + n_layers, reshard_after_forward, checkpoint_impl + ) + + def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None: + # Use model-specific knowledge to configure backward prefetching: + # each transformer block (layer) prefetches for the previous few + for i, layer in enumerate(model.layers): + if i < num_to_prefetch: + continue + layers_to_prefetch = [ + model.layers[i - j] for j in range(1, num_to_prefetch + 1) + ] + layer.set_modules_to_backward_prefetch(layers_to_prefetch) + + events: List[EventType] = [] + unshard_with_record = self._get_unshard_with_record( + FSDPParamGroup.unshard, events + ) + reshard_with_record = self._get_reshard_with_record( + FSDPParamGroup.reshard, events + ) + post_backward_with_record = self._get_post_backward_with_record( + FSDPParamGroup.post_backward, events + ) + expected_forward_events = [ + # Default forward prefetching + ("unshard", "", TrainingState.FORWARD), # root + ("unshard", "layers.0", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + with patch_unshard(unshard_with_record), patch_reshard( + reshard_with_record + ), patch_post_backward(post_backward_with_record): + set_backward_prefetch(model, num_to_prefetch=1) + loss = model(inp) + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + expected_backward_events = [ + # Root prefetches `layers.3` per default + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + # `layers.i` prefetches for `layers.i-1` (same as default) + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + self.assertEqual(events, expected_backward_events) + events.clear() + + set_backward_prefetch(model, num_to_prefetch=2) + loss = model(inp) + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + expected_backward_events = [ + # Root prefetches `layers.3` per default + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + # `layers.i` prefetches for `layers.i-1` and `layers.i-2` + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + self.assertEqual(events, expected_backward_events) + events.clear() + def _init_transformer( self, n_layers: int, @@ -614,6 +802,21 @@ def unshard_with_record(self, *args, **kwargs): return unshard_with_record + def _get_reshard_with_record( + self, orig_reshard: Callable, events: List[EventType] + ) -> Callable: + def reshard_with_record(self, *args, **kwargs): + nonlocal events + if ( + self._training_state == TrainingState.FORWARD + and not self._reshard_after_forward + ): # skip no-ops + return + events.append(("reshard", self._module_fqn, self._training_state)) + return orig_reshard(self, *args, **kwargs) + + return reshard_with_record + def _get_post_backward_with_record( self, orig_post_backward: Callable, events: List[EventType] ) -> Callable: diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 836013f7fb243..3dbaa65243794 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -3,6 +3,7 @@ import contextlib import copy import functools +import itertools import unittest from typing import Iterable, List, Tuple, Type, Union @@ -337,7 +338,6 @@ def _test_train_parity_multi_group( return assert device_type in ("cuda", "cpu"), f"{device_type}" torch.manual_seed(42) - lin_dim = 32 vocab_size = 1024 model_args = ModelArgs( n_layers=3, @@ -494,6 +494,44 @@ def forward(self, x): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_explicit_prefetching(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=8, dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + num_to_forward_prefetch = num_to_backward_prefetch = 2 + for i, layer in enumerate(model.layers): + if i >= len(model.layers) - num_to_forward_prefetch: + break + layers_to_prefetch = [ + model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1) + ] + layer.set_modules_to_forward_prefetch(layers_to_prefetch) + for i, layer in enumerate(model.layers): + if i < num_to_backward_prefetch: + continue + layers_to_prefetch = [ + model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1) + ] + layer.set_modules_to_backward_prefetch(layers_to_prefetch) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + for iter_idx in range(10): + losses: List[torch.Tensor] = [] + for _model, _optim in ((ref_model, ref_optim), (model, optim)): + _optim.zero_grad() + losses.append(_model(inp).sum()) + losses[-1].backward() + _optim.step() + self.assertEqual(losses[0], losses[1]) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 63142466f001f..06fa90e060e70 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -283,14 +283,15 @@ def _record_post_forward(self) -> None: self.comm_ctx.post_forward_order.append(self) self._post_forward_indices.append(post_forward_index) - def pre_backward(self, *unused: Any): + def pre_backward(self, default_prefetch: bool, *unused: Any): if self._training_state == TrainingState.PRE_BACKWARD: return with record_function(self._with_fqn("FSDP::pre_backward")): self._training_state = TrainingState.PRE_BACKWARD self.unshard() # no-op if prefetched self.wait_for_unshard() - self._prefetch_unshard() + if default_prefetch: + self._backward_prefetch() def post_backward(self, *unused: Any): self._training_state = TrainingState.POST_BACKWARD @@ -348,7 +349,7 @@ def finalize_backward(self): fsdp_param.grad_offload_event = None self._post_forward_indices.clear() - def _prefetch_unshard(self): + def _backward_prefetch(self) -> None: if self._training_state == TrainingState.PRE_BACKWARD: if not self._post_forward_indices: # Can be cleared if running multiple `backward`s @@ -360,11 +361,23 @@ def _prefetch_unshard(self): # have mistargeted prefetches if not all modules used in forward # are used in this backward target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] - target_fqn = target_fsdp_param_group._module_fqn - with record_function( - self._with_fqn(f"FSDP::backward_prefetch for {target_fqn}") - ), target_fsdp_param_group.use_training_state(TrainingState.PRE_BACKWARD): - target_fsdp_param_group.unshard() + self._prefetch_unshard(target_fsdp_param_group, "backward") + + @staticmethod + def _prefetch_unshard( + target_fsdp_param_group: "FSDPParamGroup", pass_type: str + ) -> None: + if pass_type == "backward": + training_state = TrainingState.PRE_BACKWARD + elif pass_type == "forward": + training_state = TrainingState.FORWARD + else: + raise ValueError(f"Unknown pass type: {pass_type}") + target_fqn = target_fsdp_param_group._module_fqn + with record_function( + f"FSDP::{pass_type}_prefetch for {target_fqn}" + ), target_fsdp_param_group.use_training_state(training_state): + target_fsdp_param_group.unshard() # Utilities # def _to_sharded(self): diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index f080e75503384..79a09342704ff 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -56,6 +56,8 @@ def __init__(self): self._state_ctx = FSDPStateContext() self._comm_ctx = FSDPCommContext() self._training_state: TrainingState = TrainingState.IDLE + self._states_to_forward_prefetch: List[FSDPState] = [] + self._states_to_backward_prefetch: List[FSDPState] = [] # Define a separate init since `__init__` is called in the contract def init( @@ -171,6 +173,9 @@ def _pre_forward( args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) if self._fsdp_param_group: args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) + for fsdp_state in self._states_to_forward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "forward") return args, kwargs @disable_if_config_true @@ -205,7 +210,11 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: self._training_state = TrainingState.PRE_BACKWARD self._register_root_post_backward_final_callback() if self._fsdp_param_group: - self._fsdp_param_group.pre_backward() + default_prefetch = len(self._states_to_backward_prefetch) == 0 + self._fsdp_param_group.pre_backward(default_prefetch) + for fsdp_state in self._states_to_backward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "backward") return grad def _root_post_backward_final_callback(self) -> None: diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index d3e70b38eac91..4961eab12fa32 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools -from typing import Any, cast, NoReturn, Optional, Union +from typing import Any, cast, Iterable, List, NoReturn, Optional, Union import torch import torch.nn as nn @@ -270,6 +270,45 @@ def set_reshard_after_backward( if fsdp_param_group := state._fsdp_param_group: fsdp_param_group.reshard_after_backward = reshard_after_backward + def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in forward. The prefetching runs after this + module's all-gather copy-out. + + Passing a singleton list containing the next FSDP module gives the same + all-gather overlap behavior as the default overlap behavior, except the + prefetched all-gather is issued earlier from the CPU. Passing a list + with at least length two is required for more aggressive overlap. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_forward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in backward. This overrides the default backward + pretching implementation that prefetches the next FSDP module based on + the reverse post-forward order. + + Passing a singleton list containing the previous FSDP module gives the + same all-gather overlap behavior as the default overlap behavior. + Passing a list with at least length two is required for more aggressive + overlap. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_backward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") @@ -350,3 +389,9 @@ def wrapped_method(self, *args, **kwargs): method_name, wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] ) + + +def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: + for module in modules: + if not isinstance(module, FSDPModule): + raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}") diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 2b5fdc613c2e2..cfa16307da334 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -997,6 +997,19 @@ def patch_unshard(new_unshard: Callable): FSDPParamGroup.unshard = orig_unshard +@no_type_check +@contextlib.contextmanager +def patch_reshard(new_reshard: Callable): + orig_reshard = FSDPParamGroup.reshard + dist.barrier() + FSDPParamGroup.reshard = new_reshard + try: + yield + finally: + dist.barrier() + FSDPParamGroup.reshard = orig_reshard + + @no_type_check @contextlib.contextmanager def patch_post_backward(new_post_backward: Callable):