-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][3/N] Unify fully_shard
auto wrap
#104408
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104408
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c501068: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 9e0cf806b4bc63ef5bb5361c6a4c1cb33cd80c7c Pull Request resolved: #104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
ghstack-source-id: dede963dea8977e420b32c5b10b0587688c9f693 Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
ghstack-source-id: 457c67dc90ed19b789b827fd99a06b4bcd5951b6 Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
ghstack-source-id: 457c67dc90ed19b789b827fd99a06b4bcd5951b6 Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
ghstack-source-id: f811cfa160eb549249eda792e376dc0277053373 Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
ghstack-source-id: 141ba57ed15da65b7052e81f706b82925a376d33 Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
ghstack-source-id: 5185a43b01c99015c96e04ea13a87f2449f39022 Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. [ghstack-poisoned]
ghstack-source-id: 4ad898a9777d75b08b32a786fcaab38388900ae9 Pull Request resolved: pytorch#104408
ghstack-source-id: 4ad898a9777d75b08b32a786fcaab38388900ae9 Pull Request resolved: pytorch#104408
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. This includes several important fixes: - We should register the pre/post-forward hooks on the module regardless of it has managed parameters. - We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters). - We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases). [ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. This includes several important fixes: - We should register the pre/post-forward hooks on the module regardless of it has managed parameters. - We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters). - We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases). [ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. This includes several important fixes: - We should register the pre/post-forward hooks on the module regardless of it has managed parameters. - We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters). - We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases). [ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. This includes several important fixes: - We should register the pre/post-forward hooks on the module regardless of it has managed parameters. - We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters). - We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases). [ghstack-poisoned]
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. This includes several important fixes: - We should register the pre/post-forward hooks on the module regardless of it has managed parameters. - We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters). - We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases). [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome, thanks for unifying the code paths!
# A valid FSDP state may have no managed parameters and hence no | ||
# handles, meaning no entry in `_fully_sharded_module_to_handles` | ||
if len(state._handles) == 0: | ||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did we add a test for this + test to ensure that if a composable FSDP module manages no params, it is still marked as FSDP managed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the _has_fsdp_params()
check is still valid. Before, the composable path would raise an error when it did not need to, which is why I had to add this case.
In other words, this is covered by the existing tests. summon_full_params()
on a module with fully_shard()
applied but no managed parameters would error otherwise.
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules. This includes several important fixes: - We should register the pre/post-forward hooks on the module regardless of it has managed parameters. - We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters). - We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases). [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
ModuleWrapPolicy
#104427ignored_states
+ auto wrap (for now) #104418_get_fully_sharded_module_to_states
#104409fully_shard
auto wrap #104408_auto_wrap
forfully_shard
#104407ModuleWrapPolicy
to new path #104346This moves
fully_shard
to use_auto_wrap()
just likeFullyShardedDataParallel
. This means thatfully_shard
goes through the_init_param_handle_from_module()
path (i.e. 1fully_shard
per "wrap"), removing the need for_init_param_handles_from_module()
(which was 1fully_shard
for all "wraps" of a given policy)._auto_wrap()
simply callsfully_shard
on target submodules.This includes several important fixes:
_module_handles
to return[]
in the composable path (for when the module has no managed parameters)._get_buffers_and_dtypes_for_computation()
(previously, composable path was buggy in some cases).