diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index 99823883abfb..1fca6c3f3c5a 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import functools from typing import Callable import torch @@ -7,6 +8,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor.experimental import implicit_replication from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( FSDPTest, @@ -23,15 +25,6 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fully_shard_training_overlap(self): - class LinearWithSleep(nn.Module): - def __init__(self, dim: int, sleep_ms: int): - super().__init__() - self.weight = nn.Parameter(torch.randn((dim, dim))) - self.sleep_ms = sleep_ms - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) - torch.manual_seed(42) # Use non-trivial comm. time but still shorter than compute time @@ -44,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: fully_shard(model, reshard_after_forward=True) orig_all_gather_into_tensor = dist.all_gather_into_tensor - orig_reduce_scatter = dist.reduce_scatter_tensor + orig_reduce_scatter_tensor = dist.reduce_scatter_tensor comm_stream = torch.cuda.Stream() def delay_collective(): @@ -61,7 +54,7 @@ def delayed_all_gather(*args, **kwargs): def delayed_reduce_scatter(*args, **kwargs): delay_collective() - return orig_reduce_scatter(*args, **kwargs) + return orig_reduce_scatter_tensor(*args, **kwargs) inp = torch.randn((2, dim), device="cuda") loss = model(inp).sum() # warmup CUDA and allocator @@ -92,6 +85,63 @@ def fwd_bwd(): ) self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time) + @skip_if_lt_x_gpu(2) + def test_fully_shard_post_optim_event_overlap(self): + torch.manual_seed(42) + + # Use non-trivial comm. time but still shorter than compute time + dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10) + # Define the model to have a high-compute linear followed by a + # low-compute linear, where only the low-compute linear uses FSDP + model = nn.Sequential( + LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim) + ).cuda() + fully_shard(model[1], reshard_after_forward=False) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + orig_all_gather_into_tensor = dist.all_gather_into_tensor + + def delayed_all_gather(*args, **kwargs): + torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms())) + return orig_all_gather_into_tensor(*args, **kwargs) + + inp = torch.randn((2, dim), device="cuda") + + def run_train_steps(num_iters: int, use_post_optim_event: bool): + for _ in range(num_iters): + optim.zero_grad() + with patch_all_gather(delayed_all_gather): + loss = model(inp).sum() + loss.backward() + with implicit_replication(): + optim.step() + if use_post_optim_event: + post_optim_event = torch.cuda.current_stream().record_event() + model[1].set_post_optim_event(post_optim_event) + + run_train_steps(1, False) # warmup CUDA and allocator + num_iters = 5 + baseline_time = self._time_fn( + functools.partial(run_train_steps, num_iters, False) + ) + test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True)) + + buffer_ms = 4 # CPU delays and copies + # Baseline: FSDP all-gather is exposed since the FSDP module waits for + # the current stream and hence the high-compute linear + self.assertLessEqual( + baseline_time, + num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms), + ) + # Test: FSDP all-gather is overlapped with the high-compute linear + # since the FSDP module only waits for the post-optim event (except on + # the 1st iteration when no event has been recorded) + expected_test_time = ( + num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms + ) + self.assertLessEqual(test_time, expected_test_time) + self.assertGreater(baseline_time, expected_test_time) + def _time_fn(self, fn: Callable): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -123,5 +173,15 @@ def backward(ctx, grad_output: torch.Tensor): return grad_input, grad_weight, None +class LinearWithSleep(nn.Module): + def __init__(self, dim: int, sleep_ms: int): + super().__init__() + self.weight = nn.Parameter(torch.randn((dim, dim))) + self.sleep_ms = sleep_ms + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 3dbaa6524379..abc579b40d62 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -532,6 +532,47 @@ def test_explicit_prefetching(self): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_post_optim_event(self): + torch.manual_seed(42) + model_args = ModelArgs(dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + def step_post_hook( + fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs + ) -> None: + post_optim_event = torch.cuda.current_stream().record_event() + fsdp_module.set_post_optim_event(post_optim_event) + + optim.register_step_post_hook(functools.partial(step_post_hook, model)) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + # Track all losses and check for equality at the end to avoid a CPU + # sync point after each iteration + ref_losses: List[torch.Tensor] = [] + losses: List[torch.Tensor] = [] + for iter_idx in range(10): + ref_optim.zero_grad() + ref_losses.append(ref_model(inp).sum()) + ref_losses[-1].backward() + ref_optim.step() + for iter_idx in range(10): + optim.zero_grad() + losses.append(model(inp).sum()) + losses[-1].backward() + optim.step() + # Sleep after the optimizer step to allow CPU to run ahead into the + # next iteration's forward, exercising the post-optim stream sync + torch.cuda._sleep(int(25 * get_cycles_per_ms())) + for ref_loss, loss in zip(ref_losses, losses): + self.assertEqual(ref_loss, loss) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index c6cdb2b29880..f04e6f6d0929 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -36,6 +36,9 @@ def __init__(self): self.post_backward_final_callback_queued: bool = False # Whether to finalize backward in this backward's final callback self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[torch.cuda.Event] = None def disable_if_config_true(func): @@ -84,9 +87,14 @@ def _root_pre_forward( self._state_ctx.iter_forward_root = self with torch.profiler.record_function("FSDP::root_pre_forward"): # Wait for optimizer before implicitly prefetched all-gathers - current_stream = torch.cuda.current_stream() - self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) - self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = torch.cuda.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) if self._device.type == "cuda": with torch.profiler.record_function("FSDP::inputs_to_device"): args_tuple, kwargs_tuple = _to_kwargs( diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index e8ab3466118b..88180f40f792 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -309,6 +309,25 @@ def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: module._get_fsdp_state() for module in modules ] + def set_post_optim_event(self, event: torch.cuda.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (torch.cuda.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}")