-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][1/N] Move wrapper ModuleWrapPolicy
to new path
#104346
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104346
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ba95ffc: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: eafbb91bf2339aeea230b92a0356e52003c5ce7a Pull Request resolved: #104346
[ghstack-poisoned]
ghstack-source-id: 3c7ff44f030362717f1ae2fbc2b4cd9014ebedc6 Pull Request resolved: #104346
The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook. We can directly store the global state as a mapping. To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) - Refactor `fully_shard()` auto wrap unify with `FullyShardedDataParallel` auto wrap, where the only difference should be `fn_to_apply` in `_post_order_apply()` - This means that for `fully_shard()`'s auto wrap, it will call `fully_shard()` on the target submodules, constructing a new `_FSDPState` object for each just like for the wrapper path. - This change prohibits extensions like non-module-aligned wrapping, but it allows for unifying the code paths to decrease the likelihood for bugs. I do not foresee us pursuing non-module-aligned wrapping in the near term. - After this change, we can then revisit the `ignored_states` with auto wrapping fix and land that without changing `_unshard_params()`. [ghstack-poisoned]
ghstack-source-id: 892168e233a9e74b7750227c086cb4689ab23572 Pull Request resolved: #104346
ghstack-source-id: 892168e233a9e74b7750227c086cb4689ab23572 Pull Request resolved: pytorch#104346
The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) - Refactor `fully_shard()` auto wrap unify with `FullyShardedDataParallel` auto wrap, where the only difference should be `fn_to_apply` in `_post_order_apply()` - This means that for `fully_shard()`'s auto wrap, it will call `fully_shard()` on the target submodules, constructing a new `_FSDPState` object for each just like for the wrapper path. - This change prohibits extensions like non-module-aligned wrapping, but it allows for unifying the code paths to decrease the likelihood for bugs. I do not foresee us pursuing non-module-aligned wrapping in the near term. - After this change, we can then revisit the `ignored_states` with auto wrapping fix and land that without changing `_unshard_params()`. [ghstack-poisoned]
ghstack-source-id: 304383832c2b6bc8eafbfc7730d8198365049780 Pull Request resolved: pytorch#104346
This PR is the first in refactoring the auto wrapping. The end goal is to improve the auto wrapping infra to support: - Checking valid frozen parameters (uniform frozenness per FSDP) - Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher) - Writing auto wrapping policies that may take multiple passes over the module tree - Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy) The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping. I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) [ghstack-poisoned]
ghstack-source-id: b62d622a3c0745bc3149202a81ac730a7aefe995 Pull Request resolved: pytorch#104346
ghstack-source-id: b62d622a3c0745bc3149202a81ac730a7aefe995 Pull Request resolved: pytorch#104346
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support: - Checking valid frozen parameters (uniform frozenness per FSDP) - Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher) - Writing auto wrapping policies that may take multiple passes over the module tree - Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy) The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping. I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) [ghstack-poisoned]
ghstack-source-id: 6600d7ad0d44834537abefee871178fd3cdd6ff7 Pull Request resolved: pytorch#104346
ghstack-source-id: 6600d7ad0d44834537abefee871178fd3cdd6ff7 Pull Request resolved: pytorch#104346
ghstack-source-id: 6600d7ad0d44834537abefee871178fd3cdd6ff7 Pull Request resolved: pytorch#104346
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support: - Checking valid frozen parameters (uniform frozenness per FSDP) - Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher) - Writing auto wrapping policies that may take multiple passes over the module tree - Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy) The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping. I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) [ghstack-poisoned]
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support: - Checking valid frozen parameters (uniform frozenness per FSDP) - Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher) - Writing auto wrapping policies that may take multiple passes over the module tree - Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy) The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping. I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
# NOTE: If the forward did not have any floating-point tensors, | ||
# then the dtype will not be set for this module, and we do not | ||
# upcast the dtype. | ||
if module in _MODULE_TO_INP_DTYPE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so _MODULE_TO_INP_DTYPE generalizes old_dtype to be on a per-module basis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meta comment on this PR: is this really a "move wrapper policy to the new path"? The PR added a lots of new logic, i.e. adding new post_order_apply, maybe the PR name should be more descriptive?
Good point! I wonder how we can fit more info given the title character limit though :/ |
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support: - Checking valid frozen parameters (uniform frozenness per FSDP) - Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher) - Writing auto wrapping policies that may take multiple passes over the module tree - Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy) The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping. I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) [ghstack-poisoned]
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support: - Checking valid frozen parameters (uniform frozenness per FSDP) - Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher) - Writing auto wrapping policies that may take multiple passes over the module tree - Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy) The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping. I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) [ghstack-poisoned]
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support: - Checking valid frozen parameters (uniform frozenness per FSDP) - Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher) - Writing auto wrapping policies that may take multiple passes over the module tree - Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy) The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping. I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.) To-do in follow-ups (not in order): - Add frozen parameter check before `_post_order_apply()` - Add shared parameter check before `_post_order_apply()` - Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg) [ghstack-poisoned]
ghstack-source-id: 158dc95b1093450c647296ee05ef13192ed67fc9 Pull Request resolved: #104346
ghstack-source-id: 158dc95b1093450c647296ee05ef13192ed67fc9 Pull Request resolved: #104346
ghstack-source-id: 158dc95b1093450c647296ee05ef13192ed67fc9 Pull Request resolved: #104346
Stack from ghstack (oldest at bottom):
ModuleWrapPolicy
#104427ignored_states
+ auto wrap (for now) #104418_get_fully_sharded_module_to_states
#104409fully_shard
auto wrap #104408_auto_wrap
forfully_shard
#104407ModuleWrapPolicy
to new path #104346This PR is the first in refactoring the auto wrapping, only affecting
ModuleWrapPolicy
for wrapperFullyShardedDataParallel
. The end goal is to improve the auto wrapping infra to support:The way I envision achieving this is that, we decouple the actual "wrapping" (which is
_post_order_apply()
in this PR) from constructing the wrapping targets and kwargs (which istarget_module_to_kwargs
in this PR). In that way, a policy reduces to just constructing that lattertarget_module_to_kwargs
mapping.I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple.
The change to how
old_dtype
is handled is mainly to avoid keeping a reference to_override_module_mixed_precision()
function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)To-do in follow-ups (not in order):
_post_order_apply()
_post_order_apply()