Skip to content

Commit

Permalink
[FSDP2] Added APIs for explicit fwd/bwd prefetching
Browse files Browse the repository at this point in the history
ghstack-source-id: 86e664adafe8e22d99a3809ba6712d54c25dc5e5
Pull Request resolved: #128884
  • Loading branch information
awgu committed Jun 17, 2024
1 parent 24443fe commit 8721000
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 12 deletions.
205 changes: 204 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
FSDPTestMultiThread,
MLP,
patch_post_backward,
patch_reshard,
patch_unshard,
)
from torch.testing._internal.common_utils import run_tests
Expand Down Expand Up @@ -372,7 +373,7 @@ def test_manual_reshard_with_reshard_after_forward_false(self):
)


class TestFullyShardBackwardPrefetch(FSDPTest):
class TestFullyShardPrefetch(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
Expand Down Expand Up @@ -578,6 +579,193 @@ def _test_backward_prefetch_unused_in_backward(
self.assertEqual(events, expected_events)
events.clear()

@skip_if_lt_x_gpu(2)
def test_set_modules_to_forward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)

def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure forward prefetching:
# each transformer block (layer) prefetches for the next few
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_backward_events = [
# Default backward prefetching
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_forward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()

set_forward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1` and `layers.i+2`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()

@skip_if_lt_x_gpu(2)
def test_set_modules_to_backward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)

def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure backward prefetching:
# each transformer block (layer) prefetches for the previous few
for i, layer in enumerate(model.layers):
if i < num_to_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_forward_events = [
# Default forward prefetching
("unshard", "", TrainingState.FORWARD), # root
("unshard", "layers.0", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_backward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` (same as default)
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()

set_backward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` and `layers.i-2`
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()

def _init_transformer(
self,
n_layers: int,
Expand Down Expand Up @@ -614,6 +802,21 @@ def unshard_with_record(self, *args, **kwargs):

return unshard_with_record

def _get_reshard_with_record(
self, orig_reshard: Callable, events: List[EventType]
) -> Callable:
def reshard_with_record(self, *args, **kwargs):
nonlocal events
if (
self._training_state == TrainingState.FORWARD
and not self._reshard_after_forward
): # skip no-ops
return
events.append(("reshard", self._module_fqn, self._training_state))
return orig_reshard(self, *args, **kwargs)

return reshard_with_record

def _get_post_backward_with_record(
self, orig_post_backward: Callable, events: List[EventType]
) -> Callable:
Expand Down
40 changes: 39 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import copy
import functools
import itertools
import unittest
from typing import Iterable, List, Tuple, Type, Union

Expand Down Expand Up @@ -337,7 +338,6 @@ def _test_train_parity_multi_group(
return
assert device_type in ("cuda", "cpu"), f"{device_type}"
torch.manual_seed(42)
lin_dim = 32
vocab_size = 1024
model_args = ModelArgs(
n_layers=3,
Expand Down Expand Up @@ -494,6 +494,44 @@ def forward(self, x):
_optim.step()
self.assertEqual(losses[0], losses[1])

@skip_if_lt_x_gpu(2)
def test_explicit_prefetching(self):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

num_to_forward_prefetch = num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)

torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad()
losses.append(_model(inp).sum())
losses[-1].backward()
_optim.step()
self.assertEqual(losses[0], losses[1])


class TestFullyShard1DTrainingCompose(FSDPTest):
@property
Expand Down
29 changes: 21 additions & 8 deletions torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,15 @@ def _record_post_forward(self) -> None:
self.comm_ctx.post_forward_order.append(self)
self._post_forward_indices.append(post_forward_index)

def pre_backward(self, *unused: Any):
def pre_backward(self, default_prefetch: bool, *unused: Any):
if self._training_state == TrainingState.PRE_BACKWARD:
return
with record_function(self._with_fqn("FSDP::pre_backward")):
self._training_state = TrainingState.PRE_BACKWARD
self.unshard() # no-op if prefetched
self.wait_for_unshard()
self._prefetch_unshard()
if default_prefetch:
self._backward_prefetch()

def post_backward(self, *unused: Any):
self._training_state = TrainingState.POST_BACKWARD
Expand Down Expand Up @@ -348,7 +349,7 @@ def finalize_backward(self):
fsdp_param.grad_offload_event = None
self._post_forward_indices.clear()

def _prefetch_unshard(self):
def _backward_prefetch(self) -> None:
if self._training_state == TrainingState.PRE_BACKWARD:
if not self._post_forward_indices:
# Can be cleared if running multiple `backward`s
Expand All @@ -360,11 +361,23 @@ def _prefetch_unshard(self):
# have mistargeted prefetches if not all modules used in forward
# are used in this backward
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
self._with_fqn(f"FSDP::backward_prefetch for {target_fqn}")
), target_fsdp_param_group.use_training_state(TrainingState.PRE_BACKWARD):
target_fsdp_param_group.unshard()
self._prefetch_unshard(target_fsdp_param_group, "backward")

@staticmethod
def _prefetch_unshard(
target_fsdp_param_group: "FSDPParamGroup", pass_type: str
) -> None:
if pass_type == "backward":
training_state = TrainingState.PRE_BACKWARD
elif pass_type == "forward":
training_state = TrainingState.FORWARD
else:
raise ValueError(f"Unknown pass type: {pass_type}")
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
f"FSDP::{pass_type}_prefetch for {target_fqn}"
), target_fsdp_param_group.use_training_state(training_state):
target_fsdp_param_group.unshard()

# Utilities #
def _to_sharded(self):
Expand Down
11 changes: 10 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self):
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []

# Define a separate init since `__init__` is called in the contract
def init(
Expand Down Expand Up @@ -171,6 +173,9 @@ def _pre_forward(
args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)
if self._fsdp_param_group:
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
for fsdp_state in self._states_to_forward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "forward")
return args, kwargs

@disable_if_config_true
Expand Down Expand Up @@ -205,7 +210,11 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor:
self._training_state = TrainingState.PRE_BACKWARD
self._register_root_post_backward_final_callback()
if self._fsdp_param_group:
self._fsdp_param_group.pre_backward()
default_prefetch = len(self._states_to_backward_prefetch) == 0
self._fsdp_param_group.pre_backward(default_prefetch)
for fsdp_state in self._states_to_backward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "backward")
return grad

def _root_post_backward_final_callback(self) -> None:
Expand Down

0 comments on commit 8721000

Please sign in to comment.