Skip to content

Commit

Permalink
[FSDP][9/N] Introduce LambdaWrapPolicy
Browse files Browse the repository at this point in the history
ghstack-source-id: 3d3413d4209481f1c7248eead8a9ad92115a39a5
Pull Request resolved: #104986
  • Loading branch information
awgu committed Aug 2, 2023
1 parent c6a6fbc commit bd9c35b
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 26 deletions.
28 changes: 26 additions & 2 deletions test/distributed/_composable/fully_shard/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened, clean_tensor_name
from torch.distributed.fsdp.wrap import _Policy, ModuleWrapPolicy
from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy
from torch.testing._internal.common_dist_composable import (
CompositeParamModel,
FakeSequential,
NestedSequentialModel,
UnitModule,
)
Expand Down Expand Up @@ -43,12 +44,21 @@ def world_size(self) -> int:
@skip_if_lt_x_gpu(2)
def test_policy(self):
"""Tests passing a ``policy`` for pseudo-auto-wrapping."""

def lambda_fn(module: nn.Module):
if isinstance(module, nn.Sequential):
return True
elif isinstance(module, FakeSequential):
return {"backward_prefetch": BackwardPrefetch.BACKWARD_POST}
return False

self.run_subtests(
{
"policy": [
None,
ModuleWrapPolicy({UnitModule}),
ModuleWrapPolicy({nn.Sequential}),
CustomPolicy(lambda_fn),
],
},
self._test_policy,
Expand Down Expand Up @@ -140,6 +150,20 @@ def _test_fully_shard_construction(
composable_module_classes.add(type(submodule))
self.assertEqual(local_module_classes, composable_module_classes)

# Check that the composable module has the same FSDP states with the
# same attributes (mainly checking backward prefetch since the lambda
# wrap policy overrides it for `FakeSequential`)
wrapper_states = traversal_utils._get_fsdp_states(fsdp_wrapped_model)
composable_states = traversal_utils._get_fsdp_states(composable_module)
self.assertEqual(len(wrapper_states), len(composable_states))
for wrapper_state, composable_state in zip(wrapper_states, composable_states):
self.assertEqual(
wrapper_state.sharding_strategy, composable_state.sharding_strategy
)
self.assertEqual(
wrapper_state.backward_prefetch, composable_state.backward_prefetch
)

@skip_if_lt_x_gpu(2)
def test_device_id(self):
"""Tests passing a ``device_id``."""
Expand Down
137 changes: 133 additions & 4 deletions test/distributed/fsdp/test_wrap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]

import functools
import itertools
import os
import tempfile
import unittest
Expand All @@ -15,12 +16,15 @@
BackwardPrefetch,
CPUOffload,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
_or_policy,
_Policy,
_wrap_module_cls_individually,
always_wrap_policy,
CustomPolicy,
enable_wrap,
ModuleWrapPolicy,
size_based_auto_wrap_policy,
Expand Down Expand Up @@ -477,6 +481,109 @@ def _test_transformer_wrapping(self, auto_wrap_policy: Union[Callable, _Policy])
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_custom_policy(self):
"""
Tests ``CustomPolicy`` with both a lambda function that uses uniform
kwargs (so only returns ``False`` or ``True``) and a lambda function
that uses non-uniform kwargs (so returns a dict to override the root
kwargs).
"""
for use_uniform_kwargs in [False, True]:
self._test_custom_policy(use_uniform_kwargs)

def _test_custom_policy(self, use_uniform_kwargs: bool):
print(f"use_uniform_kwargs={use_uniform_kwargs}")
model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.NO_FSDP,
CUDAInitMode.CUDA_BEFORE,
{},
)

if use_uniform_kwargs:

def lambda_fn(module: nn.Module):
if module is model.bn:
return True
elif isinstance(
module, (TransformerEncoderLayer, TransformerDecoderLayer)
):
return True
return False

else:

def lambda_fn(module: nn.Module):
if module is model.bn:
return {"sharding_strategy": ShardingStrategy.NO_SHARD}
elif isinstance(module, TransformerEncoderLayer):
return True
elif isinstance(module, TransformerDecoderLayer):
return {
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
"backward_prefetch": BackwardPrefetch.BACKWARD_POST,
}
return False

policy = CustomPolicy(lambda_fn)
# Use a size-2 dummy PG to avoid clamping the sharding strategy to
# `NO_SHARD` as for a size-1 PG
process_group = DummyProcessGroup(rank=0, size=2)
fp16_mp = MixedPrecision(param_dtype=torch.float16)
fp32_mp = MixedPrecision()
model = FSDP(
model,
process_group=process_group,
auto_wrap_policy=policy,
mixed_precision=fp16_mp,
)
encoder_layers = set(model.module.transformer.encoder.layers)
decoder_layers = set(model.module.transformer.decoder.layers)
bn = model.module.bn
bn_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.NO_SHARD
)
bn_prefetch = BackwardPrefetch.BACKWARD_PRE
encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD
encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE
decoder_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.SHARD_GRAD_OP
)
decoder_prefetch = (
BackwardPrefetch.BACKWARD_PRE
if use_uniform_kwargs
else BackwardPrefetch.BACKWARD_POST
)
for module in model.modules():
if module is bn:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, bn_strategy)
self.assertEqual(module.backward_prefetch, bn_prefetch)
# We currently override batch norm modules to use fp32
self.assertEqual(module.mixed_precision, fp32_mp)
elif module in encoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, encoder_strategy)
self.assertEqual(module.backward_prefetch, encoder_prefetch)
self.assertEqual(module.mixed_precision, fp16_mp)
elif module in decoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, decoder_strategy)
self.assertEqual(module.backward_prefetch, decoder_prefetch)
self.assertEqual(module.mixed_precision, fp16_mp)
elif module is model:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, root_strategy)
self.assertEqual(module.backward_prefetch, root_prefetch)
self.assertEqual(module.mixed_precision, fp16_mp)
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_auto_wrap_api(self):
"""
Expand Down Expand Up @@ -705,12 +812,34 @@ def test_frozen_params(self):
Tests that mixing frozen/non-frozen parameters in an FSDP instance
raises for ``use_orig_params=False`` and warns for ``True``.
"""
for use_orig_params in [True, False]:
self._test_frozen_params(use_orig_params)
module_classes = (LoraAttention, LoraMLP, LoraDecoder)
module_wrap_policy = ModuleWrapPolicy(module_classes)

def lambda_fn_uniform(module: nn.Module):
return isinstance(module, module_classes)

def lambda_fn_nonuniform(module: nn.Module):
if isinstance(module, LoraAttention):
return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
elif isinstance(module, module_classes):
return True
return False

lambda_wrap_policy_uniform = CustomPolicy(lambda_fn_uniform)
lambda_wrap_policy_nonuniform = CustomPolicy(lambda_fn_nonuniform)

for use_orig_params, policy in itertools.product(
[True, False],
[
module_wrap_policy,
lambda_wrap_policy_uniform,
lambda_wrap_policy_nonuniform,
],
):
self._test_frozen_params(use_orig_params, policy)

def _test_frozen_params(self, use_orig_params: bool):
def _test_frozen_params(self, use_orig_params: bool, policy: _Policy):
model = LoraModel().cuda()
policy = ModuleWrapPolicy({LoraAttention, LoraMLP, LoraDecoder})
msg = "layers.0.attn has both parameters with requires_grad=True and False. "
if use_orig_params:
msg += "We do not recommend wrapping such modules"
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def fully_shard(
_annotate_modules_for_dynamo(module, state._ignored_modules, True)
state = _init_process_group_state(state, process_group, strategy, policy)
if policy is not None:
fsdp_kwargs = {
root_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
Expand All @@ -80,13 +80,13 @@ def fully_shard(
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
fsdp_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
fsdp_kwargs,
root_kwargs,
fully_shard,
)
state = _init_core_state(
Expand Down
22 changes: 10 additions & 12 deletions torch/distributed/fsdp/_wrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
_recursive_wrap,
_run_mixed_precision_override_policy,
_wrap_module_cls_individually,
ModuleWrapPolicy,
)


Expand All @@ -28,42 +27,41 @@ def _auto_wrap(
policy: Union[Callable, _Policy],
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
fsdp_kwargs: Dict[str, Any],
fsdp_fn: Callable, # `FullyShardedDataParallel` or `fully_shard`
root_kwargs: Dict[str, Any],
fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
):
"""
Auto wraps modules in ``root_module`` 's tree according to ``policy``
following a post-order traversal.
Precondition: ``fsdp_kwargs`` should contain all FSDP arguments except
Precondition: ``root_kwargs`` should contain all arguments except
``module``. This function accepts the kwargs dict directly since it gets
forwarded into the post-order traversal function.
"""
mixed_precision = fsdp_kwargs["mixed_precision"]
mixed_precision = root_kwargs["mixed_precision"]
is_wrapper = inspect.isclass(fsdp_fn)
# TODO: We may relax this no-nested-wrapping constraint to support manual
# wrapping followed by auto wrapping.
_check_nested_wrapping(root_module)

# TODO: Start migration to refactored auto wrapping with `ModuleWrapPolicy`
if isinstance(policy, ModuleWrapPolicy):
fsdp_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
if isinstance(policy, _Policy):
root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
target_module_to_kwargs = policy._run_policy(
root_module, ignored_modules, fsdp_kwargs
root_module, ignored_modules, root_kwargs
)
if mixed_precision is not None:
target_module_to_kwargs = _run_mixed_precision_override_policy(
root_module,
mixed_precision._module_classes_to_ignore,
ignored_modules,
fsdp_kwargs,
root_kwargs,
target_module_to_kwargs,
)
overridden_module_classes = _override_module_mixed_precision(
root_module, mixed_precision._module_classes_to_ignore
)
_warn_on_overridden_mixed_precision(overridden_module_classes)
use_orig_params = fsdp_kwargs.get("use_orig_params", False)
use_orig_params = root_kwargs.get("use_orig_params", False)
_validate_frozen_params(
root_module,
set(target_module_to_kwargs.keys()),
Expand Down Expand Up @@ -100,7 +98,7 @@ def _auto_wrap(
)
recursive_wrap_kwargs["auto_wrap_policy"] = policy
_warn_on_overridden_mixed_precision(overridden_module_classes)
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) # type: ignore[arg-type]
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]


def _check_nested_wrapping(root_module: nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def __init__(
self, process_group, sharding_strategy, auto_wrap_policy
)
if auto_wrap_policy is not None:
fsdp_kwargs = {
root_kwargs = {
"process_group": process_group,
"sharding_strategy": sharding_strategy,
"cpu_offload": cpu_offload,
Expand All @@ -452,14 +452,14 @@ def __init__(
# Share root process groups with children to maintain
# the invariant that all FSDP modules will have the same
# process groups.
fsdp_kwargs["process_group"] = (self.process_group, self._inter_node_pg)
root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)

_auto_wrap(
module,
auto_wrap_policy,
self._ignored_modules,
self._ignored_params,
fsdp_kwargs,
root_kwargs,
FullyShardedDataParallel,
)

Expand Down

0 comments on commit bd9c35b

Please sign in to comment.