Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] Fixed 2D clip grad norm test #126497

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ test_inductor_distributed() {
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_frozen.py --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_clip_grad_norm.py -k test_clip_grad_norm_2d --verbose
python test/run_test.py -i distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration --verbose

# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Expand All @@ -30,12 +30,12 @@ def _test_clip_grad_norm(
ref_optim: torch.optim.Optimizer,
model: nn.Module,
optim: torch.optim.Optimizer,
inp: torch.Tensor,
dp_mesh: Optional[DeviceMesh] = None,
):
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
for iter_idx in range(10):
ref_optim.zero_grad()
ref_model(inp).sum().backward()
Expand All @@ -53,11 +53,11 @@ def _test_clip_grad_norm(
continue
self.assertEqual(ref_grad, param.grad.full_tensor())

# Check that all gradients have norm greater than the max norm
# before clipping to ensure the clipping is not vacuous
self.assertTrue(all(vector_norm_fn(g).item() > max_norm for g in ref_grads))
# Check that at least one gradient has norm greater than the max
# norm before clipping to ensure the clipping is not vacuous
self.assertTrue(any(vector_norm_fn(g).item() > max_norm for g in ref_grads))
self.assertTrue(
all(vector_norm_fn(g).item() > max_norm for g in local_grads)
any(vector_norm_fn(g).item() > max_norm for g in local_grads)
)

# Check gradient norm clipping via total norm and individual
Expand Down Expand Up @@ -111,7 +111,10 @@ def test_clip_grad_norm_1d(self):
fully_shard(module)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
self._test_clip_grad_norm(1, norm_type, ref_model, ref_optim, model, optim)
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
self._test_clip_grad_norm(
1, norm_type, ref_model, ref_optim, model, optim, inp
)


class TestClipGradNormWorldSize4(_TestClipGradNormBase):
Expand All @@ -130,20 +133,23 @@ def test_clip_grad_norm_2d(self):
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
# Test using an MLP stack, not a transformer, since the transformer
# has some more significant numeric differences from the TP
model = MLPStack(16, with_seq_parallel=True)
ref_model = replicate(
copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group()
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
model = Transformer.parallelize(model, tp_mesh, use_seq_parallel=True)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)
model.parallelize(
tp_mesh,
dp_mesh,
use_activation_checkpointing=False,
reshard_after_forward=True,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randn(2, 16, device="cuda")
self._test_clip_grad_norm(
1, norm_type, ref_model, ref_optim, model, optim, dp_mesh
0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh
)


Expand Down
61 changes: 39 additions & 22 deletions torch/testing/_internal/common_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, Callable, Dict, no_type_check, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
no_type_check,
Optional,
Tuple,
Type,
Union,
)
from unittest import mock

import torch
Expand Down Expand Up @@ -39,6 +49,7 @@
ColwiseParallel,
parallelize_module,
RowwiseParallel,
SequenceParallel,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
Expand Down Expand Up @@ -865,15 +876,17 @@ def reset_parameters(self):


class MLPStack(nn.Sequential):
def __init__(self, mlp_dim: int):
modules = [
nn.LayerNorm(mlp_dim, bias=False),
def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
modules: List[nn.Module] = [
# Use multiplier of 3 to exercise uneven case
MLP(mlp_dim, dim_multiplier=3),
MLP(mlp_dim),
MLP(mlp_dim, dim_multiplier=3),
]
if with_seq_parallel:
modules.append(nn.LayerNorm(mlp_dim, bias=False))
super().__init__(*modules)
self.with_seq_parallel = with_seq_parallel

def parallelize(
self,
Expand All @@ -882,25 +895,29 @@ def parallelize(
use_activation_checkpointing: bool,
reshard_after_forward: bool,
) -> "MLPStack":
parallelize_module(
self,
device_mesh=tp_mesh,
# Leave the layer norm as implicitly replicated
parallelize_plan={
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(use_local_output=False),
"3.in_proj": ColwiseParallel(use_local_output=False),
"3.out_proj": RowwiseParallel(),
},
)
for mlp in self:
parallelize_plan = {
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"0.in_proj": ColwiseParallel(use_local_output=False),
"0.out_proj": RowwiseParallel(use_local_output=False),
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(output_layouts=Shard(1))
if self.with_seq_parallel
else RowwiseParallel(),
}
if self.with_seq_parallel:
parallelize_plan["3"] = SequenceParallel(sequence_dim=1)
parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan)
for module in self:
if isinstance(module, nn.LayerNorm):
continue
if use_activation_checkpointing:
checkpoint(mlp)
fully_shard(mlp, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
checkpoint(module)
fully_shard(
module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward
)
fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
return self

Expand Down
Loading