Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][9/N] Introduce CustomPolicy #104986

Closed
wants to merge 9 commits into from
Closed

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jul 11, 2023

Stack from ghstack (oldest at bottom):

This PR adds a new CustomPolicy that acts like the existing lambda_auto_wrap_policy except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:

def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = CustomPolicy(lambda_fn)

The lambda_fn can return:

  • False or {} to indicate no wrapping
  • True to indicate wrapping while inheriting the root's FSDP kwargs
  • Non-empty dict to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

After this PR, the follow-up work items for auto wrapping are:

  1. Add shared parameter validation
  2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104986

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 1 Unrelated Failure

As of commit 3fdf50f:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jul 11, 2023
awgu added a commit that referenced this pull request Jul 11, 2023
ghstack-source-id: 2a97189df60fc83b9310f08928cfe3df0198297e
Pull Request resolved: #104986
@awgu awgu added the topic: improvements topic category label Jul 11, 2023
@awgu awgu marked this pull request as ready for review July 11, 2023 16:13
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = LambdaWrapPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root




[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 11, 2023
ghstack-source-id: aaeb5aa27683ba2e349bd1926345c2a6bdf8ea9b
Pull Request resolved: #104986
awgu added a commit to awgu/pytorch that referenced this pull request Jul 11, 2023
ghstack-source-id: aaeb5aa27683ba2e349bd1926345c2a6bdf8ea9b
Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = LambdaWrapPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root




[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 11, 2023
ghstack-source-id: 316176d9ee4b5903c8d0bf860e2dcf3417fb5f0c
Pull Request resolved: #104986
awgu added a commit to awgu/pytorch that referenced this pull request Jul 11, 2023
ghstack-source-id: 316176d9ee4b5903c8d0bf860e2dcf3417fb5f0c
Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = LambdaWrapPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root




[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 11, 2023
ghstack-source-id: 6dc615032a8544d4e70666c7587dad2f88adb489
Pull Request resolved: #104986
awgu added a commit to awgu/pytorch that referenced this pull request Jul 11, 2023
ghstack-source-id: 6dc615032a8544d4e70666c7587dad2f88adb489
Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = LambdaWrapPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input


[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 11, 2023
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432
Pull Request resolved: #104986
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall, thanks for adding this! Just have a few questions.

if isinstance(module, nn.Sequential):
return True
elif isinstance(module, FakeSequential):
return {"backward_prefetch": BackwardPrefetch.BACKWARD_POST}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a test to ensure we raise a good error when we return an invalid key? like {"foo": 1}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one option here is to use inspect to get the arg signature of the module-level API passed to the policy. However, I think the default error message from passing in an unexpected arg/kwarg is not unreasonable, so I will not worry about this change in this PR:

TypeError: __init__() got an unexpected keyword argument 'foo'


if use_uniform_kwargs:

def lambda_fn(module: nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does LambdaPolicy still get overriden and individually wrap BatchNorms if mixed precision is configured?

In general, what happens if my LambdaPolicy tells to wrap a BN, but mixed precision overrides this? What should be the expected behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, our mixed precision overrides run last. I am not sure if we should special case the LambdaWrapPolicy to be highest priority. What do you think?

I was leaning toward just keeping our overrides as absolute for now.

test/distributed/fsdp/test_wrap.py Show resolved Hide resolved

def _test_frozen_params(self, use_orig_params: bool):
def _test_frozen_params(self, use_orig_params: bool, policy: _WrapPolicy):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a test for the mixed precision overriding as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

@@ -209,6 +211,48 @@ def __repr__(self) -> str:
return super().__repr__() + f"({self._module_classes_str})"


class LambdaWrapPolicy(_WrapPolicy):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add an example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

awgu added a commit to awgu/pytorch that referenced this pull request Jul 19, 2023
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432
Pull Request resolved: pytorch#104986
awgu added a commit to awgu/pytorch that referenced this pull request Jul 21, 2023
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432
Pull Request resolved: pytorch#104986
awgu added a commit to awgu/pytorch that referenced this pull request Jul 26, 2023
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432
Pull Request resolved: pytorch#104986
@@ -209,6 +211,48 @@ def __repr__(self) -> str:
return super().__repr__() + f"({self._module_classes_str})"


class LambdaWrapPolicy(_WrapPolicy):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You solicited some name bikeshedding over WC, so I'll try my best here. Assuming we are keeping WrapPolicy suffix, CallbackWrapPolicy or CustomWrapPolicy both convey intent here reasonably well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also try PredicateWrapPolicy. This is not a perfect match: I would call lambda_fn a predicate if it only returned a bool, but you return a little bit extra information sometimes.


def _run_policy(
self,
root_module: nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, is there a way to easily determine the FQN of the module at this point? Without the root module it seems like it might be difficult, whereas, the code that calls _run_policy would be able to generate it without much difficulty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point. I do think we need access to the root module to be able to generate FQNs ("we" encompasses both in FSDP and for users).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a downstream package I maintain, I've been allowing users to provide module FQN-based complements to a given auto_wrap_policy using a custom _FSDPPolicy (NameDrivenPolicy). I should be able to continue offering that functionality with some minor refactoring after this PR, but I thought it might be worth mentioning in the context of this PR because:

  1. To call attention to this pattern since I've received positive user feedback regarding the utility FQN-based wrapping (to accommodate more flexible fine-tuning requirements in this case)
  2. In case the team considers the functionality sufficiently useful to offer it natively at some point in the future (I'm not necessarily advocating that, just thinking aloud... either way this PR should help make such extensions cleaner)

Anyway, nice enhancement @awgu! A pleasure as always working with this thoughtfully designed API and constantly improving package.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! This is super helpful, especially learning more about the customer feedback.

I am thinking that CustomPolicy can stay generic, and we can push specific logic to the lambda_fn construction. If the user wants to use FQNs, then the user can use the same trick as your implementation where you translate FQNs to Python id()s and then use the id()s in the lambda_fn, which will see the module as argument. Like you said, after some refactoring, you should still be able to make use of CustomPolicy. The question is whether we should offer some API that permits FQNs directly. I think that is something that we can wait and see. For now, we can just add CustomPolicy that can serve as a base for other extensions.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question is whether we should offer some API that permits FQNs directly. I think that is something that we can wait and see. For now, we can just add CustomPolicy that can serve as a base for other extensions.

Agree completely, I think prioritizing API stability and simplicity at this juncture is the prudent approach. Adding an API that permits FQNs directly might be worth considering at some point in the future but that feature isn't sufficiently crucial to justify its inclusion at present (especially since implementation approaches like my id() translation one we discussed aren't too onerous).

Thanks for the quick and thoughtful response!

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable.

awgu added a commit to awgu/pytorch that referenced this pull request Jul 31, 2023
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432
Pull Request resolved: pytorch#104986
@awgu awgu changed the title [FSDP][9/N] Introduce LambdaWrapPolicy [FSDP][9/N] Introduce CustomPolicy Aug 2, 2023
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = LambdaWrapPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input


[ghstack-poisoned]
awgu added a commit that referenced this pull request Aug 2, 2023
ghstack-source-id: 16789fae18bb459d11a17988761b6fb754c2a91c
Pull Request resolved: #104986
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 2, 2023
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = LambdaWrapPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input


[ghstack-poisoned]
awgu added a commit that referenced this pull request Aug 2, 2023
ghstack-source-id: f2047acf56ad947009524531f997efad6fe5914a
Pull Request resolved: #104986
awgu added a commit to awgu/pytorch that referenced this pull request Aug 2, 2023
ghstack-source-id: f2047acf56ad947009524531f997efad6fe5914a
Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = LambdaWrapPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input


[ghstack-poisoned]
awgu added a commit that referenced this pull request Aug 2, 2023
ghstack-source-id: 3d3413d4209481f1c7248eead8a9ad92115a39a5
Pull Request resolved: #104986
@awgu
Copy link
Contributor Author

awgu commented Aug 3, 2023

@awgu
Copy link
Contributor Author

awgu commented Aug 3, 2023

@pytorchbot merge -f "unrelated failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/awgu/420/head branch August 6, 2023 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants