Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][9/N] Introduce CustomPolicy #104986

Closed
wants to merge 9 commits into from
28 changes: 26 additions & 2 deletions test/distributed/_composable/fully_shard/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened, clean_tensor_name
from torch.distributed.fsdp.wrap import _WrapPolicy, ModuleWrapPolicy
from torch.distributed.fsdp.wrap import _WrapPolicy, LambdaWrapPolicy, ModuleWrapPolicy
from torch.testing._internal.common_dist_composable import (
CompositeParamModel,
FakeSequential,
NestedSequentialModel,
UnitModule,
)
Expand Down Expand Up @@ -43,12 +44,21 @@ def world_size(self) -> int:
@skip_if_lt_x_gpu(2)
def test_policy(self):
"""Tests passing a ``policy`` for pseudo-auto-wrapping."""

def lambda_fn(module: nn.Module):
if isinstance(module, nn.Sequential):
return True
elif isinstance(module, FakeSequential):
return {"backward_prefetch": BackwardPrefetch.BACKWARD_POST}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a test to ensure we raise a good error when we return an invalid key? like {"foo": 1}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one option here is to use inspect to get the arg signature of the module-level API passed to the policy. However, I think the default error message from passing in an unexpected arg/kwarg is not unreasonable, so I will not worry about this change in this PR:

TypeError: __init__() got an unexpected keyword argument 'foo'

return False

self.run_subtests(
{
"policy": [
None,
ModuleWrapPolicy({UnitModule}),
ModuleWrapPolicy({nn.Sequential}),
LambdaWrapPolicy(lambda_fn),
],
},
self._test_policy,
Expand Down Expand Up @@ -140,6 +150,20 @@ def _test_fully_shard_construction(
composable_module_classes.add(type(submodule))
self.assertEqual(local_module_classes, composable_module_classes)

# Check that the composable module has the same FSDP states with the
# same attributes (mainly checking backward prefetch since the lambda
# wrap policy overrides it for `FakeSequential`)
wrapper_states = traversal_utils._get_fsdp_states(fsdp_wrapped_model)
composable_states = traversal_utils._get_fsdp_states(composable_module)
self.assertEqual(len(wrapper_states), len(composable_states))
for wrapper_state, composable_state in zip(wrapper_states, composable_states):
self.assertEqual(
wrapper_state.sharding_strategy, composable_state.sharding_strategy
)
self.assertEqual(
wrapper_state.backward_prefetch, composable_state.backward_prefetch
)

@skip_if_lt_x_gpu(2)
def test_device_id(self):
"""Tests passing a ``device_id``."""
Expand Down
123 changes: 119 additions & 4 deletions test/distributed/fsdp/test_wrap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]

import functools
import itertools
import os
import tempfile
import unittest
Expand All @@ -15,13 +16,15 @@
BackwardPrefetch,
CPUOffload,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
_or_policy,
_wrap_module_cls_individually,
_WrapPolicy,
always_wrap_policy,
enable_wrap,
LambdaWrapPolicy,
ModuleWrapPolicy,
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
Expand Down Expand Up @@ -479,6 +482,96 @@ def _test_transformer_wrapping(
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_lambda_wrap_policy(self):
"""
Tests ``LambdaWrapPolicy`` with both a lambda function that uses
uniform kwargs (so only returns ``False`` or ``True``) and a lambda
function that uses non-uniform kwargs (so returns a dict to override
the root kwargs).
"""
for use_uniform_kwargs in [False, True]:
self._test_lambda_wrap_policy(use_uniform_kwargs)

def _test_lambda_wrap_policy(self, use_uniform_kwargs: bool):
model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.NO_FSDP,
CUDAInitMode.CUDA_BEFORE,
{},
)

if use_uniform_kwargs:

def lambda_fn(module: nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does LambdaPolicy still get overriden and individually wrap BatchNorms if mixed precision is configured?

In general, what happens if my LambdaPolicy tells to wrap a BN, but mixed precision overrides this? What should be the expected behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, our mixed precision overrides run last. I am not sure if we should special case the LambdaWrapPolicy to be highest priority. What do you think?

I was leaning toward just keeping our overrides as absolute for now.

if module is model.bn:
return True
elif isinstance(
module, (TransformerEncoderLayer, TransformerDecoderLayer)
):
return True
return False

else:

def lambda_fn(module: nn.Module):
if module is model.bn:
return {"sharding_strategy": ShardingStrategy.NO_SHARD}
elif isinstance(module, TransformerEncoderLayer):
return True
elif isinstance(module, TransformerDecoderLayer):
return {
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
"backward_prefetch": BackwardPrefetch.BACKWARD_POST,
}
return False

policy = LambdaWrapPolicy(lambda_fn)
# Use a size-2 dummy PG to avoid clamping the sharding strategy to
# `NO_SHARD` as for a size-1 PG
process_group = DummyProcessGroup(rank=0, size=2)
model = FSDP(model, process_group=process_group, auto_wrap_policy=policy)
encoder_layers = set(model.module.transformer.encoder.layers)
decoder_layers = set(model.module.transformer.decoder.layers)
bn = model.module.bn
bn_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.NO_SHARD
)
bn_prefetch = BackwardPrefetch.BACKWARD_PRE
encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD
encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE
decoder_strategy = (
ShardingStrategy.FULL_SHARD
if use_uniform_kwargs
else ShardingStrategy.SHARD_GRAD_OP
)
decoder_prefetch = (
BackwardPrefetch.BACKWARD_PRE
if use_uniform_kwargs
else BackwardPrefetch.BACKWARD_POST
)
for module in model.modules():
if module is bn:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, bn_strategy)
self.assertEqual(module.backward_prefetch, bn_prefetch)
awgu marked this conversation as resolved.
Show resolved Hide resolved
elif module in encoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, encoder_strategy)
self.assertEqual(module.backward_prefetch, encoder_prefetch)
elif module in decoder_layers:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, decoder_strategy)
self.assertEqual(module.backward_prefetch, decoder_prefetch)
elif module is model:
self.assertTrue(isinstance(module, FSDP))
self.assertEqual(module.sharding_strategy, root_strategy)
self.assertEqual(module.backward_prefetch, root_prefetch)
else:
self.assertFalse(isinstance(module, FSDP))

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_auto_wrap_api(self):
"""
Expand Down Expand Up @@ -707,12 +800,34 @@ def test_frozen_params(self):
Tests that mixing frozen/non-frozen parameters in an FSDP instance
raises for ``use_orig_params=False`` and warns for ``True``.
"""
for use_orig_params in [True, False]:
self._test_frozen_params(use_orig_params)
module_classes = (LoraAttention, LoraMLP, LoraDecoder)
module_wrap_policy = ModuleWrapPolicy(module_classes)

def lambda_fn_uniform(module: nn.Module):
return isinstance(module, module_classes)

def lambda_fn_nonuniform(module: nn.Module):
if isinstance(module, LoraAttention):
return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
elif isinstance(module, module_classes):
return True
return False

lambda_wrap_policy_uniform = LambdaWrapPolicy(lambda_fn_uniform)
lambda_wrap_policy_nonuniform = LambdaWrapPolicy(lambda_fn_nonuniform)

for use_orig_params, policy in itertools.product(
[True, False],
[
module_wrap_policy,
lambda_wrap_policy_uniform,
lambda_wrap_policy_nonuniform,
],
):
self._test_frozen_params(use_orig_params, policy)

def _test_frozen_params(self, use_orig_params: bool):
def _test_frozen_params(self, use_orig_params: bool, policy: _WrapPolicy):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a test for the mixed precision overriding as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

model = LoraModel().cuda()
policy = ModuleWrapPolicy({LoraAttention, LoraMLP, LoraDecoder})
msg = "layers.0.attn has both parameters with requires_grad=True and False. "
if use_orig_params:
msg += "We do not recommend wrapping such modules"
Expand Down
4 changes: 1 addition & 3 deletions torch/distributed/fsdp/_wrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
_run_mixed_precision_override_policy,
_wrap_module_cls_individually,
_WrapPolicy,
ModuleWrapPolicy,
)


Expand All @@ -43,8 +42,7 @@ def _auto_wrap(
# wrapping followed by auto wrapping.
_check_nested_wrapping(root_module)

# TODO: Start migration to refactored auto wrapping with `ModuleWrapPolicy`
if isinstance(policy, ModuleWrapPolicy):
if isinstance(policy, _WrapPolicy):
fsdp_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
target_module_to_kwargs = policy._run_policy(
root_module, ignored_modules, fsdp_kwargs
Expand Down
44 changes: 44 additions & 0 deletions torch/distributed/fsdp/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Set,
Tuple,
Type,
Union,
)

import torch.nn as nn
Expand All @@ -29,6 +30,7 @@
"size_based_auto_wrap_policy",
"enable_wrap",
"wrap",
"LambdaWrapPolicy",
"ModuleWrapPolicy",
]

Expand Down Expand Up @@ -209,6 +211,48 @@ def __repr__(self) -> str:
return super().__repr__() + f"({self._module_classes_str})"


class LambdaWrapPolicy(_WrapPolicy):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add an example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You solicited some name bikeshedding over WC, so I'll try my best here. Assuming we are keeping WrapPolicy suffix, CallbackWrapPolicy or CustomWrapPolicy both convey intent here reasonably well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also try PredicateWrapPolicy. This is not a perfect match: I would call lambda_fn a predicate if it only returned a bool, but you return a little bit extra information sometimes.

"""
This policy takes in a lambda function that maps a given ``nn.Module`` to
either ``False``, ``True``, or a FSDP kwarg dictionary.
- If the function returns ``False`` or an empty dictionary, then the module
is not wrapped.
- If the function returns ``True``, then the module is wrapped using the
root's kwargs.
- If the function returns a non-empty dictionary, then the module is
wrapped, and the dictionary overrides the root's kwargs.
"""

def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
self._lambda_fn = lambda_fn

def _run_policy(
self,
root_module: nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, is there a way to easily determine the FQN of the module at this point? Without the root module it seems like it might be difficult, whereas, the code that calls _run_policy would be able to generate it without much difficulty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point. I do think we need access to the root module to be able to generate FQNs ("we" encompasses both in FSDP and for users).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a downstream package I maintain, I've been allowing users to provide module FQN-based complements to a given auto_wrap_policy using a custom _FSDPPolicy (NameDrivenPolicy). I should be able to continue offering that functionality with some minor refactoring after this PR, but I thought it might be worth mentioning in the context of this PR because:

  1. To call attention to this pattern since I've received positive user feedback regarding the utility FQN-based wrapping (to accommodate more flexible fine-tuning requirements in this case)
  2. In case the team considers the functionality sufficiently useful to offer it natively at some point in the future (I'm not necessarily advocating that, just thinking aloud... either way this PR should help make such extensions cleaner)

Anyway, nice enhancement @awgu! A pleasure as always working with this thoughtfully designed API and constantly improving package.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! This is super helpful, especially learning more about the customer feedback.

I am thinking that CustomPolicy can stay generic, and we can push specific logic to the lambda_fn construction. If the user wants to use FQNs, then the user can use the same trick as your implementation where you translate FQNs to Python id()s and then use the id()s in the lambda_fn, which will see the module as argument. Like you said, after some refactoring, you should still be able to make use of CustomPolicy. The question is whether we should offer some API that permits FQNs directly. I think that is something that we can wait and see. For now, we can just add CustomPolicy that can serve as a base for other extensions.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question is whether we should offer some API that permits FQNs directly. I think that is something that we can wait and see. For now, we can just add CustomPolicy that can serve as a base for other extensions.

Agree completely, I think prioritizing API stability and simplicity at this juncture is the prudent approach. Adding an API that permits FQNs directly might be worth considering at some point in the future but that feature isn't sufficiently crucial to justify its inclusion at present (especially since implementation approaches like my id() translation one we discussed aren't too onerous).

Thanks for the quick and thoughtful response!

ignored_modules: Set[nn.Module],
root_fsdp_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
for module in root_module.modules():
if module in ignored_modules:
continue
res = self._lambda_fn(module)
if not isinstance(res, (dict, bool)):
raise ValueError(
"The lambda_fn passed to LambdaWrapPolicy should return "
f"False/True or an FSDP kwarg dict, but it returned {res}"
)
if not res:
continue
fsdp_kwargs = copy.copy(root_fsdp_kwargs)
if isinstance(res, dict):
# Override the root FSDP kwargs with the ones specified by the
# lambda function
fsdp_kwargs.update(res)
target_module_to_kwargs[module] = fsdp_kwargs
return target_module_to_kwargs


def lambda_auto_wrap_policy(
module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
) -> bool:
Expand Down