Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP][6/N] Check valid param freezing for ModuleWrapPolicy #104427

Closed
wants to merge 19 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jun 29, 2023

Stack from ghstack (oldest at bottom):

This PR adds improved error/warning messaging when auto wrapping with ModuleWrapPolicy in the presence of frozen parameters.

  • For use_orig_params=False, FSDP requires uniform requires_grad for each FSDP instance. This PR adds a ValueError at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
  • For use_orig_params=True, FSDP allows non-uniform requires_grad for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a UserWarning at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
    • There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.
Why DFS via named_children() vs. Using named_modules()
LoraModel(
  (embed_tokens): Embedding(100, 32)
  (layers): ModuleList(
    (0-3): 4 x LoraDecoder(
      (attn): LoraAttention(
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (lora_A): Linear(in_features=32, out_features=8, bias=False)
        (lora_B): Linear(in_features=8, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (o_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): LoraMLP(
        (proj1): Linear(in_features=32, out_features=128, bias=False)
        (proj2): Linear(in_features=128, out_features=32, bias=False)
      )
      (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)

Reverse topological order with stack-based DFS via named_children():

[
  'embed_tokens',
  'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
  'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
  'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
  'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
  'layers', 'norm', ''
]

Reverse topological order with named_modules():

[
  'norm',
  'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
  'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
  'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
  'layers', 'embed_tokens', ''
]

With the stack-based DFS via named_children(), reversing the topological order gives us each level in the module tree in the registered order, wheres with named_modules(), reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the first-registered module that violates the frozen/non-frozen condition.

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104427

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a03137f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util.

[ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util.

[ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util.

[ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util.

[ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util.

[ghstack-poisoned]
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jun 30, 2023
ghstack-source-id: 2ec596050c5c974ce924faa3146e773064e1c8ea
Pull Request resolved: #104427
Note that this util is _not_ integrated into our auto wrapping yet. This PR just adds the util.

[ghstack-poisoned]
reverse topological order to cover the full module tree. This differs from
the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
on the full module tree in one shot. Given those differences, we do not try
to unify the two.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function really smacks of "we want to know what the module tree would look like after we do wrapping, but we need to run this before we actually start wrapping, so we have to simulate what the effect would be." That's annoying, because if you change how the wrap works, you also have to adjust this function too (speaking of which, if this statement is true, it would be helpful to have some invariant, like for example _get_param_to_fqn(post_wrapping_module) should report the same as _get_managed_param_to_fqn(pre_wrapping_module)--I'm not sure if this holds, but it would be helpful to know if it is.)

to unify the two.
"""
param_to_fqn: Dict[nn.Parameter, str] = {}
# Run BFS (or any tree traversal works)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are assuming no weight tying, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be making that assumption. However, for shared parameters, FSDP needs them to be assigned to at least the lowest common ancestor module or higher. A follow-up work item is to add this check to the auto wrapping path.

for param, fqn in param_to_fqn.items():
if param.requires_grad:
nonfrozen_param_fqns.append(fqn)
nonfrozen_param_numel += param.numel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of reporting numel, reporting nbytes seems better since that will adjust for differing dtype size (...well, unless there's some shenanigans where dtypes have to get bucketed together?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp8, we will need to add support for mixing dtypes in the same flat parameter. It should still be possible to pre-compute the bytes for these tensors though. It just requires some more plumbing.

)
else:
msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, this requires splitting your module into multiple modules, some of which have frozen params and some of which don't, right? Is it always safe to take a module M and create two submodules beneath it M1 and M2, one of which has frozen params and the other which doesn't (even if I don't call M1.forward/M2.forward?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, this requires splitting your module into multiple modules, some of which have frozen params and some of which don't, right?

Yes.

Is it always safe to take a module M and create two submodules beneath it M1 and M2, one of which has frozen params and the other which doesn't (even if I don't call M1.forward/M2.forward?)

Unfortunately, it is important to still call M1.forward/M2.forward in order to run pre/post-forward logic, which would notably include the corresponding all-gathers in the pre-forward. In other words, if we just organize the parameters into modules but do not change the parent's forward to actually call those modules' forward, then they will not have their parameters all-gathered.

frozen_param_numel += param.numel()
if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
msg = f"{module_name} has both parameters with requires_grad=True and False."
if use_orig_params:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's non-obvious to me why not use_orig_params doesn't work. Without using orig params, we are reconstructing the params from the flat buffer every iteration. What is the requires_grad on this flat buffer? Clearly it has to be requires_grad=True, since we want to accumulate gradients into it. So how do I reconstruct a parameter that doesn't require grad from this requires grad parameter? The naive thing to do is detach (letting me get at the data without setting up a gradient edge.) Is there something in the reduce-scatter that chokes afterwards?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For use_orig_params=True, mixing frozen and non-frozen does work functionally. It just may use memory higher than expected (hence a warning instead of an error).

For the frozen original parameters, we can just have them view into the flat parameter but still have requires_grad=False. FSDP takes care to not forward a gradient into the .grad attribute for those parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the use_orig_params=True case makes sense to me. I'm wondering about use_orig_params=False haha

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I somehow missed the not in not use_orig_params.

I think the issue is for the optimizer. For use_orig_params=False, the optimizer sees the sharded FlatParameters themselves. If the FlatParameter logically contains both frozen and non-frozen parameters, one thing we can do is enforce zeros in all gradient elements for frozen parameters part of the FlatParameter; however, this assumes that for the optimizer, a step with zero gradient is the same as no step (which I think is not always true). Nonetheless, maybe it is not unreasonable to just follow this behavior and set the flat_param.requires_grad = any(orig_param.requires_grad for orig_param in flat_param._params).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the interaction with optimizer certainly makes sense.

That being said, most optimizers, if you only ever have steps with zero gradients, it will be equivalent to not updating (e.g., momentum will always be zero). But you are losing most of the benefit from freezing in the first place here, so this is mostly an academic discussion.

ctx = self.assertWarnsRegex(UserWarning, msg)
else:
msg += "FSDP does not support wrapping such modules when use_orig_params=False."
ctx = self.assertRaisesRegex(ValueError, msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider using assertExpectedInline! It makes updating these tests if you tweak the warning/error message much more pleasant, since you can just run the test with EXPECTTEST_ACCEPT=1 and it will automatically update all the assertions with the new text.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Everything algorithmically looks good. The main annoyance is logic duplication when you have to simulate the post-order wrapping process to accurately know which parameters bucket together. Will have to simultaneously update this code as well as the FSDP code if we ever change it... not that there really is much way we can change it, I guess 🤔

…cy`"



This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
    - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.


<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>

```
LoraModel(
  (embed_tokens): Embedding(100, 32)
  (layers): ModuleList(
    (0-3): 4 x LoraDecoder(
      (attn): LoraAttention(
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (lora_A): Linear(in_features=32, out_features=8, bias=False)
        (lora_B): Linear(in_features=8, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (o_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): LoraMLP(
        (proj1): Linear(in_features=32, out_features=128, bias=False)
        (proj2): Linear(in_features=128, out_features=32, bias=False)
      )
      (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
  'embed_tokens',
  'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
  'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
  'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
  'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
  'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
  'norm',
  'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
  'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
  'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
  'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.


</details>





[ghstack-poisoned]
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 2, 2023
…cy`"



This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
    - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.


<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>

```
LoraModel(
  (embed_tokens): Embedding(100, 32)
  (layers): ModuleList(
    (0-3): 4 x LoraDecoder(
      (attn): LoraAttention(
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (lora_A): Linear(in_features=32, out_features=8, bias=False)
        (lora_B): Linear(in_features=8, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (o_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): LoraMLP(
        (proj1): Linear(in_features=32, out_features=128, bias=False)
        (proj2): Linear(in_features=128, out_features=32, bias=False)
      )
      (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
  'embed_tokens',
  'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
  'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
  'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
  'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
  'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
  'norm',
  'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
  'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
  'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
  'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.


</details>





[ghstack-poisoned]
@awgu awgu requested a review from penguinwu as a code owner August 2, 2023 16:17
awgu added a commit to awgu/pytorch that referenced this pull request Aug 2, 2023
ghstack-source-id: 874b21ead066fa1cd21a9d527e2449886ecbf718
Pull Request resolved: pytorch#104427
…cy`"



This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
    - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.


<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>

```
LoraModel(
  (embed_tokens): Embedding(100, 32)
  (layers): ModuleList(
    (0-3): 4 x LoraDecoder(
      (attn): LoraAttention(
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (lora_A): Linear(in_features=32, out_features=8, bias=False)
        (lora_B): Linear(in_features=8, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (o_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): LoraMLP(
        (proj1): Linear(in_features=32, out_features=128, bias=False)
        (proj2): Linear(in_features=128, out_features=32, bias=False)
      )
      (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
  'embed_tokens',
  'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
  'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
  'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
  'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
  'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
  'norm',
  'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
  'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
  'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
  'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.


</details>





[ghstack-poisoned]
@awgu
Copy link
Contributor Author

awgu commented Aug 2, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 2, 2023
pytorchmergebot pushed a commit that referenced this pull request Aug 2, 2023
pytorchmergebot pushed a commit that referenced this pull request Aug 3, 2023
…104969)

This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: #104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
pytorchmergebot pushed a commit that referenced this pull request Aug 3, 2023
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input

Pull Request resolved: #104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
@facebook-github-bot facebook-github-bot deleted the gh/awgu/415/head branch August 6, 2023 14:16
voznesenskym pushed a commit that referenced this pull request Aug 7, 2023
ghstack-source-id: c16f55b9ae3232062b2daa284aec7ccbb781c45f
Pull Request resolved: #104427
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants