New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][9/N] Introduce CustomPolicy
#104986
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104986
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit 3fdf50f: UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 2a97189df60fc83b9310f08928cfe3df0198297e Pull Request resolved: #104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = LambdaWrapPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root [ghstack-poisoned]
ghstack-source-id: aaeb5aa27683ba2e349bd1926345c2a6bdf8ea9b Pull Request resolved: #104986
ghstack-source-id: aaeb5aa27683ba2e349bd1926345c2a6bdf8ea9b Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = LambdaWrapPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root [ghstack-poisoned]
ghstack-source-id: 316176d9ee4b5903c8d0bf860e2dcf3417fb5f0c Pull Request resolved: #104986
ghstack-source-id: 316176d9ee4b5903c8d0bf860e2dcf3417fb5f0c Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = LambdaWrapPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root [ghstack-poisoned]
ghstack-source-id: 6dc615032a8544d4e70666c7587dad2f88adb489 Pull Request resolved: #104986
ghstack-source-id: 6dc615032a8544d4e70666c7587dad2f88adb489 Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = LambdaWrapPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root --- After this PR, the follow-up work items for auto wrapping are: 1. Add shared parameter validation 2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input [ghstack-poisoned]
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432 Pull Request resolved: #104986
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall, thanks for adding this! Just have a few questions.
if isinstance(module, nn.Sequential): | ||
return True | ||
elif isinstance(module, FakeSequential): | ||
return {"backward_prefetch": BackwardPrefetch.BACKWARD_POST} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a test to ensure we raise a good error when we return an invalid key? like {"foo": 1}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one option here is to use inspect
to get the arg signature of the module-level API passed to the policy. However, I think the default error message from passing in an unexpected arg/kwarg is not unreasonable, so I will not worry about this change in this PR:
TypeError: __init__() got an unexpected keyword argument 'foo'
|
||
if use_uniform_kwargs: | ||
|
||
def lambda_fn(module: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does LambdaPolicy still get overriden and individually wrap BatchNorms if mixed precision is configured?
In general, what happens if my LambdaPolicy tells to wrap a BN, but mixed precision overrides this? What should be the expected behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, our mixed precision overrides run last. I am not sure if we should special case the LambdaWrapPolicy
to be highest priority. What do you think?
I was leaning toward just keeping our overrides as absolute for now.
test/distributed/fsdp/test_wrap.py
Outdated
|
||
def _test_frozen_params(self, use_orig_params: bool): | ||
def _test_frozen_params(self, use_orig_params: bool, policy: _WrapPolicy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a test for the mixed precision overriding as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
torch/distributed/fsdp/wrap.py
Outdated
@@ -209,6 +211,48 @@ def __repr__(self) -> str: | |||
return super().__repr__() + f"({self._module_classes_str})" | |||
|
|||
|
|||
class LambdaWrapPolicy(_WrapPolicy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432 Pull Request resolved: pytorch#104986
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432 Pull Request resolved: pytorch#104986
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432 Pull Request resolved: pytorch#104986
torch/distributed/fsdp/wrap.py
Outdated
@@ -209,6 +211,48 @@ def __repr__(self) -> str: | |||
return super().__repr__() + f"({self._module_classes_str})" | |||
|
|||
|
|||
class LambdaWrapPolicy(_WrapPolicy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You solicited some name bikeshedding over WC, so I'll try my best here. Assuming we are keeping WrapPolicy suffix, CallbackWrapPolicy or CustomWrapPolicy both convey intent here reasonably well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also try PredicateWrapPolicy. This is not a perfect match: I would call lambda_fn
a predicate if it only returned a bool, but you return a little bit extra information sometimes.
|
||
def _run_policy( | ||
self, | ||
root_module: nn.Module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC, is there a way to easily determine the FQN of the module at this point? Without the root module it seems like it might be difficult, whereas, the code that calls _run_policy
would be able to generate it without much difficulty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a good point. I do think we need access to the root module to be able to generate FQNs ("we" encompasses both in FSDP and for users).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a downstream package I maintain, I've been allowing users to provide module FQN-based complements to a given auto_wrap_policy
using a custom _FSDPPolicy
(NameDrivenPolicy). I should be able to continue offering that functionality with some minor refactoring after this PR, but I thought it might be worth mentioning in the context of this PR because:
- To call attention to this pattern since I've received positive user feedback regarding the utility FQN-based wrapping (to accommodate more flexible fine-tuning requirements in this case)
- In case the team considers the functionality sufficiently useful to offer it natively at some point in the future (I'm not necessarily advocating that, just thinking aloud... either way this PR should help make such extensions cleaner)
Anyway, nice enhancement @awgu! A pleasure as always working with this thoughtfully designed API and constantly improving package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer! This is super helpful, especially learning more about the customer feedback.
I am thinking that CustomPolicy
can stay generic, and we can push specific logic to the lambda_fn
construction. If the user wants to use FQNs, then the user can use the same trick as your implementation where you translate FQNs to Python id()
s and then use the id()
s in the lambda_fn
, which will see the module
as argument. Like you said, after some refactoring, you should still be able to make use of CustomPolicy
. The question is whether we should offer some API that permits FQNs directly. I think that is something that we can wait and see. For now, we can just add CustomPolicy
that can serve as a base for other extensions.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The question is whether we should offer some API that permits FQNs directly. I think that is something that we can wait and see. For now, we can just add
CustomPolicy
that can serve as a base for other extensions.
Agree completely, I think prioritizing API stability and simplicity at this juncture is the prudent approach. Adding an API that permits FQNs directly might be worth considering at some point in the future but that feature isn't sufficiently crucial to justify its inclusion at present (especially since implementation approaches like my id()
translation one we discussed aren't too onerous).
Thanks for the quick and thoughtful response!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable.
ghstack-source-id: 4ac5da42fcdcc1b44ed8d9fb6677295a3fb9d432 Pull Request resolved: pytorch#104986
LambdaWrapPolicy
CustomPolicy
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = LambdaWrapPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root --- After this PR, the follow-up work items for auto wrapping are: 1. Add shared parameter validation 2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input [ghstack-poisoned]
ghstack-source-id: 16789fae18bb459d11a17988761b6fb754c2a91c Pull Request resolved: #104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = LambdaWrapPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root --- After this PR, the follow-up work items for auto wrapping are: 1. Add shared parameter validation 2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input [ghstack-poisoned]
ghstack-source-id: f2047acf56ad947009524531f997efad6fe5914a Pull Request resolved: #104986
ghstack-source-id: f2047acf56ad947009524531f997efad6fe5914a Pull Request resolved: pytorch#104986
This PR adds a new `LambdaWrapPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = LambdaWrapPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root --- After this PR, the follow-up work items for auto wrapping are: 1. Add shared parameter validation 2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input [ghstack-poisoned]
ghstack-source-id: 3d3413d4209481f1c7248eead8a9ad92115a39a5 Pull Request resolved: #104986
Unstable and unrelated: |
@pytorchbot merge -f "unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
CustomPolicy
#104986_FSDPPolicy.policy
with_Policy._run_policy
#104969ModuleWrapPolicy
to takeIterable
#104999ModuleWrapPolicy
#104427This PR adds a new
CustomPolicy
that acts like the existinglambda_auto_wrap_policy
except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.The API is as follows:
The
lambda_fn
can return:False
or{}
to indicate no wrappingTrue
to indicate wrapping while inheriting the root's FSDP kwargsdict
to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the rootAfter this PR, the follow-up work items for auto wrapping are: