Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] Added APIs for explicit fwd/bwd prefetching #128884

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
205 changes: 204 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
FSDPTestMultiThread,
MLP,
patch_post_backward,
patch_reshard,
patch_unshard,
)
from torch.testing._internal.common_utils import run_tests
Expand Down Expand Up @@ -372,7 +373,7 @@ def test_manual_reshard_with_reshard_after_forward_false(self):
)


class TestFullyShardBackwardPrefetch(FSDPTest):
class TestFullyShardPrefetch(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
Expand Down Expand Up @@ -578,6 +579,193 @@ def _test_backward_prefetch_unused_in_backward(
self.assertEqual(events, expected_events)
events.clear()

@skip_if_lt_x_gpu(2)
def test_set_modules_to_forward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)

def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure forward prefetching:
# each transformer block (layer) prefetches for the next few
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_backward_events = [
# Default backward prefetching
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_forward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()

set_forward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
expected_forward_events = [
("unshard", "", TrainingState.FORWARD),
# `layers.i` prefetches `layers.i+1` and `layers.i+2`
("unshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
self.assertEqual(events, expected_backward_events)
events.clear()

@skip_if_lt_x_gpu(2)
def test_set_modules_to_backward_prefetch(self):
n_layers = 4
reshard_after_forward = True
checkpoint_impl = "utils"
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)

def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
# Use model-specific knowledge to configure backward prefetching:
# each transformer block (layer) prefetches for the previous few
for i, layer in enumerate(model.layers):
if i < num_to_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)

events: List[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
reshard_with_record = self._get_reshard_with_record(
FSDPParamGroup.reshard, events
)
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
expected_forward_events = [
# Default forward prefetching
("unshard", "", TrainingState.FORWARD), # root
("unshard", "layers.0", TrainingState.FORWARD),
("reshard", "layers.0", TrainingState.FORWARD),
("unshard", "layers.1", TrainingState.FORWARD),
("reshard", "layers.1", TrainingState.FORWARD),
("unshard", "layers.2", TrainingState.FORWARD),
("reshard", "layers.2", TrainingState.FORWARD),
("unshard", "layers.3", TrainingState.FORWARD),
("reshard", "layers.3", TrainingState.FORWARD),
]
with patch_unshard(unshard_with_record), patch_reshard(
reshard_with_record
), patch_post_backward(post_backward_with_record):
set_backward_prefetch(model, num_to_prefetch=1)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` (same as default)
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()

set_backward_prefetch(model, num_to_prefetch=2)
loss = model(inp)
self.assertEqual(events, expected_forward_events)
events.clear()
loss.sum().backward()
expected_backward_events = [
# Root prefetches `layers.3` per default
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
# `layers.i` prefetches for `layers.i-1` and `layers.i-2`
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
("reshard", "layers.3", TrainingState.POST_BACKWARD),
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
("reshard", "layers.2", TrainingState.POST_BACKWARD),
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
("reshard", "layers.1", TrainingState.POST_BACKWARD),
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
("reshard", "layers.0", TrainingState.POST_BACKWARD),
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD),
]
self.assertEqual(events, expected_backward_events)
events.clear()

def _init_transformer(
self,
n_layers: int,
Expand Down Expand Up @@ -614,6 +802,21 @@ def unshard_with_record(self, *args, **kwargs):

return unshard_with_record

def _get_reshard_with_record(
self, orig_reshard: Callable, events: List[EventType]
) -> Callable:
def reshard_with_record(self, *args, **kwargs):
nonlocal events
if (
self._training_state == TrainingState.FORWARD
and not self._reshard_after_forward
): # skip no-ops
return
events.append(("reshard", self._module_fqn, self._training_state))
return orig_reshard(self, *args, **kwargs)

return reshard_with_record

def _get_post_backward_with_record(
self, orig_post_backward: Callable, events: List[EventType]
) -> Callable:
Expand Down
40 changes: 39 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import copy
import functools
import itertools
import unittest
from typing import Iterable, List, Tuple, Type, Union

Expand Down Expand Up @@ -337,7 +338,6 @@ def _test_train_parity_multi_group(
return
assert device_type in ("cuda", "cpu"), f"{device_type}"
torch.manual_seed(42)
lin_dim = 32
vocab_size = 1024
model_args = ModelArgs(
n_layers=3,
Expand Down Expand Up @@ -494,6 +494,44 @@ def forward(self, x):
_optim.step()
self.assertEqual(losses[0], losses[1])

@skip_if_lt_x_gpu(2)
def test_explicit_prefetching(self):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

num_to_forward_prefetch = num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)

torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad()
losses.append(_model(inp).sum())
losses[-1].backward()
_optim.step()
self.assertEqual(losses[0], losses[1])


class TestFullyShard1DTrainingCompose(FSDPTest):
@property
Expand Down
29 changes: 21 additions & 8 deletions torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,15 @@ def _record_post_forward(self) -> None:
self.comm_ctx.post_forward_order.append(self)
self._post_forward_indices.append(post_forward_index)

def pre_backward(self, *unused: Any):
def pre_backward(self, default_prefetch: bool, *unused: Any):
if self._training_state == TrainingState.PRE_BACKWARD:
return
with record_function(self._with_fqn("FSDP::pre_backward")):
self._training_state = TrainingState.PRE_BACKWARD
self.unshard() # no-op if prefetched
self.wait_for_unshard()
self._prefetch_unshard()
if default_prefetch:
self._backward_prefetch()

def post_backward(self, *unused: Any):
self._training_state = TrainingState.POST_BACKWARD
Expand Down Expand Up @@ -348,7 +349,7 @@ def finalize_backward(self):
fsdp_param.grad_offload_event = None
self._post_forward_indices.clear()

def _prefetch_unshard(self):
def _backward_prefetch(self) -> None:
if self._training_state == TrainingState.PRE_BACKWARD:
if not self._post_forward_indices:
# Can be cleared if running multiple `backward`s
Expand All @@ -360,11 +361,23 @@ def _prefetch_unshard(self):
# have mistargeted prefetches if not all modules used in forward
# are used in this backward
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
self._with_fqn(f"FSDP::backward_prefetch for {target_fqn}")
), target_fsdp_param_group.use_training_state(TrainingState.PRE_BACKWARD):
target_fsdp_param_group.unshard()
self._prefetch_unshard(target_fsdp_param_group, "backward")

@staticmethod
def _prefetch_unshard(
target_fsdp_param_group: "FSDPParamGroup", pass_type: str
) -> None:
if pass_type == "backward":
training_state = TrainingState.PRE_BACKWARD
elif pass_type == "forward":
training_state = TrainingState.FORWARD
else:
raise ValueError(f"Unknown pass type: {pass_type}")
target_fqn = target_fsdp_param_group._module_fqn
with record_function(
f"FSDP::{pass_type}_prefetch for {target_fqn}"
), target_fsdp_param_group.use_training_state(training_state):
target_fsdp_param_group.unshard()

# Utilities #
def _to_sharded(self):
Expand Down
11 changes: 10 additions & 1 deletion torch/distributed/_composable/fsdp/_fsdp_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self):
self._state_ctx = FSDPStateContext()
self._comm_ctx = FSDPCommContext()
self._training_state: TrainingState = TrainingState.IDLE
self._states_to_forward_prefetch: List[FSDPState] = []
self._states_to_backward_prefetch: List[FSDPState] = []

# Define a separate init since `__init__` is called in the contract
def init(
Expand Down Expand Up @@ -171,6 +173,9 @@ def _pre_forward(
args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)
if self._fsdp_param_group:
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
for fsdp_state in self._states_to_forward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "forward")
return args, kwargs

@disable_if_config_true
Expand Down Expand Up @@ -205,7 +210,11 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor:
self._training_state = TrainingState.PRE_BACKWARD
self._register_root_post_backward_final_callback()
if self._fsdp_param_group:
self._fsdp_param_group.pre_backward()
default_prefetch = len(self._states_to_backward_prefetch) == 0
self._fsdp_param_group.pre_backward(default_prefetch)
for fsdp_state in self._states_to_backward_prefetch:
if (target_param_group := fsdp_state._fsdp_param_group) is not None:
FSDPParamGroup._prefetch_unshard(target_param_group, "backward")
return grad

def _root_post_backward_final_callback(self) -> None:
Expand Down