Skip to content

Commit

Permalink
[FSDP2] Added set_post_optim_event (#128975)
Browse files Browse the repository at this point in the history
This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on.
```
def set_post_optim_event(self, event: torch.cuda.Event) -> None:
```
By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models.

To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on.

We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`).

---

One possible way to use the API is to register a post-step hook on the optimizer. For example:
https://github.com/pytorch/pytorch/blob/12e8d1399b979b45d16f0934017f742d01ab2b8d/test/distributed/_composable/fsdp/test_fully_shard_training.py#L546-L552

Pull Request resolved: #128975
Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy
ghstack dependencies: #128884
  • Loading branch information
awgu authored and pytorchmergebot committed Jun 18, 2024
1 parent d9c294c commit ac5f565
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 14 deletions.
82 changes: 71 additions & 11 deletions test/distributed/_composable/fsdp/test_fully_shard_overlap.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Owner(s): ["oncall: distributed"]

import functools
from typing import Callable

import torch
import torch.distributed as dist
import torch.nn as nn

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
Expand All @@ -23,15 +25,6 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_fully_shard_training_overlap(self):
class LinearWithSleep(nn.Module):
def __init__(self, dim: int, sleep_ms: int):
super().__init__()
self.weight = nn.Parameter(torch.randn((dim, dim)))
self.sleep_ms = sleep_ms

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))

torch.manual_seed(42)

# Use non-trivial comm. time but still shorter than compute time
Expand All @@ -44,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
fully_shard(model, reshard_after_forward=True)

orig_all_gather_into_tensor = dist.all_gather_into_tensor
orig_reduce_scatter = dist.reduce_scatter_tensor
orig_reduce_scatter_tensor = dist.reduce_scatter_tensor
comm_stream = torch.cuda.Stream()

def delay_collective():
Expand All @@ -61,7 +54,7 @@ def delayed_all_gather(*args, **kwargs):

def delayed_reduce_scatter(*args, **kwargs):
delay_collective()
return orig_reduce_scatter(*args, **kwargs)
return orig_reduce_scatter_tensor(*args, **kwargs)

inp = torch.randn((2, dim), device="cuda")
loss = model(inp).sum() # warmup CUDA and allocator
Expand Down Expand Up @@ -92,6 +85,63 @@ def fwd_bwd():
)
self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time)

@skip_if_lt_x_gpu(2)
def test_fully_shard_post_optim_event_overlap(self):
torch.manual_seed(42)

# Use non-trivial comm. time but still shorter than compute time
dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10)
# Define the model to have a high-compute linear followed by a
# low-compute linear, where only the low-compute linear uses FSDP
model = nn.Sequential(
LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim)
).cuda()
fully_shard(model[1], reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

orig_all_gather_into_tensor = dist.all_gather_into_tensor

def delayed_all_gather(*args, **kwargs):
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
return orig_all_gather_into_tensor(*args, **kwargs)

inp = torch.randn((2, dim), device="cuda")

def run_train_steps(num_iters: int, use_post_optim_event: bool):
for _ in range(num_iters):
optim.zero_grad()
with patch_all_gather(delayed_all_gather):
loss = model(inp).sum()
loss.backward()
with implicit_replication():
optim.step()
if use_post_optim_event:
post_optim_event = torch.cuda.current_stream().record_event()
model[1].set_post_optim_event(post_optim_event)

run_train_steps(1, False) # warmup CUDA and allocator
num_iters = 5
baseline_time = self._time_fn(
functools.partial(run_train_steps, num_iters, False)
)
test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True))

buffer_ms = 4 # CPU delays and copies
# Baseline: FSDP all-gather is exposed since the FSDP module waits for
# the current stream and hence the high-compute linear
self.assertLessEqual(
baseline_time,
num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms),
)
# Test: FSDP all-gather is overlapped with the high-compute linear
# since the FSDP module only waits for the post-optim event (except on
# the 1st iteration when no event has been recorded)
expected_test_time = (
num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms
)
self.assertLessEqual(test_time, expected_test_time)
self.assertGreater(baseline_time, expected_test_time)

def _time_fn(self, fn: Callable):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -123,5 +173,15 @@ def backward(ctx, grad_output: torch.Tensor):
return grad_input, grad_weight, None


class LinearWithSleep(nn.Module):
def __init__(self, dim: int, sleep_ms: int):
super().__init__()
self.weight = nn.Parameter(torch.randn((dim, dim)))
self.sleep_ms = sleep_ms

def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))


if __name__ == "__main__":
run_tests()
41 changes: 41 additions & 0 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,47 @@ def test_explicit_prefetching(self):
_optim.step()
self.assertEqual(losses[0], losses[1])

@skip_if_lt_x_gpu(2)
def test_post_optim_event(self):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

def step_post_hook(
fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs
) -> None:
post_optim_event = torch.cuda.current_stream().record_event()
fsdp_module.set_post_optim_event(post_optim_event)

optim.register_step_post_hook(functools.partial(step_post_hook, model))

torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
# Track all losses and check for equality at the end to avoid a CPU
# sync point after each iteration
ref_losses: List[torch.Tensor] = []
losses: List[torch.Tensor] = []
for iter_idx in range(10):
ref_optim.zero_grad()
ref_losses.append(ref_model(inp).sum())
ref_losses[-1].backward()
ref_optim.step()
for iter_idx in range(10):
optim.zero_grad()
losses.append(model(inp).sum())
losses[-1].backward()
optim.step()
# Sleep after the optimizer step to allow CPU to run ahead into the
# next iteration's forward, exercising the post-optim stream sync
torch.cuda._sleep(int(25 * get_cycles_per_ms()))
for ref_loss, loss in zip(ref_losses, losses):
self.assertEqual(ref_loss, loss)


class TestFullyShard1DTrainingCompose(FSDPTest):
@property
Expand Down
14 changes: 11 additions & 3 deletions torch/distributed/_composable/fsdp/_fsdp_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self):
self.post_backward_final_callback_queued: bool = False
# Whether to finalize backward in this backward's final callback
self.is_last_backward: bool = True
# Optional user-provided event recorded after optimizer for the
# all-gather streams to wait on in the root pre-forward
self.post_optim_event: Optional[torch.cuda.Event] = None


def disable_if_config_true(func):
Expand Down Expand Up @@ -84,9 +87,14 @@ def _root_pre_forward(
self._state_ctx.iter_forward_root = self
with torch.profiler.record_function("FSDP::root_pre_forward"):
# Wait for optimizer before implicitly prefetched all-gathers
current_stream = torch.cuda.current_stream()
self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream)
self._comm_ctx.all_gather_stream.wait_stream(current_stream)
if (event := self._state_ctx.post_optim_event) is not None:
self._comm_ctx.all_gather_copy_in_stream.wait_event(event)
self._comm_ctx.all_gather_stream.wait_event(event)
self._state_ctx.post_optim_event = None
else:
current_stream = torch.cuda.current_stream()
self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream)
self._comm_ctx.all_gather_stream.wait_stream(current_stream)
if self._device.type == "cuda":
with torch.profiler.record_function("FSDP::inputs_to_device"):
args_tuple, kwargs_tuple = _to_kwargs(
Expand Down
19 changes: 19 additions & 0 deletions torch/distributed/_composable/fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,25 @@ def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
module._get_fsdp_state() for module in modules
]

def set_post_optim_event(self, event: torch.cuda.Event) -> None:
"""
Sets a post-optimizer-step event for the root FSDP module to wait the
all-gather streams on.
By default, the root FSDP module waits the all-gather streams on the
current stream to ensure that the optimizer step has finished before
all-gathering. However, this may introduce false dependencies if
there is unrelated computation after the optimizer step. This API
allows the user to provide their own event to wait on. After the root
waits on the event, the event is discarded, so this API should be
called with a new event each iteration.
Args:
event (torch.cuda.Event): Event recorded after the optimizer step
to wait all-gather streams on.
"""
self._get_fsdp_state()._state_ctx.post_optim_event = event

def _get_fsdp_state(self) -> FSDPState:
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
raise AssertionError(f"No FSDP state found on {self}")
Expand Down

0 comments on commit ac5f565

Please sign in to comment.