Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][1/N] Move wrapper ModuleWrapPolicy to new path #104346

Closed
wants to merge 11 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jun 28, 2023

Stack from ghstack (oldest at bottom):

This PR is the first in refactoring the auto wrapping, only affecting ModuleWrapPolicy for wrapper FullyShardedDataParallel. The end goal is to improve the auto wrapping infra to support:

  • Checking valid frozen parameters (uniform frozenness per FSDP)
  • Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
  • Writing auto wrapping policies that may take multiple passes over the module tree
  • Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is _post_order_apply() in this PR) from constructing the wrapping targets and kwargs (which is target_module_to_kwargs in this PR). In that way, a policy reduces to just constructing that latter target_module_to_kwargs mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple.

The change to how old_dtype is handled is mainly to avoid keeping a reference to _override_module_mixed_precision() function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)

To-do in follow-ups (not in order):

  • Add frozen parameter check before _post_order_apply()
  • Add shared parameter check before _post_order_apply()
  • Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)

@pytorch-bot pytorch-bot bot added release notes: distributed (fsdp) release notes category labels Jun 28, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104346

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ba95ffc:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu added a commit that referenced this pull request Jun 28, 2023
ghstack-source-id: eafbb91bf2339aeea230b92a0356e52003c5ce7a
Pull Request resolved: #104346
awgu added a commit that referenced this pull request Jun 28, 2023
ghstack-source-id: 3c7ff44f030362717f1ae2fbc2b4cd9014ebedc6
Pull Request resolved: #104346
The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook. We can directly store the global state as a mapping.


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)
- Refactor `fully_shard()` auto wrap unify with `FullyShardedDataParallel` auto wrap, where the only difference should be `fn_to_apply` in `_post_order_apply()`
    - This means that for `fully_shard()`'s auto wrap, it will call `fully_shard()` on the target submodules, constructing a new `_FSDPState` object for each just like for the wrapper path.
    - This change prohibits extensions like non-module-aligned wrapping, but it allows for unifying the code paths to decrease the likelihood for bugs. I do not foresee us pursuing non-module-aligned wrapping in the near term.
    - After this change, we can then revisit the `ignored_states` with auto wrapping fix and land that without changing `_unshard_params()`.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jun 28, 2023
ghstack-source-id: 892168e233a9e74b7750227c086cb4689ab23572
Pull Request resolved: #104346
awgu added a commit to awgu/pytorch that referenced this pull request Jun 29, 2023
ghstack-source-id: 892168e233a9e74b7750227c086cb4689ab23572
Pull Request resolved: pytorch#104346
The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)
- Refactor `fully_shard()` auto wrap unify with `FullyShardedDataParallel` auto wrap, where the only difference should be `fn_to_apply` in `_post_order_apply()`
    - This means that for `fully_shard()`'s auto wrap, it will call `fully_shard()` on the target submodules, constructing a new `_FSDPState` object for each just like for the wrapper path.
    - This change prohibits extensions like non-module-aligned wrapping, but it allows for unifying the code paths to decrease the likelihood for bugs. I do not foresee us pursuing non-module-aligned wrapping in the near term.
    - After this change, we can then revisit the `ignored_states` with auto wrapping fix and land that without changing `_unshard_params()`.

[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jun 29, 2023
ghstack-source-id: 304383832c2b6bc8eafbfc7730d8198365049780
Pull Request resolved: pytorch#104346
This PR is the first in refactoring the auto wrapping. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. 

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)


[ghstack-poisoned]
@awgu awgu added the topic: not user facing topic category label Jun 29, 2023
awgu added a commit to awgu/pytorch that referenced this pull request Jun 29, 2023
ghstack-source-id: b62d622a3c0745bc3149202a81ac730a7aefe995
Pull Request resolved: pytorch#104346
awgu added a commit to awgu/pytorch that referenced this pull request Jun 30, 2023
ghstack-source-id: b62d622a3c0745bc3149202a81ac730a7aefe995
Pull Request resolved: pytorch#104346
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. 

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)


[ghstack-poisoned]
awgu added a commit to awgu/pytorch that referenced this pull request Jun 30, 2023
ghstack-source-id: 6600d7ad0d44834537abefee871178fd3cdd6ff7
Pull Request resolved: pytorch#104346
awgu added a commit to awgu/pytorch that referenced this pull request Jul 5, 2023
ghstack-source-id: 6600d7ad0d44834537abefee871178fd3cdd6ff7
Pull Request resolved: pytorch#104346
awgu added a commit to awgu/pytorch that referenced this pull request Jul 5, 2023
ghstack-source-id: 6600d7ad0d44834537abefee871178fd3cdd6ff7
Pull Request resolved: pytorch#104346
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. 

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)


[ghstack-poisoned]
@awgu awgu marked this pull request as ready for review July 5, 2023 15:35
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. 

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)


[ghstack-poisoned]
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# NOTE: If the forward did not have any floating-point tensors,
# then the dtype will not be set for this module, and we do not
# upcast the dtype.
if module in _MODULE_TO_INP_DTYPE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so _MODULE_TO_INP_DTYPE generalizes old_dtype to be on a per-module basis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

torch/distributed/fsdp/wrap.py Show resolved Hide resolved
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta comment on this PR: is this really a "move wrapper policy to the new path"? The PR added a lots of new logic, i.e. adding new post_order_apply, maybe the PR name should be more descriptive?

@awgu
Copy link
Contributor Author

awgu commented Jul 6, 2023

meta comment on this PR: is this really a "move wrapper policy to the new path"? The PR added a lots of new logic, i.e. adding new post_order_apply, maybe the PR name should be more descriptive?

Good point! I wonder how we can fit more info given the title character limit though :/

This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. 

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)


[ghstack-poisoned]
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 7, 2023
awgu added 2 commits July 7, 2023 14:42
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. 

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)


[ghstack-poisoned]
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple. 

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)


To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)


[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/awgu/409/head branch July 11, 2023 14:16
voznesenskym pushed a commit that referenced this pull request Jul 19, 2023
ghstack-source-id: 158dc95b1093450c647296ee05ef13192ed67fc9
Pull Request resolved: #104346
voznesenskym pushed a commit that referenced this pull request Jul 21, 2023
ghstack-source-id: 158dc95b1093450c647296ee05ef13192ed67fc9
Pull Request resolved: #104346
voznesenskym pushed a commit that referenced this pull request Aug 7, 2023
ghstack-source-id: 158dc95b1093450c647296ee05ef13192ed67fc9
Pull Request resolved: #104346
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants