Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][9/N] Introduce CustomPolicy #104986

Closed
wants to merge 9 commits into from
28 changes: 26 additions & 2 deletions test/distributed/_composable/fully_shard/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened, clean_tensor_name
from torch.distributed.fsdp.wrap import _Policy, ModuleWrapPolicy
from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy
from torch.testing._internal.common_dist_composable import (
CompositeParamModel,
FakeSequential,
NestedSequentialModel,
UnitModule,
)
Expand Down Expand Up @@ -43,12 +44,21 @@ def world_size(self) -> int:
@skip_if_lt_x_gpu(2)
def test_policy(self):
"""Tests passing a ``policy`` for pseudo-auto-wrapping."""

def lambda_fn(module: nn.Module):
if isinstance(module, nn.Sequential):
return True
elif isinstance(module, FakeSequential):
return {"backward_prefetch": BackwardPrefetch.BACKWARD_POST}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a test to ensure we raise a good error when we return an invalid key? like {"foo": 1}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one option here is to use inspect to get the arg signature of the module-level API passed to the policy. However, I think the default error message from passing in an unexpected arg/kwarg is not unreasonable, so I will not worry about this change in this PR:

TypeError: __init__() got an unexpected keyword argument 'foo'

return False

self.run_subtests(
{
"policy": [
None,
ModuleWrapPolicy({UnitModule}),
ModuleWrapPolicy({nn.Sequential}),
CustomPolicy(lambda_fn),
],
},
self._test_policy,
Expand Down Expand Up @@ -140,6 +150,20 @@ def _test_fully_shard_construction(
composable_module_classes.add(type(submodule))
self.assertEqual(local_module_classes, composable_module_classes)

# Check that the composable module has the same FSDP states with the
# same attributes (mainly checking backward prefetch since the lambda
# wrap policy overrides it for `FakeSequential`)
wrapper_states = traversal_utils._get_fsdp_states(fsdp_wrapped_model)
composable_states = traversal_utils._get_fsdp_states(composable_module)
self.assertEqual(len(wrapper_states), len(composable_states))
for wrapper_state, composable_state in zip(wrapper_states, composable_states):
self.assertEqual(
wrapper_state.sharding_strategy, composable_state.sharding_strategy
)
self.assertEqual(
wrapper_state.backward_prefetch, composable_state.backward_prefetch
)

@skip_if_lt_x_gpu(2)
def test_device_id(self):
"""Tests passing a ``device_id``."""
Expand Down
137 changes: 133 additions & 4 deletions test/distributed/fsdp/test_wrap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]

import functools
import itertools
import os
import tempfile
import unittest
Expand All @@ -15,12 +16,15 @@
BackwardPrefetch,
CPUOffload,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
_or_policy,
_Policy,
_wrap_module_cls_individually,
always_wrap_policy,
CustomPolicy,
enable_wrap,
ModuleWrapPolicy,
size_based_auto_wrap_policy,
Expand Down Expand Up @@ -477,6 +481,109 @@ def _test_transformer_wrapping(self, auto_wrap_policy: Union[Callable, _Policy])
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_custom_policy(self):
"""
Tests ``CustomPolicy`` with both a lambda function that uses uniform
kwargs (so only returns ``False`` or ``True``) and a lambda function
that uses non-uniform kwargs (so returns a dict to override the root
kwargs).
"""
for use_uniform_kwargs in [False, True]:
self._test_custom_policy(use_uniform_kwargs)

def _test_custom_policy(self, use_uniform_kwargs: bool):
print(f"use_uniform_kwargs={use_uniform_kwargs}")
model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.NO_FSDP,
CUDAInitMode.CUDA_BEFORE,
{},
)

if use_uniform_kwargs:

def lambda_fn(module: nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does LambdaPolicy still get overriden and individually wrap BatchNorms if mixed precision is configured?

In general, what happens if my LambdaPolicy tells to wrap a BN, but mixed precision overrides this? What should be the expected behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, our mixed precision overrides run last. I am not sure if we should special case the LambdaWrapPolicy to be highest priority. What do you think?

I was leaning toward just keeping our overrides as absolute for now.

if module is model.bn:
return True
elif isinstance(
module, (TransformerEncoderLayer, TransformerDecoderLayer)
):
return True
return False

else:

def lambda_fn(module: nn.Module):
if module is model.bn:
return {"sharding_strategy": ShardingStrategy.NO_SHARD}
elif isinstance(module, TransformerEncoderLayer):
return True
elif isinstance(module, TransformerDecoderLayer):
return {
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
"backward_prefetch": BackwardPrefetch.BACKWARD_POST,
}
return False

policy = CustomPolicy(lambda_fn)
# Use a size-2 dummy PG to avoid clamping the sharding strategy to
# `NO_SHARD` as for a size-1 PG
process_group = DummyProcessGroup(rank=0, size=2)
fp16_mp = MixedPrecision(param_dtype=torch.float16)
fp32_mp = MixedPrecision()
model = FSDP(
model,
process_group=process_group,
auto_wrap_policy=policy,
mixed_precision=fp16_mp,
)
encoder_layers = set(model.module.transformer.encoder.layers)
decoder_layers = set(model.module.transformer.decoder.layers)
bn = model.module.bn
bn_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.NO_SHARD
)
bn_prefetch = BackwardPrefetch.BACKWARD_PRE
encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD
encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE
decoder_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.SHARD_GRAD_OP
)
decoder_prefetch = (
BackwardPrefetch.BACKWARD_PRE
if use_uniform_kwargs
else BackwardPrefetch.BACKWARD_POST
)
for module in model.modules():
if module is bn:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, bn_strategy)
self.assertEqual(module.backward_prefetch, bn_prefetch)
awgu marked this conversation as resolved.
Show resolved Hide resolved
# We currently override batch norm modules to use fp32
self.assertEqual(module.mixed_precision, fp32_mp)
elif module in encoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, encoder_strategy)
self.assertEqual(module.backward_prefetch, encoder_prefetch)
self.assertEqual(module.mixed_precision, fp16_mp)
elif module in decoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, decoder_strategy)
self.assertEqual(module.backward_prefetch, decoder_prefetch)
self.assertEqual(module.mixed_precision, fp16_mp)
elif module is model:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, root_strategy)
self.assertEqual(module.backward_prefetch, root_prefetch)
self.assertEqual(module.mixed_precision, fp16_mp)
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_auto_wrap_api(self):
"""
Expand Down Expand Up @@ -705,12 +812,34 @@ def test_frozen_params(self):
Tests that mixing frozen/non-frozen parameters in an FSDP instance
raises for ``use_orig_params=False`` and warns for ``True``.
"""
for use_orig_params in [True, False]:
self._test_frozen_params(use_orig_params)
module_classes = (LoraAttention, LoraMLP, LoraDecoder)
module_wrap_policy = ModuleWrapPolicy(module_classes)

def lambda_fn_uniform(module: nn.Module):
return isinstance(module, module_classes)

def lambda_fn_nonuniform(module: nn.Module):
if isinstance(module, LoraAttention):
return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
elif isinstance(module, module_classes):
return True
return False

lambda_wrap_policy_uniform = CustomPolicy(lambda_fn_uniform)
lambda_wrap_policy_nonuniform = CustomPolicy(lambda_fn_nonuniform)

for use_orig_params, policy in itertools.product(
[True, False],
[
module_wrap_policy,
lambda_wrap_policy_uniform,
lambda_wrap_policy_nonuniform,
],
):
self._test_frozen_params(use_orig_params, policy)

def _test_frozen_params(self, use_orig_params: bool):
def _test_frozen_params(self, use_orig_params: bool, policy: _Policy):
model = LoraModel().cuda()
policy = ModuleWrapPolicy({LoraAttention, LoraMLP, LoraDecoder})
msg = "layers.0.attn has both parameters with requires_grad=True and False. "
if use_orig_params:
msg += "We do not recommend wrapping such modules"
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def fully_shard(
_annotate_modules_for_dynamo(module, state._ignored_modules, True)
state = _init_process_group_state(state, process_group, strategy, policy)
if policy is not None:
fsdp_kwargs = {
root_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
Expand All @@ -80,13 +80,13 @@ def fully_shard(
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
fsdp_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
fsdp_kwargs,
root_kwargs,
fully_shard,
)
state = _init_core_state(
Expand Down
22 changes: 10 additions & 12 deletions torch/distributed/fsdp/_wrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
_recursive_wrap,
_run_mixed_precision_override_policy,
_wrap_module_cls_individually,
ModuleWrapPolicy,
)


Expand All @@ -28,42 +27,41 @@ def _auto_wrap(
policy: Union[Callable, _Policy],
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
fsdp_kwargs: Dict[str, Any],
fsdp_fn: Callable, # `FullyShardedDataParallel` or `fully_shard`
root_kwargs: Dict[str, Any],
fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
):
"""
Auto wraps modules in ``root_module`` 's tree according to ``policy``
following a post-order traversal.

Precondition: ``fsdp_kwargs`` should contain all FSDP arguments except
Precondition: ``root_kwargs`` should contain all arguments except
``module``. This function accepts the kwargs dict directly since it gets
forwarded into the post-order traversal function.
"""
mixed_precision = fsdp_kwargs["mixed_precision"]
mixed_precision = root_kwargs["mixed_precision"]
is_wrapper = inspect.isclass(fsdp_fn)
# TODO: We may relax this no-nested-wrapping constraint to support manual
# wrapping followed by auto wrapping.
_check_nested_wrapping(root_module)

# TODO: Start migration to refactored auto wrapping with `ModuleWrapPolicy`
if isinstance(policy, ModuleWrapPolicy):
fsdp_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
if isinstance(policy, _Policy):
root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
target_module_to_kwargs = policy._run_policy(
root_module, ignored_modules, fsdp_kwargs
root_module, ignored_modules, root_kwargs
)
if mixed_precision is not None:
target_module_to_kwargs = _run_mixed_precision_override_policy(
root_module,
mixed_precision._module_classes_to_ignore,
ignored_modules,
fsdp_kwargs,
root_kwargs,
target_module_to_kwargs,
)
overridden_module_classes = _override_module_mixed_precision(
root_module, mixed_precision._module_classes_to_ignore
)
_warn_on_overridden_mixed_precision(overridden_module_classes)
use_orig_params = fsdp_kwargs.get("use_orig_params", False)
use_orig_params = root_kwargs.get("use_orig_params", False)
_validate_frozen_params(
root_module,
set(target_module_to_kwargs.keys()),
Expand Down Expand Up @@ -100,7 +98,7 @@ def _auto_wrap(
)
recursive_wrap_kwargs["auto_wrap_policy"] = policy
_warn_on_overridden_mixed_precision(overridden_module_classes)
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) # type: ignore[arg-type]
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]


def _check_nested_wrapping(root_module: nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def __init__(
self, process_group, sharding_strategy, auto_wrap_policy
)
if auto_wrap_policy is not None:
fsdp_kwargs = {
root_kwargs = {
"process_group": process_group,
"sharding_strategy": sharding_strategy,
"cpu_offload": cpu_offload,
Expand All @@ -452,14 +452,14 @@ def __init__(
# Share root process groups with children to maintain
# the invariant that all FSDP modules will have the same
# process groups.
fsdp_kwargs["process_group"] = (self.process_group, self._inter_node_pg)
root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)

_auto_wrap(
module,
auto_wrap_policy,
self._ignored_modules,
self._ignored_params,
fsdp_kwargs,
root_kwargs,
FullyShardedDataParallel,
)

Expand Down