Skip to content

Commit

Permalink
Revert "[FSDP2] Fixed 2D clip grad norm test (#126497)"
Browse files Browse the repository at this point in the history
This reverts commit 3f28906.

Reverted #126497 on behalf of https://github.com/jeanschmidt due to reverting to check if might have introduced inductor cuda 12 issues ([comment](#126497 (comment)))
  • Loading branch information
pytorchmergebot committed May 17, 2024
1 parent 95b2766 commit d782e43
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 61 deletions.
1 change: 0 additions & 1 deletion .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ test_inductor_distributed() {
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_frozen.py --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_clip_grad_norm.py -k test_clip_grad_norm_2d --verbose
python test/run_test.py -i distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration --verbose

# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Expand All @@ -30,12 +30,12 @@ def _test_clip_grad_norm(
ref_optim: torch.optim.Optimizer,
model: nn.Module,
optim: torch.optim.Optimizer,
inp: torch.Tensor,
dp_mesh: Optional[DeviceMesh] = None,
):
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
for iter_idx in range(10):
ref_optim.zero_grad()
ref_model(inp).sum().backward()
Expand All @@ -53,11 +53,11 @@ def _test_clip_grad_norm(
continue
self.assertEqual(ref_grad, param.grad.full_tensor())

# Check that at least one gradient has norm greater than the max
# norm before clipping to ensure the clipping is not vacuous
self.assertTrue(any(vector_norm_fn(g).item() > max_norm for g in ref_grads))
# Check that all gradients have norm greater than the max norm
# before clipping to ensure the clipping is not vacuous
self.assertTrue(all(vector_norm_fn(g).item() > max_norm for g in ref_grads))
self.assertTrue(
any(vector_norm_fn(g).item() > max_norm for g in local_grads)
all(vector_norm_fn(g).item() > max_norm for g in local_grads)
)

# Check gradient norm clipping via total norm and individual
Expand Down Expand Up @@ -111,10 +111,7 @@ def test_clip_grad_norm_1d(self):
fully_shard(module)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
self._test_clip_grad_norm(
1, norm_type, ref_model, ref_optim, model, optim, inp
)
self._test_clip_grad_norm(1, norm_type, ref_model, ref_optim, model, optim)


class TestClipGradNormWorldSize4(_TestClipGradNormBase):
Expand All @@ -133,23 +130,20 @@ def test_clip_grad_norm_2d(self):
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42)
# Test using an MLP stack, not a transformer, since the transformer
# has some more significant numeric differences from the TP
model = MLPStack(16, with_seq_parallel=True)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(
copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group()
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
model.parallelize(
tp_mesh,
dp_mesh,
use_activation_checkpointing=False,
reshard_after_forward=True,
)
model = Transformer.parallelize(model, tp_mesh, use_seq_parallel=True)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randn(2, 16, device="cuda")
self._test_clip_grad_norm(
0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh
1, norm_type, ref_model, ref_optim, model, optim, dp_mesh
)


Expand Down
61 changes: 22 additions & 39 deletions torch/testing/_internal/common_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,7 @@
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import (
Any,
Callable,
Dict,
List,
no_type_check,
Optional,
Tuple,
Type,
Union,
)
from typing import Any, Callable, Dict, no_type_check, Optional, Tuple, Type, Union
from unittest import mock

import torch
Expand Down Expand Up @@ -49,7 +39,6 @@
ColwiseParallel,
parallelize_module,
RowwiseParallel,
SequenceParallel,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
Expand Down Expand Up @@ -876,17 +865,15 @@ def reset_parameters(self):


class MLPStack(nn.Sequential):
def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
modules: List[nn.Module] = [
def __init__(self, mlp_dim: int):
modules = [
nn.LayerNorm(mlp_dim, bias=False),
# Use multiplier of 3 to exercise uneven case
MLP(mlp_dim, dim_multiplier=3),
MLP(mlp_dim),
MLP(mlp_dim, dim_multiplier=3),
]
if with_seq_parallel:
modules.append(nn.LayerNorm(mlp_dim, bias=False))
super().__init__(*modules)
self.with_seq_parallel = with_seq_parallel

def parallelize(
self,
Expand All @@ -895,29 +882,25 @@ def parallelize(
use_activation_checkpointing: bool,
reshard_after_forward: bool,
) -> "MLPStack":
parallelize_plan = {
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"0.in_proj": ColwiseParallel(use_local_output=False),
"0.out_proj": RowwiseParallel(use_local_output=False),
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(output_layouts=Shard(1))
if self.with_seq_parallel
else RowwiseParallel(),
}
if self.with_seq_parallel:
parallelize_plan["3"] = SequenceParallel(sequence_dim=1)
parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan)
for module in self:
if isinstance(module, nn.LayerNorm):
continue
parallelize_module(
self,
device_mesh=tp_mesh,
# Leave the layer norm as implicitly replicated
parallelize_plan={
# Pass `use_local_output=False` to keep as DTensor to preserve
# uneven activation dims
"1.in_proj": ColwiseParallel(use_local_output=False),
"1.out_proj": RowwiseParallel(use_local_output=False),
"2.in_proj": ColwiseParallel(use_local_output=False),
"2.out_proj": RowwiseParallel(use_local_output=False),
"3.in_proj": ColwiseParallel(use_local_output=False),
"3.out_proj": RowwiseParallel(),
},
)
for mlp in self:
if use_activation_checkpointing:
checkpoint(module)
fully_shard(
module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward
)
checkpoint(mlp)
fully_shard(mlp, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
return self

Expand Down

0 comments on commit d782e43

Please sign in to comment.