-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][9/N] Introduce CustomPolicy
#104986
Changes from all commits
f671536
c384f78
a8d980e
0dbb0c6
89b13a0
0c1c37b
d0bbb31
c938136
3fdf50f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
import functools | ||
import itertools | ||
import os | ||
import tempfile | ||
import unittest | ||
|
@@ -15,12 +16,15 @@ | |
BackwardPrefetch, | ||
CPUOffload, | ||
FullyShardedDataParallel as FSDP, | ||
MixedPrecision, | ||
ShardingStrategy, | ||
) | ||
from torch.distributed.fsdp.wrap import ( | ||
_or_policy, | ||
_Policy, | ||
_wrap_module_cls_individually, | ||
always_wrap_policy, | ||
CustomPolicy, | ||
enable_wrap, | ||
ModuleWrapPolicy, | ||
size_based_auto_wrap_policy, | ||
|
@@ -477,6 +481,109 @@ def _test_transformer_wrapping(self, auto_wrap_policy: Union[Callable, _Policy]) | |
else: | ||
self.assertFalse(isinstance(module, FSDP)) | ||
|
||
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") | ||
def test_custom_policy(self): | ||
""" | ||
Tests ``CustomPolicy`` with both a lambda function that uses uniform | ||
kwargs (so only returns ``False`` or ``True``) and a lambda function | ||
that uses non-uniform kwargs (so returns a dict to override the root | ||
kwargs). | ||
""" | ||
for use_uniform_kwargs in [False, True]: | ||
self._test_custom_policy(use_uniform_kwargs) | ||
|
||
def _test_custom_policy(self, use_uniform_kwargs: bool): | ||
print(f"use_uniform_kwargs={use_uniform_kwargs}") | ||
model = TransformerWithSharedParams.init( | ||
self.process_group, | ||
FSDPInitMode.NO_FSDP, | ||
CUDAInitMode.CUDA_BEFORE, | ||
{}, | ||
) | ||
|
||
if use_uniform_kwargs: | ||
|
||
def lambda_fn(module: nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does LambdaPolicy still get overriden and individually wrap BatchNorms if mixed precision is configured? In general, what happens if my LambdaPolicy tells to wrap a BN, but mixed precision overrides this? What should be the expected behavior? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, our mixed precision overrides run last. I am not sure if we should special case the I was leaning toward just keeping our overrides as absolute for now. |
||
if module is model.bn: | ||
return True | ||
elif isinstance( | ||
module, (TransformerEncoderLayer, TransformerDecoderLayer) | ||
): | ||
return True | ||
return False | ||
|
||
else: | ||
|
||
def lambda_fn(module: nn.Module): | ||
if module is model.bn: | ||
return {"sharding_strategy": ShardingStrategy.NO_SHARD} | ||
elif isinstance(module, TransformerEncoderLayer): | ||
return True | ||
elif isinstance(module, TransformerDecoderLayer): | ||
return { | ||
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, | ||
"backward_prefetch": BackwardPrefetch.BACKWARD_POST, | ||
} | ||
return False | ||
|
||
policy = CustomPolicy(lambda_fn) | ||
# Use a size-2 dummy PG to avoid clamping the sharding strategy to | ||
# `NO_SHARD` as for a size-1 PG | ||
process_group = DummyProcessGroup(rank=0, size=2) | ||
fp16_mp = MixedPrecision(param_dtype=torch.float16) | ||
fp32_mp = MixedPrecision() | ||
model = FSDP( | ||
model, | ||
process_group=process_group, | ||
auto_wrap_policy=policy, | ||
mixed_precision=fp16_mp, | ||
) | ||
encoder_layers = set(model.module.transformer.encoder.layers) | ||
decoder_layers = set(model.module.transformer.decoder.layers) | ||
bn = model.module.bn | ||
bn_strategy = ( | ||
ShardingStrategy.FULL_SHARD | ||
if use_uniform_kwargs | ||
else ShardingStrategy.NO_SHARD | ||
) | ||
bn_prefetch = BackwardPrefetch.BACKWARD_PRE | ||
encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD | ||
encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE | ||
decoder_strategy = ( | ||
ShardingStrategy.FULL_SHARD | ||
if use_uniform_kwargs | ||
else ShardingStrategy.SHARD_GRAD_OP | ||
) | ||
decoder_prefetch = ( | ||
BackwardPrefetch.BACKWARD_PRE | ||
if use_uniform_kwargs | ||
else BackwardPrefetch.BACKWARD_POST | ||
) | ||
for module in model.modules(): | ||
if module is bn: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, bn_strategy) | ||
self.assertEqual(module.backward_prefetch, bn_prefetch) | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# We currently override batch norm modules to use fp32 | ||
self.assertEqual(module.mixed_precision, fp32_mp) | ||
elif module in encoder_layers: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, encoder_strategy) | ||
self.assertEqual(module.backward_prefetch, encoder_prefetch) | ||
self.assertEqual(module.mixed_precision, fp16_mp) | ||
elif module in decoder_layers: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, decoder_strategy) | ||
self.assertEqual(module.backward_prefetch, decoder_prefetch) | ||
self.assertEqual(module.mixed_precision, fp16_mp) | ||
elif module is model: | ||
self.assertTrue(isinstance(module, FSDP)) | ||
self.assertEqual(module.sharding_strategy, root_strategy) | ||
self.assertEqual(module.backward_prefetch, root_prefetch) | ||
self.assertEqual(module.mixed_precision, fp16_mp) | ||
else: | ||
self.assertFalse(isinstance(module, FSDP)) | ||
|
||
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") | ||
def test_auto_wrap_api(self): | ||
""" | ||
|
@@ -705,12 +812,34 @@ def test_frozen_params(self): | |
Tests that mixing frozen/non-frozen parameters in an FSDP instance | ||
raises for ``use_orig_params=False`` and warns for ``True``. | ||
""" | ||
for use_orig_params in [True, False]: | ||
self._test_frozen_params(use_orig_params) | ||
module_classes = (LoraAttention, LoraMLP, LoraDecoder) | ||
module_wrap_policy = ModuleWrapPolicy(module_classes) | ||
|
||
def lambda_fn_uniform(module: nn.Module): | ||
return isinstance(module, module_classes) | ||
|
||
def lambda_fn_nonuniform(module: nn.Module): | ||
if isinstance(module, LoraAttention): | ||
return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} | ||
elif isinstance(module, module_classes): | ||
return True | ||
return False | ||
|
||
lambda_wrap_policy_uniform = CustomPolicy(lambda_fn_uniform) | ||
lambda_wrap_policy_nonuniform = CustomPolicy(lambda_fn_nonuniform) | ||
|
||
for use_orig_params, policy in itertools.product( | ||
[True, False], | ||
[ | ||
module_wrap_policy, | ||
lambda_wrap_policy_uniform, | ||
lambda_wrap_policy_nonuniform, | ||
], | ||
): | ||
self._test_frozen_params(use_orig_params, policy) | ||
|
||
def _test_frozen_params(self, use_orig_params: bool): | ||
def _test_frozen_params(self, use_orig_params: bool, policy: _Policy): | ||
model = LoraModel().cuda() | ||
policy = ModuleWrapPolicy({LoraAttention, LoraMLP, LoraDecoder}) | ||
msg = "layers.0.attn has both parameters with requires_grad=True and False. " | ||
if use_orig_params: | ||
msg += "We do not recommend wrapping such modules" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a test to ensure we raise a good error when we return an invalid key? like {"foo": 1}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one option here is to use
inspect
to get the arg signature of the module-level API passed to the policy. However, I think the default error message from passing in an unexpected arg/kwarg is not unreasonable, so I will not worry about this change in this PR: