Skip to content

Commit

Permalink
[FSDP][9/N] Introduce LambdaWrapPolicy
Browse files Browse the repository at this point in the history
ghstack-source-id: 2a97189df60fc83b9310f08928cfe3df0198297e
Pull Request resolved: #104986
  • Loading branch information
awgu committed Jul 11, 2023
1 parent 4327d5b commit 9b4d944
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 7 deletions.
123 changes: 119 additions & 4 deletions test/distributed/fsdp/test_wrap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]

import functools
import itertools
import os
import tempfile
import unittest
Expand All @@ -15,13 +16,15 @@
BackwardPrefetch,
CPUOffload,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
_FSDPPolicy,
_or_policy,
_wrap_module_cls_individually,
always_wrap_policy,
enable_wrap,
LambdaWrapPolicy,
ModuleWrapPolicy,
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
Expand Down Expand Up @@ -479,6 +482,96 @@ def _test_transformer_wrapping(
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_lambda_wrap_policy(self):
"""
Tests ``LambdaWrapPolicy`` with both a lambda function that uses
uniform kwargs (so only returns ``False`` or ``True``) and a lambda
function that uses non-uniform kwargs (so returns a dict to override
the root kwargs).
"""
for use_uniform_kwargs in [False, True]:
self._test_lambda_wrap_policy(use_uniform_kwargs)

def _test_lambda_wrap_policy(self, use_uniform_kwargs: bool):
model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.NO_FSDP,
CUDAInitMode.CUDA_BEFORE,
{},
)

if use_uniform_kwargs:

def lambda_fn(module: nn.Module):
if module is model.bn:
return True
elif isinstance(
module, (TransformerEncoderLayer, TransformerDecoderLayer)
):
return True
return False

else:

def lambda_fn(module: nn.Module):
if module is model.bn:
return {"sharding_strategy": ShardingStrategy.NO_SHARD}
elif isinstance(module, TransformerEncoderLayer):
return True
elif isinstance(module, TransformerDecoderLayer):
return {
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
"backward_prefetch": BackwardPrefetch.BACKWARD_POST,
}
return False

policy = LambdaWrapPolicy(lambda_fn)
# Use a size-2 dummy PG to avoid clamping the sharding strategy to
# `NO_SHARD` as for a size-1 PG
process_group = DummyProcessGroup(rank=0, size=2)
model = FSDP(model, process_group=process_group, auto_wrap_policy=policy)
encoder_layers = set(model.module.transformer.encoder.layers)
decoder_layers = set(model.module.transformer.decoder.layers)
bn = model.module.bn
bn_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.NO_SHARD
)
bn_prefetch = BackwardPrefetch.BACKWARD_PRE
encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD
encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE
decoder_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.SHARD_GRAD_OP
)
decoder_prefetch = (
BackwardPrefetch.BACKWARD_PRE
if use_uniform_kwargs
else BackwardPrefetch.BACKWARD_POST
)
for module in model.modules():
if module is bn:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, bn_strategy)
self.assertEqual(module.backward_prefetch, bn_prefetch)
elif module in encoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, encoder_strategy)
self.assertEqual(module.backward_prefetch, encoder_prefetch)
elif module in decoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, decoder_strategy)
self.assertEqual(module.backward_prefetch, decoder_prefetch)
elif module is model:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, root_strategy)
self.assertEqual(module.backward_prefetch, root_prefetch)
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_auto_wrap_api(self):
"""
Expand Down Expand Up @@ -707,12 +800,34 @@ def test_frozen_params(self):
Tests that mixing frozen/non-frozen parameters in an FSDP instance
raises for ``use_orig_params=False`` and warns for ``True``.
"""
for use_orig_params in [True, False]:
self._test_frozen_params(use_orig_params)
module_classes = (LoraAttention, LoraMLP, LoraDecoder)
module_wrap_policy = ModuleWrapPolicy(module_classes)

def lambda_fn_uniform(module: nn.Module):
return isinstance(module, module_classes)

def lambda_fn_nonuniform(module: nn.Module):
if isinstance(module, LoraAttention):
return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
elif isinstance(module, module_classes):
return True
return False

lambda_wrap_policy_uniform = LambdaWrapPolicy(lambda_fn_uniform)
lambda_wrap_policy_nonuniform = LambdaWrapPolicy(lambda_fn_nonuniform)

for use_orig_params, policy in itertools.product(
[True, False],
[
module_wrap_policy,
lambda_wrap_policy_uniform,
lambda_wrap_policy_nonuniform,
],
):
self._test_frozen_params(use_orig_params, policy)

def _test_frozen_params(self, use_orig_params: bool):
def _test_frozen_params(self, use_orig_params: bool, policy: _FSDPPolicy):
model = LoraModel().cuda()
policy = ModuleWrapPolicy({LoraAttention, LoraMLP, LoraDecoder})
msg = "layers.0.attn has both parameters with requires_grad=True and False. "
if use_orig_params:
msg += "We do not recommend wrapping such modules"
Expand Down
4 changes: 1 addition & 3 deletions torch/distributed/fsdp/_wrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
_recursive_wrap,
_run_mixed_precision_override_policy,
_wrap_module_cls_individually,
ModuleWrapPolicy,
)


Expand All @@ -43,8 +42,7 @@ def _auto_wrap(
# wrapping followed by auto wrapping.
_check_nested_wrapping(root_module)

# TODO: Start migration to refactored auto wrapping with `ModuleWrapPolicy`
if isinstance(policy, ModuleWrapPolicy):
if isinstance(policy, _FSDPPolicy):
fsdp_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
target_module_to_kwargs = policy._run_policy(
root_module, ignored_modules, fsdp_kwargs
Expand Down
44 changes: 44 additions & 0 deletions torch/distributed/fsdp/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Set,
Tuple,
Type,
Union,
)

import torch.nn as nn
Expand All @@ -30,6 +31,7 @@
"size_based_auto_wrap_policy",
"enable_wrap",
"wrap",
"LambdaWrapPolicy",
"ModuleWrapPolicy",
]

Expand Down Expand Up @@ -214,6 +216,48 @@ def __repr__(self) -> str:
return super().__repr__() + f"({self._module_classes_str})"


class LambdaWrapPolicy(_FSDPPolicy):
"""
This policy takes in a lambda function that maps a given ``nn.Module`` to
either ``False``, ``True``, or a FSDP kwarg dictionary.
- If the function returns ``False`` or an empty dictionary, then the module
is not wrapped.
- If the function returns ``True``, then the module is wrapped using the
root's kwargs.
- If the function returns a non-empty dictionary, then the module is
wrapped, and the dictionary overrides the root's kwargs.
"""

def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
self._lambda_fn = lambda_fn

def _run_policy(
self,
root_module: nn.Module,
ignored_modules: Set[nn.Module],
root_fsdp_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
for module in root_module.modules():
if module in ignored_modules:
continue
res = self._lambda_fn(module)
if not isinstance(res, (dict, bool)):
raise ValueError(
"The lambda_fn passed to LambdaWrapPolicy should return "
f"False/True or an FSDP kwarg dict, but it returned {res}"
)
if not res:
continue
fsdp_kwargs = copy.copy(root_fsdp_kwargs)
if isinstance(res, dict):
# Override the root FSDP kwargs with the ones specified by the
# lambda function
fsdp_kwargs.update(res)
target_module_to_kwargs[module] = fsdp_kwargs
return target_module_to_kwargs


def lambda_auto_wrap_policy(
module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
) -> bool:
Expand Down

0 comments on commit 9b4d944

Please sign in to comment.