-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][9/N] Introduce CustomPolicy
#104986
Changes from 6 commits
f671536
c384f78
a8d980e
0dbb0c6
89b13a0
0c1c37b
d0bbb31
c938136
3fdf50f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
import functools | ||
import itertools | ||
import os | ||
import tempfile | ||
import unittest | ||
|
@@ -15,13 +16,15 @@ | |
BackwardPrefetch, | ||
CPUOffload, | ||
FullyShardedDataParallel as FSDP, | ||
ShardingStrategy, | ||
) | ||
from torch.distributed.fsdp.wrap import ( | ||
_or_policy, | ||
_wrap_module_cls_individually, | ||
_WrapPolicy, | ||
always_wrap_policy, | ||
enable_wrap, | ||
LambdaWrapPolicy, | ||
ModuleWrapPolicy, | ||
size_based_auto_wrap_policy, | ||
transformer_auto_wrap_policy, | ||
|
@@ -479,6 +482,96 @@ def _test_transformer_wrapping( | |
else: | ||
self.assertFalse(isinstance(module, FSDP)) | ||
|
||
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") | ||
def test_lambda_wrap_policy(self): | ||
""" | ||
Tests ``LambdaWrapPolicy`` with both a lambda function that uses | ||
uniform kwargs (so only returns ``False`` or ``True``) and a lambda | ||
function that uses non-uniform kwargs (so returns a dict to override | ||
the root kwargs). | ||
""" | ||
for use_uniform_kwargs in [False, True]: | ||
self._test_lambda_wrap_policy(use_uniform_kwargs) | ||
|
||
def _test_lambda_wrap_policy(self, use_uniform_kwargs: bool): | ||
model = TransformerWithSharedParams.init( | ||
self.process_group, | ||
FSDPInitMode.NO_FSDP, | ||
CUDAInitMode.CUDA_BEFORE, | ||
{}, | ||
) | ||
|
||
if use_uniform_kwargs: | ||
|
||
def lambda_fn(module: nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does LambdaPolicy still get overriden and individually wrap BatchNorms if mixed precision is configured? In general, what happens if my LambdaPolicy tells to wrap a BN, but mixed precision overrides this? What should be the expected behavior? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, our mixed precision overrides run last. I am not sure if we should special case the I was leaning toward just keeping our overrides as absolute for now. |
||
if module is model.bn: | ||
return True | ||
elif isinstance( | ||
module, (TransformerEncoderLayer, TransformerDecoderLayer) | ||
): | ||
return True | ||
return False | ||
|
||
else: | ||
|
||
def lambda_fn(module: nn.Module): | ||
if module is model.bn: | ||
return {"sharding_strategy": ShardingStrategy.NO_SHARD} | ||
elif isinstance(module, TransformerEncoderLayer): | ||
return True | ||
elif isinstance(module, TransformerDecoderLayer): | ||
return { | ||
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, | ||
"backward_prefetch": BackwardPrefetch.BACKWARD_POST, | ||
} | ||
return False | ||
|
||
policy = LambdaWrapPolicy(lambda_fn) | ||
# Use a size-2 dummy PG to avoid clamping the sharding strategy to | ||
# `NO_SHARD` as for a size-1 PG | ||
process_group = DummyProcessGroup(rank=0, size=2) | ||
model = FSDP(model, process_group=process_group, auto_wrap_policy=policy) | ||
encoder_layers = set(model.module.transformer.encoder.layers) | ||
decoder_layers = set(model.module.transformer.decoder.layers) | ||
bn = model.module.bn | ||
bn_strategy = ( | ||
ShardingStrategy.FULL_SHARD | ||
if use_uniform_kwargs | ||
else ShardingStrategy.NO_SHARD | ||
) | ||
bn_prefetch = BackwardPrefetch.BACKWARD_PRE | ||
encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD | ||
encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE | ||
decoder_strategy = ( | ||
ShardingStrategy.FULL_SHARD | ||
if use_uniform_kwargs | ||
else ShardingStrategy.SHARD_GRAD_OP | ||
) | ||
decoder_prefetch = ( | ||
BackwardPrefetch.BACKWARD_PRE | ||
if use_uniform_kwargs | ||
else BackwardPrefetch.BACKWARD_POST | ||
) | ||
for module in model.modules(): | ||
if module is bn: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, bn_strategy) | ||
self.assertEqual(module.backward_prefetch, bn_prefetch) | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif module in encoder_layers: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, encoder_strategy) | ||
self.assertEqual(module.backward_prefetch, encoder_prefetch) | ||
elif module in decoder_layers: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, decoder_strategy) | ||
self.assertEqual(module.backward_prefetch, decoder_prefetch) | ||
elif module is model: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, root_strategy) | ||
self.assertEqual(module.backward_prefetch, root_prefetch) | ||
else: | ||
self.assertFalse(isinstance(module, FSDP)) | ||
|
||
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") | ||
def test_auto_wrap_api(self): | ||
""" | ||
|
@@ -707,12 +800,34 @@ def test_frozen_params(self): | |
Tests that mixing frozen/non-frozen parameters in an FSDP instance | ||
raises for ``use_orig_params=False`` and warns for ``True``. | ||
""" | ||
for use_orig_params in [True, False]: | ||
self._test_frozen_params(use_orig_params) | ||
module_classes = (LoraAttention, LoraMLP, LoraDecoder) | ||
module_wrap_policy = ModuleWrapPolicy(module_classes) | ||
|
||
def lambda_fn_uniform(module: nn.Module): | ||
return isinstance(module, module_classes) | ||
|
||
def lambda_fn_nonuniform(module: nn.Module): | ||
if isinstance(module, LoraAttention): | ||
return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} | ||
elif isinstance(module, module_classes): | ||
return True | ||
return False | ||
|
||
lambda_wrap_policy_uniform = LambdaWrapPolicy(lambda_fn_uniform) | ||
lambda_wrap_policy_nonuniform = LambdaWrapPolicy(lambda_fn_nonuniform) | ||
|
||
for use_orig_params, policy in itertools.product( | ||
[True, False], | ||
[ | ||
module_wrap_policy, | ||
lambda_wrap_policy_uniform, | ||
lambda_wrap_policy_nonuniform, | ||
], | ||
): | ||
self._test_frozen_params(use_orig_params, policy) | ||
|
||
def _test_frozen_params(self, use_orig_params: bool): | ||
def _test_frozen_params(self, use_orig_params: bool, policy: _WrapPolicy): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add a test for the mixed precision overriding as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added! |
||
model = LoraModel().cuda() | ||
policy = ModuleWrapPolicy({LoraAttention, LoraMLP, LoraDecoder}) | ||
msg = "layers.0.attn has both parameters with requires_grad=True and False. " | ||
if use_orig_params: | ||
msg += "We do not recommend wrapping such modules" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
Set, | ||
Tuple, | ||
Type, | ||
Union, | ||
) | ||
|
||
import torch.nn as nn | ||
|
@@ -29,6 +30,7 @@ | |
"size_based_auto_wrap_policy", | ||
"enable_wrap", | ||
"wrap", | ||
"LambdaWrapPolicy", | ||
"ModuleWrapPolicy", | ||
] | ||
|
||
|
@@ -209,6 +211,48 @@ def __repr__(self) -> str: | |
return super().__repr__() + f"({self._module_classes_str})" | ||
|
||
|
||
class LambdaWrapPolicy(_WrapPolicy): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add an example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You solicited some name bikeshedding over WC, so I'll try my best here. Assuming we are keeping WrapPolicy suffix, CallbackWrapPolicy or CustomWrapPolicy both convey intent here reasonably well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could also try PredicateWrapPolicy. This is not a perfect match: I would call |
||
""" | ||
This policy takes in a lambda function that maps a given ``nn.Module`` to | ||
either ``False``, ``True``, or a FSDP kwarg dictionary. | ||
- If the function returns ``False`` or an empty dictionary, then the module | ||
is not wrapped. | ||
- If the function returns ``True``, then the module is wrapped using the | ||
root's kwargs. | ||
- If the function returns a non-empty dictionary, then the module is | ||
wrapped, and the dictionary overrides the root's kwargs. | ||
""" | ||
|
||
def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]): | ||
self._lambda_fn = lambda_fn | ||
|
||
def _run_policy( | ||
self, | ||
root_module: nn.Module, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OOC, is there a way to easily determine the FQN of the module at this point? Without the root module it seems like it might be difficult, whereas, the code that calls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a good point. I do think we need access to the root module to be able to generate FQNs ("we" encompasses both in FSDP and for users). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a downstream package I maintain, I've been allowing users to provide module FQN-based complements to a given
Anyway, nice enhancement @awgu! A pleasure as always working with this thoughtfully designed API and constantly improving package. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the pointer! This is super helpful, especially learning more about the customer feedback. I am thinking that What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Agree completely, I think prioritizing API stability and simplicity at this juncture is the prudent approach. Adding an API that permits FQNs directly might be worth considering at some point in the future but that feature isn't sufficiently crucial to justify its inclusion at present (especially since implementation approaches like my Thanks for the quick and thoughtful response! |
||
ignored_modules: Set[nn.Module], | ||
root_fsdp_kwargs: Dict[str, Any], | ||
) -> Dict[nn.Module, Dict[str, Any]]: | ||
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} | ||
for module in root_module.modules(): | ||
if module in ignored_modules: | ||
continue | ||
res = self._lambda_fn(module) | ||
if not isinstance(res, (dict, bool)): | ||
raise ValueError( | ||
"The lambda_fn passed to LambdaWrapPolicy should return " | ||
f"False/True or an FSDP kwarg dict, but it returned {res}" | ||
) | ||
if not res: | ||
continue | ||
fsdp_kwargs = copy.copy(root_fsdp_kwargs) | ||
if isinstance(res, dict): | ||
# Override the root FSDP kwargs with the ones specified by the | ||
# lambda function | ||
fsdp_kwargs.update(res) | ||
target_module_to_kwargs[module] = fsdp_kwargs | ||
return target_module_to_kwargs | ||
|
||
|
||
def lambda_auto_wrap_policy( | ||
module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable | ||
) -> bool: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a test to ensure we raise a good error when we return an invalid key? like {"foo": 1}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one option here is to use
inspect
to get the arg signature of the module-level API passed to the policy. However, I think the default error message from passing in an unexpected arg/kwarg is not unreasonable, so I will not worry about this change in this PR: