-
Notifications
You must be signed in to change notification settings - Fork 419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable LoRA + FSDP2 #855
enable LoRA + FSDP2 #855
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/855
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8fbbc4b with merge base 135cf2e (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one main question for discussion is what the torchtune folks feel about defining an explicit initializer for the RoPE theta buffer, which would be required to do meta-device init (which should speed up initialization!).
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
recipes/lora_finetune_distributed.py
Outdated
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
|
||
for m in reversed(list(model.modules())): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My noob understanding here is that you want to account for the Linear module before the TransformerDecoderLayer - is that right? If so, is there a better way to do this? It wasn't immediately obvious to me that reversing this list gets me that outcome - but maybe that's because I haven't played around with the modules function? Disregard if this is a pretty standard way to achieve this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to quickly mention my opinion:
What we want is a post-order traversal of the modules so that we visit lower modules before higher ones. nn.Module.modules()
always gives reverse post-order, so reversing it gives post-order.
FSDP1 hides all of this under the auto_wrap_policy
argument, which does the post-order traversal for the user. Personally, I did not want to have an auto_wrap_policy
argument for FSDP2 because the name is a misnomer. Auto wrapping is not automatically choosing which modules to wrap for the user -- rather, it just performing a post-order traversal and applying a user-defined policy to determine if a module is wrapped or not. What I am open to though is having some other higher-level utility (say, apply_fsdp()
) that does the same thing as auto_wrap_policy
but exists outside of the function fully_shard
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this thorough explanation, it helps to understand this code block a lot. I agree it'd be nice to have an apply_fsdp
utility but not a major concern here. Can we add code comments here to get the point you described across? (Basically that we are iterating over lower-order modules to higher-level modules and wrapping individual transformer layers. And I assume the separate wrapping of trainable LoRA weights is more related to the point you mentioned today about lower memory, rather than the flat param rationale of grads being allocated per shard?) Users will take this recipe code as a starting point so the more explicit we are here the easier they'll find it to extend.
recipes/lora_finetune_distributed.py
Outdated
utils.load_from_full_model_state_dict( | ||
model, lora_weights_state_dict, self._device, self._is_rank_zero | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:
- the modules on which we call
fully_shard
- init on meta device
- calling
initialize_parameters
andreset_parameters
Also I think there was a technical reason with FSDP1 to call the function reset_parameters
. Is that still true? Or can we standardize this with initialize_parameters
in the modules code? Happy to chat about this offline!
recipes/lora_finetune_distributed.py
Outdated
|
||
for m in reversed(list(model.modules())): | ||
if ( | ||
isinstance(m, nn.Linear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be LoRALinear? Or am I misunderstanding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.Linear
with required_grad=True
are referring to lora_a
and lora_b
Trunk FSDP1 is wrapping lora_a
and lora_b
separately thus we are do the same wrapping for parity. But as you mentioned, we can wrap LoRALinear
instead so lora_a
and lora_b
are communicated together
recipes/lora_finetune_distributed.py
Outdated
m.lora_a.to_empty(device=self._device) | ||
m.lora_b.to_empty(device=self._device) | ||
m.initialize_parameters() | ||
if isinstance(m, modules.RotaryPositionalEmbeddings): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, we special handle RoPE because the buffer is not being loaded from a state dict, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately.
@@ -242,6 +250,79 @@ def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs): | |||
return lora_wrap_fsdp | |||
|
|||
|
|||
def load_from_full_model_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit unfortunate but we have two different meanings of "full model" IIUC:
- One is related to FSDP i.e. full vs sharded - is that right?
- Other is LoRA i.e. full model vs LoRA adapters
The current function is a bit confusing since we pass in the adapter state_dict and this is referred to as full_sd
in the function itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. you are right "full" is opposing to "sharded" here. What do you think if rename to "local"? load_from_local_model_state_dict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add docstrings for the functions added in this file?
torchtune/utils/_distributed.py
Outdated
for param_name, full_tensor in full_sd.items(): | ||
sharded_meta_param = meta_sharded_sd.get(param_name) | ||
sharded_tensor = distribute_tensor( | ||
full_tensor, sharded_meta_param.device_mesh, sharded_meta_param.placements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is device_mesh and placements information coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's from fully_shard(model, mesh)
. we are using default mesh
where every rank serves FSDP, since there is no 2D/3D parallasim involved
after fully_shard(model, mesh)
, model.parameters()
is converted from plain tensor to DTensor
with mesh
synced with @kartikayk on action items
|
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great! Most of my comments are minor and around documentation.
I saw there was also some discussion around the location for this recipe. I am inclined to agree with the point you and @kartikayk discussed around putting it in recipes/dev
. I don't want that to be a long-term home for it, but also want to make sure we don't break folks who are using the current recipe and aren't on the latest version of PyTorch yet.
recipes/lora_finetune_distributed.py
Outdated
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
|
||
for m in reversed(list(model.modules())): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this thorough explanation, it helps to understand this code block a lot. I agree it'd be nice to have an apply_fsdp
utility but not a major concern here. Can we add code comments here to get the point you described across? (Basically that we are iterating over lower-order modules to higher-level modules and wrapping individual transformer layers. And I assume the separate wrapping of trainable LoRA weights is more related to the point you mentioned today about lower memory, rather than the flat param rationale of grads being allocated per shard?) Users will take this recipe code as a starting point so the more explicit we are here the easier they'll find it to extend.
recipes/lora_finetune_distributed.py
Outdated
m.lora_a.to_empty(device=self._device) | ||
m.lora_b.to_empty(device=self._device) | ||
m.initialize_parameters() | ||
if isinstance(m, modules.RotaryPositionalEmbeddings): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately.
recipes/lora_finetune_distributed.py
Outdated
# LoRA hyper-params needed for merging weights while saving checkpoints | ||
self._lora_rank = cfg_model.lora_rank | ||
self._lora_alpha = cfg_model.lora_alpha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just fyi might need a merge, I think these are actually now defined in the version of recipe we have in main
@pytest.mark.skipif( | ||
version.parse(torch.__version__).base_version < "2.4.0", reason="" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make sure I understand, our distributed LoRA recipe will now only work on torch >= 2.4.0, is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's correct. FSDP2 will be released in 2.4.0, although it has been in nightly for a while
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying! In that case I think it makes sense to create dev/recipes/lora_finetune_fsdp2.py
(or something like that) in the short-term, then migrate to replace recipes/lora_finetune_distributed.py
once we've socialized and gotten enough users onto >= 2.4.0. Let me know how that sounds to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proposing a middle ground: type hinting with "FSDPModule" without try...except
def load_from_full_model_state_dict(
model: "FSDPModule",
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do also have gpu_test in torchtune, can we use that here for the sake of consistency?
@@ -242,6 +250,79 @@ def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs): | |||
return lora_wrap_fsdp | |||
|
|||
|
|||
def load_from_full_model_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add docstrings for the functions added in this file?
) | ||
sharded_sd[param_name] = nn.Parameter(sharded_tensor) | ||
# choose `assign=True` since we cannot call `copy_` on meta tensor | ||
model.load_state_dict(sharded_sd, strict=False, assign=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we catch missing and unexpected keys from load_state_dict
with strict=False
what format will the keys be in? Previously with FSDP1 the keys contained all the info about FSDP wrapping. E.g.model.layers.0._fsdp_flat_param.attn.q_proj.weight
(probably not exactly right but something like that). Will that still be the case here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for FSDP2, it's clean FQNs without FSDP prefix. For example, layers.0.attn.q_proj.lora_a.weight
FSDP2 is clean because 1) fully_shard
register hooks instead wrap nn.Module
, 2) fully_shard
set module.__class__ = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
https://fburl.com/i20yr3s2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I think this means we can actually do validation of LoRA state dict load more cleanly (note that we actually have two separate utilities for this for the single-device vs distributed case because of this FSDP prefix issue). Not a concern for this PR but this will allow us to clean up our code a bit
torchtune/utils/_distributed.py
Outdated
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]): | ||
if pid not in state: | ||
continue | ||
param_state = state[pid] | ||
full_param_state = full_state[full_pid] | ||
for attr, full_tensor in full_param_state.items(): | ||
sharded_tensor = param_state[attr] | ||
if isinstance(sharded_tensor, DTensor): | ||
param_state[attr] = distribute_tensor( | ||
full_tensor, | ||
sharded_tensor.device_mesh, | ||
sharded_tensor.placements, | ||
) | ||
else: | ||
param_state[attr] = full_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be useful to add some comments for this code block
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple small comments on the configs but otherwise no major concerns. Really excited to see this in our library!
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably need to do a find and replace of lora_finetune_distributed
-> lora_finetune_fsdp2
in all three config files
# Config for multi-device LoRA in lora_finetune_distributed.py | ||
# using a Llama2 13B model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can update this header to mention that this config is for the recipe using FSDP2 (I know the config file is the same, but nice visibility to just explicitly call it out at the top of the file)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch! updated to lora_finetune_fsdp2 and mentioned FSDP2
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
# iterating from lowerer modules to higher | ||
# eg grouping lora adapters before transformer block | ||
for m in reversed(list(model.modules())): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, another option here is to just make two passes if that is clearer.
for module in model.modules():
if <LoRA adapter>:
fully_shard(module)
for module in model.modules():
if <transformer block>:
fully_shard(module)
how to run it
tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
&&tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/7B_lora
tune download meta-llama/Llama-2-70b-hf --output-dir /tmp/Llama-2-70b-hf --hf-token <HF_TOKEN>
&&tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama2/70B_lora
recipe tests:
pytest tests/recipes/test_lora_finetune_fsdp2.py -m integration_test
unit test:
pytest tests/torchtune/utils/test_distributed.py -k test_state_dict
Highlights
FSDP2 changes
meta
instead ofcpu
checkpoint changes
mmap=True
. convert plain tensor intoDTensor
DTensor
into plain tensoroptional
optimizer.foreach=True
since pytorch PR have not land yet [DTensor] Turn on foreach implementation of optimizer for DTensor by default pytorch#123394