Skip to content

Commit

Permalink
[RFC][FSDP2] Renamed FSDP to FSDPModule (#124955)
Browse files Browse the repository at this point in the history
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module.

Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases.

Pull Request resolved: #124955
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: #124651, #124741, #124767, #124768, #124780, #124787
  • Loading branch information
awgu authored and pytorchmergebot committed Apr 29, 2024
1 parent da44d2f commit 935a946
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
8 changes: 4 additions & 4 deletions test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torch.distributed._composable import checkpoint, replicate
from torch.distributed._composable.fsdp import (
FSDP,
FSDPModule,
fully_shard,
MixedPrecisionPolicy,
OffloadPolicy,
Expand Down Expand Up @@ -630,13 +630,13 @@ def __init__(self, dim: int, mesh: DeviceMesh):

def forward(self, x: torch.Tensor):
y1, work1 = self.reduce_module1(x)
if isinstance(self.mlps.mlp1, FSDP):
if isinstance(self.mlps.mlp1, FSDPModule):
self.mlps.mlp1.unshard(async_op=True)
y2, work2 = self.reduce_module2(x)
if isinstance(self.mlps.mlp2, FSDP):
if isinstance(self.mlps.mlp2, FSDPModule):
self.mlps.mlp2.unshard(async_op=True)
y3, work3 = self.reduce_module3(x)
if isinstance(self.mlps.mlp3, FSDP):
if isinstance(self.mlps.mlp3, FSDPModule):
self.mlps.mlp3.unshard(async_op=True)
return self.mlps([y1, y2, y3], [work1, work2, work3])

Expand Down
10 changes: 5 additions & 5 deletions test/distributed/_composable/fsdp/test_fully_shard_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

import torch.nn as nn
from torch.distributed._composable.fsdp import FSDP, fully_shard
from torch.distributed._composable.fsdp import FSDPModule, fully_shard
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
from torch.testing._internal.common_utils import run_tests
Expand Down Expand Up @@ -47,22 +47,22 @@ def test_fully_shard_cls(self):
model = MLP(8)
fully_shard(model)
self.assertTrue(isinstance(model, MLP))
self.assertTrue(isinstance(model, FSDP))
self.assertTrue(isinstance(model, FSDPModule))
self.assertEqual(model.__class__.__name__, "FSDPMLP")
for module in model.modules():
if module is model:
continue
self.assertFalse(isinstance(module, FSDP))
self.assertFalse(isinstance(module, FSDPModule))

# Check that slicing into a `Sequential` does not preserve FSDP
model = nn.Sequential(*[MLP(8) for _ in range(3)])
fully_shard(model)
self.assertTrue(isinstance(model, nn.Sequential))
self.assertTrue(isinstance(model, FSDP))
self.assertTrue(isinstance(model, FSDPModule))
self.assertEqual(model.__class__.__name__, "FSDPSequential")
sliced_model = model[:2]
self.assertTrue(isinstance(sliced_model, nn.Sequential))
self.assertFalse(isinstance(sliced_model, FSDP))
self.assertFalse(isinstance(sliced_model, FSDPModule))

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_fully_shard_unsupported_module_cls(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributed._composable import checkpoint, replicate
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
FSDP,
FSDPModule,
fully_shard,
OffloadPolicy,
)
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_param_registration_after_forward(self):
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDP):
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
Expand Down Expand Up @@ -854,7 +854,7 @@ def _test_1f1b_microbatching(
# memory usage since we do not reshard after forward
if use_explicit_unshard:
for module in model.modules():
if isinstance(module, FSDP):
if isinstance(module, FSDPModule):
module.unshard(async_op=True)

# Emulate the 1f1b pipeline schedule and only reduce gradients on the
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_composable/fsdp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from .fully_shard import FSDP, fully_shard
from .fully_shard import FSDPModule, fully_shard
12 changes: 6 additions & 6 deletions torch/distributed/_composable/fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def fully_shard(
# Place FSDP leftmost for highest priority in the method resolution order
cls = module.__class__
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDP, cls), dct)
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
module.__class__ = new_cls
return module

Expand All @@ -147,14 +147,14 @@ def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> typing_extensions.Never
)


class FSDP:
class FSDPModule:
def __new__(cls, *args, **kwargs):
"""
Override ``__new__`` to remove the FSDP class and directly construct
the original class for cases like indexing into a container module.
"""
# Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
# and index 1 is the `FSDP` class itself
# and index 1 is the `FSDPModule` class itself
orig_cls = cls.__mro__[2]
self = orig_cls.__new__(orig_cls, *args, **kwargs)
self.__init__(*args, **kwargs)
Expand Down Expand Up @@ -223,7 +223,7 @@ def set_requires_gradient_sync(
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDP):
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reduce_grads = requires_gradient_sync
Expand All @@ -243,7 +243,7 @@ def set_requires_all_reduce(
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDP):
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.all_reduce_grads = requires_all_reduce
Expand All @@ -265,7 +265,7 @@ def set_reshard_after_backward(
self_module = cast(nn.Module, self)
modules = list(self_module.modules()) if recurse else [self_module]
for module in modules:
if isinstance(module, FSDP):
if isinstance(module, FSDPModule):
state = module._get_fsdp_state()
if fsdp_param_group := state._fsdp_param_group:
fsdp_param_group.reshard_after_backward = reshard_after_backward
Expand Down

0 comments on commit 935a946

Please sign in to comment.