Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][3/N] Unify fully_shard auto wrap #104408

Closed
wants to merge 15 commits into from
Closed

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jun 29, 2023

Stack from ghstack (oldest at bottom):

This moves fully_shard to use _auto_wrap() just like FullyShardedDataParallel. This means that fully_shard goes through the _init_param_handle_from_module() path (i.e. 1 fully_shard per "wrap"), removing the need for _init_param_handles_from_module() (which was 1 fully_shard for all "wraps" of a given policy). _auto_wrap() simply calls fully_shard on target submodules.

This includes several important fixes:

  • We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
  • We can permit _module_handles to return [] in the composable path (for when the module has no managed parameters).
  • We should unify the paths for _get_buffers_and_dtypes_for_computation() (previously, composable path was buggy in some cases).

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104408

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c501068:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jun 29, 2023
awgu added a commit that referenced this pull request Jun 29, 2023
ghstack-source-id: 9e0cf806b4bc63ef5bb5361c6a4c1cb33cd80c7c
Pull Request resolved: #104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jun 29, 2023
ghstack-source-id: dede963dea8977e420b32c5b10b0587688c9f693
Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jun 29, 2023
ghstack-source-id: 457c67dc90ed19b789b827fd99a06b4bcd5951b6
Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
@awgu awgu added the topic: not user facing topic category label Jun 29, 2023
awgu added a commit to awgu/pytorch that referenced this pull request Jun 29, 2023
ghstack-source-id: 457c67dc90ed19b789b827fd99a06b4bcd5951b6
Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jun 30, 2023
ghstack-source-id: f811cfa160eb549249eda792e376dc0277053373
Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jun 30, 2023
ghstack-source-id: 141ba57ed15da65b7052e81f706b82925a376d33
Pull Request resolved: pytorch#104408
awgu added 2 commits June 30, 2023 00:43
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jun 30, 2023
ghstack-source-id: 5185a43b01c99015c96e04ea13a87f2449f39022
Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jul 5, 2023
ghstack-source-id: 4ad898a9777d75b08b32a786fcaab38388900ae9
Pull Request resolved: pytorch#104408
awgu added a commit to awgu/pytorch that referenced this pull request Jul 5, 2023
ghstack-source-id: 4ad898a9777d75b08b32a786fcaab38388900ae9
Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).

[ghstack-poisoned]
@awgu awgu marked this pull request as ready for review July 5, 2023 15:36
@awgu awgu requested a review from mrshenli as a code owner July 5, 2023 15:36
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).

[ghstack-poisoned]
voznesenskym and others added 3 commits July 6, 2023 20:25
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).

[ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).

[ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).

[ghstack-poisoned]
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thanks for unifying the code paths!

# A valid FSDP state may have no managed parameters and hence no
# handles, meaning no entry in `_fully_sharded_module_to_handles`
if len(state._handles) == 0:
return []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we add a test for this + test to ensure that if a composable FSDP module manages no params, it is still marked as FSDP managed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the _has_fsdp_params() check is still valid. Before, the composable path would raise an error when it did not need to, which is why I had to add this case.

In other words, this is covered by the existing tests. summon_full_params() on a module with fully_shard() applied but no managed parameters would error otherwise.

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 7, 2023
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants