From ac5f565fa7010bd77b9e779415e8709d347234b6 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 18 Jun 2024 11:41:03 -0700 Subject: [PATCH] [FSDP2] Added `set_post_optim_event` (#128975) This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on. ``` def set_post_optim_event(self, event: torch.cuda.Event) -> None: ``` By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models. To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on. We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`). --- One possible way to use the API is to register a post-step hook on the optimizer. For example: https://github.com/pytorch/pytorch/blob/12e8d1399b979b45d16f0934017f742d01ab2b8d/test/distributed/_composable/fsdp/test_fully_shard_training.py#L546-L552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128975 Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy ghstack dependencies: #128884 --- .../fsdp/test_fully_shard_overlap.py | 82 ++++++++++++++++--- .../fsdp/test_fully_shard_training.py | 41 ++++++++++ .../_composable/fsdp/_fsdp_state.py | 14 +++- .../_composable/fsdp/fully_shard.py | 19 +++++ 4 files changed, 142 insertions(+), 14 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index 99823883abfbb..1fca6c3f3c5a0 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import functools from typing import Callable import torch @@ -7,6 +8,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor.experimental import implicit_replication from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( FSDPTest, @@ -23,15 +25,6 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fully_shard_training_overlap(self): - class LinearWithSleep(nn.Module): - def __init__(self, dim: int, sleep_ms: int): - super().__init__() - self.weight = nn.Parameter(torch.randn((dim, dim))) - self.sleep_ms = sleep_ms - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) - torch.manual_seed(42) # Use non-trivial comm. time but still shorter than compute time @@ -44,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: fully_shard(model, reshard_after_forward=True) orig_all_gather_into_tensor = dist.all_gather_into_tensor - orig_reduce_scatter = dist.reduce_scatter_tensor + orig_reduce_scatter_tensor = dist.reduce_scatter_tensor comm_stream = torch.cuda.Stream() def delay_collective(): @@ -61,7 +54,7 @@ def delayed_all_gather(*args, **kwargs): def delayed_reduce_scatter(*args, **kwargs): delay_collective() - return orig_reduce_scatter(*args, **kwargs) + return orig_reduce_scatter_tensor(*args, **kwargs) inp = torch.randn((2, dim), device="cuda") loss = model(inp).sum() # warmup CUDA and allocator @@ -92,6 +85,63 @@ def fwd_bwd(): ) self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time) + @skip_if_lt_x_gpu(2) + def test_fully_shard_post_optim_event_overlap(self): + torch.manual_seed(42) + + # Use non-trivial comm. time but still shorter than compute time + dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10) + # Define the model to have a high-compute linear followed by a + # low-compute linear, where only the low-compute linear uses FSDP + model = nn.Sequential( + LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim) + ).cuda() + fully_shard(model[1], reshard_after_forward=False) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + orig_all_gather_into_tensor = dist.all_gather_into_tensor + + def delayed_all_gather(*args, **kwargs): + torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms())) + return orig_all_gather_into_tensor(*args, **kwargs) + + inp = torch.randn((2, dim), device="cuda") + + def run_train_steps(num_iters: int, use_post_optim_event: bool): + for _ in range(num_iters): + optim.zero_grad() + with patch_all_gather(delayed_all_gather): + loss = model(inp).sum() + loss.backward() + with implicit_replication(): + optim.step() + if use_post_optim_event: + post_optim_event = torch.cuda.current_stream().record_event() + model[1].set_post_optim_event(post_optim_event) + + run_train_steps(1, False) # warmup CUDA and allocator + num_iters = 5 + baseline_time = self._time_fn( + functools.partial(run_train_steps, num_iters, False) + ) + test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True)) + + buffer_ms = 4 # CPU delays and copies + # Baseline: FSDP all-gather is exposed since the FSDP module waits for + # the current stream and hence the high-compute linear + self.assertLessEqual( + baseline_time, + num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms), + ) + # Test: FSDP all-gather is overlapped with the high-compute linear + # since the FSDP module only waits for the post-optim event (except on + # the 1st iteration when no event has been recorded) + expected_test_time = ( + num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms + ) + self.assertLessEqual(test_time, expected_test_time) + self.assertGreater(baseline_time, expected_test_time) + def _time_fn(self, fn: Callable): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -123,5 +173,15 @@ def backward(ctx, grad_output: torch.Tensor): return grad_input, grad_weight, None +class LinearWithSleep(nn.Module): + def __init__(self, dim: int, sleep_ms: int): + super().__init__() + self.weight = nn.Parameter(torch.randn((dim, dim))) + self.sleep_ms = sleep_ms + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 3dbaa65243794..abc579b40d624 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -532,6 +532,47 @@ def test_explicit_prefetching(self): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_post_optim_event(self): + torch.manual_seed(42) + model_args = ModelArgs(dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + def step_post_hook( + fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs + ) -> None: + post_optim_event = torch.cuda.current_stream().record_event() + fsdp_module.set_post_optim_event(post_optim_event) + + optim.register_step_post_hook(functools.partial(step_post_hook, model)) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + # Track all losses and check for equality at the end to avoid a CPU + # sync point after each iteration + ref_losses: List[torch.Tensor] = [] + losses: List[torch.Tensor] = [] + for iter_idx in range(10): + ref_optim.zero_grad() + ref_losses.append(ref_model(inp).sum()) + ref_losses[-1].backward() + ref_optim.step() + for iter_idx in range(10): + optim.zero_grad() + losses.append(model(inp).sum()) + losses[-1].backward() + optim.step() + # Sleep after the optimizer step to allow CPU to run ahead into the + # next iteration's forward, exercising the post-optim stream sync + torch.cuda._sleep(int(25 * get_cycles_per_ms())) + for ref_loss, loss in zip(ref_losses, losses): + self.assertEqual(ref_loss, loss) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index c6cdb2b29880b..f04e6f6d09292 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -36,6 +36,9 @@ def __init__(self): self.post_backward_final_callback_queued: bool = False # Whether to finalize backward in this backward's final callback self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[torch.cuda.Event] = None def disable_if_config_true(func): @@ -84,9 +87,14 @@ def _root_pre_forward( self._state_ctx.iter_forward_root = self with torch.profiler.record_function("FSDP::root_pre_forward"): # Wait for optimizer before implicitly prefetched all-gathers - current_stream = torch.cuda.current_stream() - self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) - self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = torch.cuda.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) if self._device.type == "cuda": with torch.profiler.record_function("FSDP::inputs_to_device"): args_tuple, kwargs_tuple = _to_kwargs( diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index e8ab3466118bc..88180f40f792c 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -309,6 +309,25 @@ def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: module._get_fsdp_state() for module in modules ] + def set_post_optim_event(self, event: torch.cuda.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (torch.cuda.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}")