Skip to content

Commit

Permalink
[FSDP2] Added APIs for explicit fwd/bwd prefetching (#128884)
Browse files Browse the repository at this point in the history
This PR adds two APIs `set_modules_to_forward_prefetch` and `set_modules_to_backward_prefetch` to enable explicit forward/backward all-gather prefetching, respectively.

```
def set_modules_to_forward_prefetch(self, modules: List[FSDPModule]): -> None
def set_modules_to_backward_prefetch(self, modules: List[FSDPModule]): -> None
```

**Motivation**
FSDP2 implements _reasonable defaults_ for forward and backward prefetching. In forward, it uses implicit prefetching and allows two all-gather output tensors to be alive at once (so that the current all-gather copy-out can overlap with the next all-gather). In backward, it uses explicit prefetching based on the reverse post-forward order.

However, there may be cases where with expert knowledge, we can reduce communication bubbles by moving all-gathers manually. One way to expose such behavior is to expose _prefetching limits_, i.e. integers that configure how many outstanding all-gathers/all-gather output tensors can be alive at once. IMIHO, this leans toward _easy_, not _simple_ (see [PyTorch design principles](https://pytorch.org/docs/stable/community/design.html#principle-2-simple-over-easy)).

The crux of the problem is that there may be special cases where manual intervention can give better performance. Exposing a prefetching limit and allowing users to pass a value >1 just smooths over the problem since such a limit would generally apply over the entire model even though it possibly should not. Then, expert users will see a specific all-gather that they want to deviate from this limit, and there is little we can do.

Thus, we instead choose to expose the most primitive extension point: namely, every `FSDPModule` gives an opportunity to prefetch other all-gathers in forward and in backward. How to leverage this extension point is fully up to the user. Implementing the prefetch limit can be done using this extension point (e.g. record the post-forward order yourself using forward hooks, iterate over that order, and call the `set_modules_to_forward_prefetch` / `set_modules_to_backward_prefetch` APIs).

Differential Revision: [D58700346](https://our.internmc.facebook.com/intern/diff/D58700346)
Pull Request resolved: #128884
Approved by: https://github.com/ckluk2, https://github.com/weifengpy
  • Loading branch information
awgu authored and pytorchmergebot committed Jun 18, 2024
1 parent 3dd5f0e commit f2805a0
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 12 deletions.
205 changes: 204 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
FSDPTestMultiThread,
MLP,
patch_post_backward,
patch_reshard,
patch_unshard,
)
from torch.testing._internal.common_utils import run_tests
Expand Down Expand Up @@ -372,7 +373,7 @@ def test_manual_reshard_with_reshard_after_forward_false(self):
)


class TestFullyShardBackwardPrefetch(FSDPTest):
class TestFullyShardPrefetch(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
Expand Down Expand Up @@ -578,6 +579,193 @@ def _test_backward_prefetch_unused_in_backward(
self.assertEqual(events, expected_events)
events.clear()

@skip_if_lt_x_gpu(2)
def test_set_modules_to_forward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)

def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure forward prefetching:
# each transformer block (layer) prefetches for the next few
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_backward_events = [
# Default backward prefetching
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_forward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()

set_forward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1` and `layers.i+2`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()

@skip_if_lt_x_gpu(2)
def test_set_modules_to_backward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)

def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure backward prefetching:
# each transformer block (layer) prefetches for the previous few
for i, layer in enumerate(model.layers):
if i < num_to_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_forward_events = [
# Default forward prefetching
("unshard", "", TrainingState.FORWARD), # root
("unshard", "layers.0", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_backward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` (same as default)
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()

set_backward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` and `layers.i-2`
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()

def _init_transformer(
self,
n_layers: int,
Expand Down Expand Up @@ -614,6 +802,21 @@ def unshard_with_record(self, *args, **kwargs):

return unshard_with_record

def _get_reshard_with_record(
self, orig_reshard: Callable, events: List[EventType]
) -> Callable:
def reshard_with_record(self, *args, **kwargs):
nonlocal events
if (
self._training_state == TrainingState.FORWARD
and not self._reshard_after_forward
): # skip no-ops
return
events.append(("reshard", self._module_fqn, self._training_state))
return orig_reshard(self, *args, **kwargs)

return reshard_with_record

def _get_post_backward_with_record(
self, orig_post_backward: Callable, events: List[EventType]
) -> Callable:
Expand Down
40 changes: 39 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import copy
import functools
import itertools
import unittest
from typing import Iterable, List, Tuple, Type, Union

Expand Down Expand Up @@ -337,7 +338,6 @@ def _test_train_parity_multi_group(
return
assert device_type in ("cuda", "cpu"), f"{device_type}"
torch.manual_seed(42)
lin_dim = 32
vocab_size = 1024
model_args = ModelArgs(
n_layers=3,
Expand Down Expand Up @@ -494,6 +494,44 @@ def forward(self, x):
_optim.step()
self.assertEqual(losses[0], losses[1])

@skip_if_lt_x_gpu(2)
def test_explicit_prefetching(self):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

num_to_forward_prefetch = num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)

torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad()
losses.append(_model(inp).sum())
losses[-1].backward()
_optim.step()
self.assertEqual(losses[0], losses[1])


class TestFullyShard1DTrainingCompose(FSDPTest):
@property
Expand Down
29 changes: 21 additions & 8 deletions torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,15 @@ def _record_post_forward(self) -> None:
self.comm_ctx.post_forward_order.append(self)
self._post_forward_indices.append(post_forward_index)

def pre_backward(self, *unused: Any):
def pre_backward(self, default_prefetch: bool, *unused: Any):
if self._training_state == TrainingState.PRE_BACKWARD:
return
with record_function(self._with_fqn("FSDP::pre_backward")):
self._training_state = TrainingState.PRE_BACKWARD
self.unshard() # no-op if prefetched
self.wait_for_unshard()
self._prefetch_unshard()
if default_prefetch:
self._backward_prefetch()

def post_backward(self, *unused: Any):
self._training_state = TrainingState.POST_BACKWARD
Expand Down Expand Up @@ -348,7 +349,7 @@ def finalize_backward(self):
fsdp_param.grad_offload_event = None
self._post_forward_indices.clear()

def _prefetch_unshard(self):
def _backward_prefetch(self) -> None:
if self._training_state == TrainingState.PRE_BACKWARD:
if not self._post_forward_indices:
# Can be cleared if running multiple `backward`s
Expand All @@ -360,11 +361,23 @@ def _prefetch_unshard(self):
# have mistargeted prefetches if not all modules used in forward
# are used in this backward
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
self._with_fqn(f"FSDP::backward_prefetch for {target_fqn}")
), target_fsdp_param_group.use_training_state(TrainingState.PRE_BACKWARD):
target_fsdp_param_group.unshard()
self._prefetch_unshard(target_fsdp_param_group, "backward")

@staticmethod
def _prefetch_unshard(
target_fsdp_param_group: "FSDPParamGroup", pass_type: str
) -> None:
if pass_type == "backward":
training_state = TrainingState.PRE_BACKWARD
elif pass_type == "forward":
training_state = TrainingState.FORWARD
else:
raise ValueError(f"Unknown pass type: {pass_type}")
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
f"FSDP::{pass_type}_prefetch for {target_fqn}"
), target_fsdp_param_group.use_training_state(training_state):
target_fsdp_param_group.unshard()

# Utilities #
def _to_sharded(self):
Expand Down
11 changes: 10 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self):
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []

# Define a separate init since `__init__` is called in the contract
def init(
Expand Down Expand Up @@ -171,6 +173,9 @@ def _pre_forward(
args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)
if self._fsdp_param_group:
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
for fsdp_state in self._states_to_forward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "forward")
return args, kwargs

@disable_if_config_true
Expand Down Expand Up @@ -205,7 +210,11 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor:
self._training_state = TrainingState.PRE_BACKWARD
self._register_root_post_backward_final_callback()
if self._fsdp_param_group:
self._fsdp_param_group.pre_backward()
default_prefetch = len(self._states_to_backward_prefetch) == 0
self._fsdp_param_group.pre_backward(default_prefetch)
for fsdp_state in self._states_to_backward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "backward")
return grad

def _root_post_backward_final_callback(self) -> None:
Expand Down
Loading

0 comments on commit f2805a0

Please sign in to comment.