Skip to content

Commit

Permalink
[FSDP2] Factored out MLPStack to de-dup code (#126070)
Browse files Browse the repository at this point in the history
Pull Request resolved: #126070
Approved by: https://github.com/wanchaol
ghstack dependencies: #126067
  • Loading branch information
awgu authored and pytorchmergebot committed May 14, 2024
1 parent 48f98bc commit 4ded666
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 59 deletions.
66 changes: 7 additions & 59 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,14 @@
get_optimizer_state_dict,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
FSDPTestMultiThread,
MLP,
MLPStack,
patch_all_gather,
patch_reduce_scatter,
test_compiled_fsdp,
Expand Down Expand Up @@ -915,37 +911,13 @@ def _test_train_parity_2d_mlp(
dp_pg = dp_mesh.get_group() # used for `replicate()`

torch.manual_seed(42)
model = nn.Sequential(
nn.LayerNorm(mlp_dim, bias=False),
# Use multiplier of 3 to exercise uneven case
MLP(mlp_dim, dim_multiplier=3),
MLP(mlp_dim),
MLP(mlp_dim, dim_multiplier=3),
)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)

model = parallelize_module(
model,
device_mesh=tp_mesh,
# Leave the layer norm as implicitly replicated
parallelize_plan={
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(use_local_output=False),
"3.in_proj": ColwiseParallel(use_local_output=False),
"3.out_proj": RowwiseParallel(),
},
model.parallelize(
tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward
)
for mlp in model:
if use_activation_checkpointing:
checkpoint(mlp)
fully_shard(mlp, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
fully_shard(model, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

torch.manual_seed(42 + dp_pg.rank() + 1)
Expand Down Expand Up @@ -1108,37 +1080,13 @@ def _test_2d_mlp_with_nd_mesh(
dp_pg = dp_mesh.get_group() # used for `replicate()`

torch.manual_seed(42)
model = nn.Sequential(
nn.LayerNorm(mlp_dim, bias=False),
# Use multiplier of 3 to exercise uneven case
MLP(mlp_dim, dim_multiplier=3),
MLP(mlp_dim),
MLP(mlp_dim, dim_multiplier=3),
)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)

model = parallelize_module(
model,
device_mesh=tp_mesh,
# Leave the layer norm as implicitly replicated
parallelize_plan={
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(use_local_output=False),
"3.in_proj": ColwiseParallel(use_local_output=False),
"3.out_proj": RowwiseParallel(),
},
model.parallelize(
tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward
)
for mlp in model:
if use_activation_checkpointing:
checkpoint(mlp)
fully_shard(mlp, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
fully_shard(model, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

torch.manual_seed(42 + dp_pg.rank() + 1)
Expand Down
49 changes: 49 additions & 0 deletions torch/testing/_internal/common_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._composable import checkpoint
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.fsdp._fsdp_param_group import (
FSDPParamGroup,
RegisterPostBackwardFunction,
)
from torch.distributed._tensor import distribute_tensor, DTensor, Shard
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import TrainingState
from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
Expand All @@ -32,6 +35,11 @@
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import (
Expand Down Expand Up @@ -856,6 +864,47 @@ def reset_parameters(self):
torch.nn.init.normal_(self.buffer)


class MLPStack(nn.Sequential):
def __init__(self, mlp_dim: int):
modules = [
nn.LayerNorm(mlp_dim, bias=False),
# Use multiplier of 3 to exercise uneven case
MLP(mlp_dim, dim_multiplier=3),
MLP(mlp_dim),
MLP(mlp_dim, dim_multiplier=3),
]
super().__init__(*modules)

def parallelize(
self,
tp_mesh: DeviceMesh,
dp_mesh: DeviceMesh,
use_activation_checkpointing: bool,
reshard_after_forward: bool,
) -> "MLPStack":
parallelize_module(
self,
device_mesh=tp_mesh,
# Leave the layer norm as implicitly replicated
parallelize_plan={
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(use_local_output=False),
"3.in_proj": ColwiseParallel(use_local_output=False),
"3.out_proj": RowwiseParallel(),
},
)
for mlp in self:
if use_activation_checkpointing:
checkpoint(mlp)
fully_shard(mlp, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
return self


class DoubleLinear(nn.Module):
"""
This can be used for returning multiple outputs from a module
Expand Down

0 comments on commit 4ded666

Please sign in to comment.