-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][6/N] Check valid param freezing for ModuleWrapPolicy
#104427
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104427
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a03137f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util. [ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util. [ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util. [ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util. [ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util. [ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util. [ghstack-poisoned]
ghstack-source-id: 2ec596050c5c974ce924faa3146e773064e1c8ea Pull Request resolved: #104427
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util. [ghstack-poisoned]
reverse topological order to cover the full module tree. This differs from | ||
the ``_get_param_to_fqn()`` function meant to be called post-wrapping and | ||
on the full module tree in one shot. Given those differences, we do not try | ||
to unify the two. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function really smacks of "we want to know what the module tree would look like after we do wrapping, but we need to run this before we actually start wrapping, so we have to simulate what the effect would be." That's annoying, because if you change how the wrap works, you also have to adjust this function too (speaking of which, if this statement is true, it would be helpful to have some invariant, like for example _get_param_to_fqn(post_wrapping_module)
should report the same as _get_managed_param_to_fqn(pre_wrapping_module)
--I'm not sure if this holds, but it would be helpful to know if it is.)
to unify the two. | ||
""" | ||
param_to_fqn: Dict[nn.Parameter, str] = {} | ||
# Run BFS (or any tree traversal works) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are assuming no weight tying, is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not be making that assumption. However, for shared parameters, FSDP needs them to be assigned to at least the lowest common ancestor module or higher. A follow-up work item is to add this check to the auto wrapping path.
for param, fqn in param_to_fqn.items(): | ||
if param.requires_grad: | ||
nonfrozen_param_fqns.append(fqn) | ||
nonfrozen_param_numel += param.numel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of reporting numel, reporting nbytes seems better since that will adjust for differing dtype size (...well, unless there's some shenanigans where dtypes have to get bucketed together?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For fp8, we will need to add support for mixing dtypes in the same flat parameter. It should still be possible to pre-compute the bytes for these tensors though. It just requires some more plumbing.
) | ||
else: | ||
msg += " FSDP does not support wrapping such modules when use_orig_params=False. " | ||
msg += "If possible, wrap the frozen parameters with FSDP separately.\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, this requires splitting your module into multiple modules, some of which have frozen params and some of which don't, right? Is it always safe to take a module M and create two submodules beneath it M1 and M2, one of which has frozen params and the other which doesn't (even if I don't call M1.forward/M2.forward?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, this requires splitting your module into multiple modules, some of which have frozen params and some of which don't, right?
Yes.
Is it always safe to take a module M and create two submodules beneath it M1 and M2, one of which has frozen params and the other which doesn't (even if I don't call M1.forward/M2.forward?)
Unfortunately, it is important to still call M1.forward/M2.forward in order to run pre/post-forward logic, which would notably include the corresponding all-gathers in the pre-forward. In other words, if we just organize the parameters into modules but do not change the parent's forward to actually call those modules' forward, then they will not have their parameters all-gathered.
frozen_param_numel += param.numel() | ||
if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0: | ||
msg = f"{module_name} has both parameters with requires_grad=True and False." | ||
if use_orig_params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's non-obvious to me why not use_orig_params
doesn't work. Without using orig params, we are reconstructing the params from the flat buffer every iteration. What is the requires_grad
on this flat buffer? Clearly it has to be requires_grad=True
, since we want to accumulate gradients into it. So how do I reconstruct a parameter that doesn't require grad from this requires grad parameter? The naive thing to do is detach (letting me get at the data without setting up a gradient edge.) Is there something in the reduce-scatter that chokes afterwards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For use_orig_params=True
, mixing frozen and non-frozen does work functionally. It just may use memory higher than expected (hence a warning instead of an error).
For the frozen original parameters, we can just have them view into the flat parameter but still have requires_grad=False
. FSDP takes care to not forward a gradient into the .grad
attribute for those parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the use_orig_params=True
case makes sense to me. I'm wondering about use_orig_params=False
haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, I somehow missed the not
in not use_orig_params
.
I think the issue is for the optimizer. For use_orig_params=False
, the optimizer sees the sharded FlatParameter
s themselves. If the FlatParameter
logically contains both frozen and non-frozen parameters, one thing we can do is enforce zeros in all gradient elements for frozen parameters part of the FlatParameter
; however, this assumes that for the optimizer, a step with zero gradient is the same as no step (which I think is not always true). Nonetheless, maybe it is not unreasonable to just follow this behavior and set the flat_param.requires_grad = any(orig_param.requires_grad for orig_param in flat_param._params)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the interaction with optimizer certainly makes sense.
That being said, most optimizers, if you only ever have steps with zero gradients, it will be equivalent to not updating (e.g., momentum will always be zero). But you are losing most of the benefit from freezing in the first place here, so this is mostly an academic discussion.
ctx = self.assertWarnsRegex(UserWarning, msg) | ||
else: | ||
msg += "FSDP does not support wrapping such modules when use_orig_params=False." | ||
ctx = self.assertRaisesRegex(ValueError, msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider using assertExpectedInline! It makes updating these tests if you tweak the warning/error message much more pleasant, since you can just run the test with EXPECTTEST_ACCEPT=1
and it will automatically update all the assertions with the new text.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Everything algorithmically looks good. The main annoyance is logic duplication when you have to simulate the post-order wrapping process to accurately know which parameters bucket together. Will have to simultaneously update this code as well as the FSDP code if we ever change it... not that there really is much way we can change it, I guess 🤔
…cy`" This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters. - For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names. - For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names. - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain. <details> <summary> Why DFS via named_children() vs. Using named_modules()</summary> ``` LoraModel( (embed_tokens): Embedding(100, 32) (layers): ModuleList( (0-3): 4 x LoraDecoder( (attn): LoraAttention( (q_proj): Linear(in_features=32, out_features=32, bias=False) (lora_A): Linear(in_features=32, out_features=8, bias=False) (lora_B): Linear(in_features=8, out_features=32, bias=False) (k_proj): Linear(in_features=32, out_features=32, bias=False) (v_proj): Linear(in_features=32, out_features=32, bias=False) (o_proj): Linear(in_features=32, out_features=32, bias=False) ) (mlp): LoraMLP( (proj1): Linear(in_features=32, out_features=128, bias=False) (proj2): Linear(in_features=128, out_features=32, bias=False) ) (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ``` Reverse topological order with stack-based DFS via `named_children()`: ``` [ 'embed_tokens', 'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0', 'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1', 'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2', 'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3', 'layers', 'norm', '' ] ``` Reverse topological order with `named_modules()`: ``` [ 'norm', 'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3', 'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2', 'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0', 'layers', 'embed_tokens', '' ] ``` With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition. </details> [ghstack-poisoned]
…cy`" This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters. - For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names. - For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names. - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain. <details> <summary> Why DFS via named_children() vs. Using named_modules()</summary> ``` LoraModel( (embed_tokens): Embedding(100, 32) (layers): ModuleList( (0-3): 4 x LoraDecoder( (attn): LoraAttention( (q_proj): Linear(in_features=32, out_features=32, bias=False) (lora_A): Linear(in_features=32, out_features=8, bias=False) (lora_B): Linear(in_features=8, out_features=32, bias=False) (k_proj): Linear(in_features=32, out_features=32, bias=False) (v_proj): Linear(in_features=32, out_features=32, bias=False) (o_proj): Linear(in_features=32, out_features=32, bias=False) ) (mlp): LoraMLP( (proj1): Linear(in_features=32, out_features=128, bias=False) (proj2): Linear(in_features=128, out_features=32, bias=False) ) (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ``` Reverse topological order with stack-based DFS via `named_children()`: ``` [ 'embed_tokens', 'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0', 'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1', 'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2', 'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3', 'layers', 'norm', '' ] ``` Reverse topological order with `named_modules()`: ``` [ 'norm', 'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3', 'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2', 'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0', 'layers', 'embed_tokens', '' ] ``` With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition. </details> [ghstack-poisoned]
ghstack-source-id: 874b21ead066fa1cd21a9d527e2449886ecbf718 Pull Request resolved: pytorch#104427
…cy`" This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters. - For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names. - For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names. - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain. <details> <summary> Why DFS via named_children() vs. Using named_modules()</summary> ``` LoraModel( (embed_tokens): Embedding(100, 32) (layers): ModuleList( (0-3): 4 x LoraDecoder( (attn): LoraAttention( (q_proj): Linear(in_features=32, out_features=32, bias=False) (lora_A): Linear(in_features=32, out_features=8, bias=False) (lora_B): Linear(in_features=8, out_features=32, bias=False) (k_proj): Linear(in_features=32, out_features=32, bias=False) (v_proj): Linear(in_features=32, out_features=32, bias=False) (o_proj): Linear(in_features=32, out_features=32, bias=False) ) (mlp): LoraMLP( (proj1): Linear(in_features=32, out_features=128, bias=False) (proj2): Linear(in_features=128, out_features=32, bias=False) ) (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ``` Reverse topological order with stack-based DFS via `named_children()`: ``` [ 'embed_tokens', 'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0', 'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1', 'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2', 'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3', 'layers', 'norm', '' ] ``` Reverse topological order with `named_modules()`: ``` [ 'norm', 'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3', 'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2', 'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0', 'layers', 'embed_tokens', '' ] ``` With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition. </details> [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #104967 Approved by: https://github.com/rohan-varma ghstack dependencies: #104427
Pull Request resolved: #104999 Approved by: https://github.com/rohan-varma ghstack dependencies: #104427, #104967
…104969) This does some code organization improvement. - It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API. - It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing. This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage. Pull Request resolved: #104969 Approved by: https://github.com/rohan-varma ghstack dependencies: #104427, #104967, #104999
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired. The API is as follows: ``` def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]: ... policy = CustomPolicy(lambda_fn) ``` The `lambda_fn` can return: - `False` or `{}` to indicate no wrapping - `True` to indicate wrapping while inheriting the root's FSDP kwargs - Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root --- After this PR, the follow-up work items for auto wrapping are: 1. Add shared parameter validation 2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input Pull Request resolved: #104986 Approved by: https://github.com/ezyang ghstack dependencies: #104427, #104967, #104999, #104969
ghstack-source-id: c16f55b9ae3232062b2daa284aec7ccbb781c45f Pull Request resolved: #104427
Stack from ghstack (oldest at bottom):
CustomPolicy
#104986_FSDPPolicy.policy
with_Policy._run_policy
#104969ModuleWrapPolicy
to takeIterable
#104999ModuleWrapPolicy
#104427This PR adds improved error/warning messaging when auto wrapping with
ModuleWrapPolicy
in the presence of frozen parameters.use_orig_params=False
, FSDP requires uniformrequires_grad
for each FSDP instance. This PR adds aValueError
at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.use_orig_params=True
, FSDP allows non-uniformrequires_grad
for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds aUserWarning
at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.Why DFS via named_children() vs. Using named_modules()
Reverse topological order with stack-based DFS via
named_children()
:Reverse topological order with
named_modules()
:With the stack-based DFS via
named_children()
, reversing the topological order gives us each level in the module tree in the registered order, wheres withnamed_modules()
, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the first-registered module that violates the frozen/non-frozen condition.