-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Introduce ModuleWrapPolicy
for simplicity
#88450
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88450
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cb44a77: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 2518526f49525a737960f99790c54f1c220137fa Pull Request resolved: #88450
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `AutoWrapPolicy` that expects an `auto_wrap_policy` property. This decouples the construct of such `AutoWrapPolicy` classes and their actual `auto_wrap_policy`, which must abide by the `_recursive_wrap` interface. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
ghstack-source-id: 1c8e6170cf3693133d150261b4adbde7436195f3 Pull Request resolved: #88450
ghstack-source-id: 1c8e6170cf3693133d150261b4adbde7436195f3 Pull Request resolved: pytorch#88450
ghstack-source-id: 1c8e6170cf3693133d150261b4adbde7436195f3 Pull Request resolved: pytorch#88450
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `AutoWrapPolicy` that expects an `auto_wrap_policy` property. This decouples the construct of such `AutoWrapPolicy` classes and their actual `auto_wrap_policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `AutoWrapPolicy`, so this approach is fully backward compatible from a functionality perspective. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
ghstack-source-id: ebbe195095265d8e837ea3f5af878cab8254f328 Pull Request resolved: pytorch#88450
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
ghstack-source-id: e3c5deb874ddc86a2225e1227b54d00c887a3f80 Pull Request resolved: pytorch#88450
ghstack-source-id: e3c5deb874ddc86a2225e1227b54d00c887a3f80 Pull Request resolved: pytorch#88450
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
ghstack-source-id: 4fa2c725988a9adbf622a212256b7e4f353b37b0 Pull Request resolved: pytorch#88450
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
@@ -32,7 +33,7 @@ def fully_shard( | |||
process_group: Optional[dist.ProcessGroup] = None, | |||
mixed_precision: Optional[MixedPrecision] = None, | |||
cpu_offload: Optional[CPUOffload] = None, | |||
auto_wrap_policy: Optional[Callable] = None, | |||
policy: Optional[_FSDPPolicy] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's still keep the 'auto_wrap_policy' name for now? feel policy is too general, also it is a big BC change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is composable FSDP. In my understanding, we should be able to change the constructor?
I wanted policy
to be general because we can configure FSDP this way. This can be an entry point for different flavors of FSDP. One option to enable tensor shape preservation may be via policy
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wrapper FullyShardedDataParallel
still calls it auto_wrap_policy
in its constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see, that sounds good then.
I checked internal code. There is no code passing |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 additional jobs have failed, first few of them are: TorchBench CI (pytorch-linux-py3.8-cu116) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 additional jobs have failed, first few of them are: TorchBench CI (pytorch-linux-py3.8-cu116) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "TorchBench CI (pytorch-linux-py3.8-cu116) is skipped but being incorrectly treated as failed" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command
Details for Dev Infra teamRaised by workflow job |
ghstack-source-id: f4f7e36fcf30155219d94f14ea6bf185fdb9ac41 Pull Request resolved: pytorch#88450
ghstack-source-id: f4f7e36fcf30155219d94f14ea6bf185fdb9ac41 Pull Request resolved: pytorch#88450
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
ghstack-source-id: 3142eae0fb9555a38d52b97ec136d48148cdc34e Pull Request resolved: pytorch#88450
**BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. Pull Request resolved: pytorch#88450 Approved by: https://github.com/zhaojuanmao
Stack from ghstack:
ModuleWrapPolicy
#88453 [Dynamo][FSDP] Migrate toModuleWrapPolicy
ModuleWrapPolicy
for simplicity #88450 [FSDP] IntroduceModuleWrapPolicy
for simplicityBC Breaking Change
This renames
unwrapped_params
tononwrapped_numel
. I prefernonwrapped
overunwrapped
because "unwrap" suggests that some wrapping has been undone. I prefernumel
overparams
because that is unit of measurement; I think we should keep "params" to refer tonn.Parameter
s themselves.This only breaks anything that passes
unwrapped_params
as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on ourpytorch
code).In a follow-up, I want to rename
min_num_params
tomin_nonwrapped_numel
insize_based_auto_wrap_policy
, which is also BC breaking. Again, this is to differentiate between "params" beingnn.Parameter
s and "numel" being the unit forparam.numel()
.Overview
This PR introduces
ModuleWrapPolicy
as a lightweight layer over the existingtransformer_auto_wrap_policy
. The most common auto wrapping paradigm is:Now, users can instead write:
This hides the unused arguments expected from the callable (
recurse
andunwrapped_params
/nonwrapped_numel
).ModuleWrapPolicy
inherits from an abstract base classFSDPPolicy
that expects apolicy
property. This decouples the construct of suchFSDPPolicy
classes and their actualpolicy
, which must abide by the_recursive_wrap
interface. Any existing auto wrap policy can be rewritten as a class that inherits fromFSDPPolicy
, so this approach is fully backward compatible from a functionality perspective.I call this base class
FSDPPolicy
to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructingFlatParameter
s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument infully_shard()
to simplypolicy
instead ofauto_wrap_policy
.This PR migrates usages of
transformer_auto_wrap_policy
within our unit test suite toModuleWrapPolicy
as much as possible.