Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Introduce ModuleWrapPolicy for simplicity #88450

Closed
wants to merge 9 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Nov 3, 2022

Stack from ghstack:

BC Breaking Change
This renames unwrapped_params to nonwrapped_numel. I prefer nonwrapped over unwrapped because "unwrap" suggests that some wrapping has been undone. I prefer numel over params because that is unit of measurement; I think we should keep "params" to refer to nn.Parameters themselves.

This only breaks anything that passes unwrapped_params as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our pytorch code).

In a follow-up, I want to rename min_num_params to min_nonwrapped_numel in size_based_auto_wrap_policy, which is also BC breaking. Again, this is to differentiate between "params" being nn.Parameters and "numel" being the unit for param.numel().

Overview
This PR introduces ModuleWrapPolicy as a lightweight layer over the existing transformer_auto_wrap_policy. The most common auto wrapping paradigm is:

module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)

Now, users can instead write:

auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)

This hides the unused arguments expected from the callable (recurse and unwrapped_params/nonwrapped_numel).

ModuleWrapPolicy inherits from an abstract base class FSDPPolicy that expects a policy property. This decouples the construct of such FSDPPolicy classes and their actual policy, which must abide by the _recursive_wrap interface. Any existing auto wrap policy can be rewritten as a class that inherits from FSDPPolicy, so this approach is fully backward compatible from a functionality perspective.

I call this base class FSDPPolicy to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing FlatParameters, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in fully_shard() to simply policy instead of auto_wrap_policy.

This PR migrates usages of transformer_auto_wrap_policy within our unit test suite to ModuleWrapPolicy as much as possible.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 3, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88450

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cb44a77:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu added a commit that referenced this pull request Nov 3, 2022
ghstack-source-id: 2518526f49525a737960f99790c54f1c220137fa
Pull Request resolved: #88450
@awgu awgu added topic: bc breaking topic category topic: improvements topic category labels Nov 3, 2022
@awgu awgu requested a review from fegin November 3, 2022 21:12
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `AutoWrapPolicy` that expects an `auto_wrap_policy` property. This decouples the construct of such `AutoWrapPolicy` classes and their actual `auto_wrap_policy`, which must abide by the `_recursive_wrap` interface.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
awgu added a commit that referenced this pull request Nov 3, 2022
ghstack-source-id: 1c8e6170cf3693133d150261b4adbde7436195f3
Pull Request resolved: #88450
awgu added a commit to awgu/pytorch that referenced this pull request Nov 10, 2022
ghstack-source-id: 1c8e6170cf3693133d150261b4adbde7436195f3
Pull Request resolved: pytorch#88450
awgu added a commit to awgu/pytorch that referenced this pull request Nov 10, 2022
ghstack-source-id: 1c8e6170cf3693133d150261b4adbde7436195f3
Pull Request resolved: pytorch#88450
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). 

`ModuleWrapPolicy` inherits from an abstract base class `AutoWrapPolicy` that expects an `auto_wrap_policy` property. This decouples the construct of such `AutoWrapPolicy` classes and their actual `auto_wrap_policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `AutoWrapPolicy`, so this approach is fully backward compatible from a functionality perspective.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Nov 10, 2022
ghstack-source-id: ebbe195095265d8e837ea3f5af878cab8254f328
Pull Request resolved: pytorch#88450
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). 

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Nov 10, 2022
ghstack-source-id: e3c5deb874ddc86a2225e1227b54d00c887a3f80
Pull Request resolved: pytorch#88450
awgu added a commit to awgu/pytorch that referenced this pull request Nov 10, 2022
ghstack-source-id: e3c5deb874ddc86a2225e1227b54d00c887a3f80
Pull Request resolved: pytorch#88450
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). 

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Nov 10, 2022
ghstack-source-id: 4fa2c725988a9adbf622a212256b7e4f353b37b0
Pull Request resolved: pytorch#88450
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). 

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). 

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). 

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
@@ -32,7 +33,7 @@ def fully_shard(
process_group: Optional[dist.ProcessGroup] = None,
mixed_precision: Optional[MixedPrecision] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Callable] = None,
policy: Optional[_FSDPPolicy] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's still keep the 'auto_wrap_policy' name for now? feel policy is too general, also it is a big BC change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is composable FSDP. In my understanding, we should be able to change the constructor?

I wanted policy to be general because we can configure FSDP this way. This can be an entry point for different flavors of FSDP. One option to enable tensor shape preservation may be via policy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrapper FullyShardedDataParallel still calls it auto_wrap_policy in its constructor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, that sounds good then.

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 11, 2022
@awgu
Copy link
Contributor Author

awgu commented Nov 11, 2022

I checked internal code. There is no code passing unwrapped_params as a kwarg, so renaming it should not have any breakages.

@awgu
Copy link
Contributor Author

awgu commented Nov 11, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 additional jobs have failed, first few of them are: TorchBench CI (pytorch-linux-py3.8-cu116)

Details for Dev Infra team Raised by workflow job

@awgu
Copy link
Contributor Author

awgu commented Nov 11, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 additional jobs have failed, first few of them are: TorchBench CI (pytorch-linux-py3.8-cu116)

Details for Dev Infra team Raised by workflow job

@awgu
Copy link
Contributor Author

awgu commented Nov 11, 2022

Screen Shot 2022-11-11 at 5 38 57 PM

All tests are passing. Only TorchBench CI (pytorch-linux-py3.8-cu116) is skipped (not failing). However, merging is failing due to the skipped test.

@awgu
Copy link
Contributor Author

awgu commented Nov 11, 2022

@pytorchbot merge -f "TorchBench CI (pytorch-linux-py3.8-cu116) is skipped but being incorrectly treated as failed"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x e410e8dcfe12b2a70f7143df05e110a37992046d returned non-zero exit code 1

Auto-merging test/distributed/fsdp/test_fsdp_state_dict.py
CONFLICT (content): Merge conflict in test/distributed/fsdp/test_fsdp_state_dict.py
Auto-merging torch/distributed/fsdp/_init_utils.py
Auto-merging torch/distributed/fsdp/fully_sharded_data_parallel.py
error: could not apply e410e8dcfe... [FSDP] Introduce `ModuleWrapPolicy` for simplicity
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
Details for Dev Infra team Raised by workflow job

awgu added a commit to awgu/pytorch that referenced this pull request Nov 12, 2022
ghstack-source-id: f4f7e36fcf30155219d94f14ea6bf185fdb9ac41
Pull Request resolved: pytorch#88450
awgu added a commit to awgu/pytorch that referenced this pull request Nov 12, 2022
ghstack-source-id: f4f7e36fcf30155219d94f14ea6bf185fdb9ac41
Pull Request resolved: pytorch#88450
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
``` 
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). 

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. 

[ghstack-poisoned]
@awgu
Copy link
Contributor Author

awgu commented Nov 12, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

awgu added a commit to awgu/pytorch that referenced this pull request Nov 12, 2022
ghstack-source-id: 3142eae0fb9555a38d52b97ec136d48148cdc34e
Pull Request resolved: pytorch#88450
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
**BC Breaking Change**
This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap"  suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves.

This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code).

In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`.

**Overview**
This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is:
```
module_classes: Set[Type[nn.Module]] = ...
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls=module_classes,
)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
Now, users can instead write:
```
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...)
```
This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`).

`ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective.

I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`.

This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible.
Pull Request resolved: pytorch#88450
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/awgu/192/head branch June 8, 2023 15:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: bc breaking topic category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants