-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA fine-tuning recipe #266
Conversation
❌ Deploy Preview for torchtune-preview failed.
|
@@ -25,6 +25,31 @@ | |||
# Modules from CausalSelfAttention that LoRA can be applied to | |||
LORA_ATTN_MODULES = Literal["q_proj", "k_proj", "v_proj", "output_proj"] | |||
|
|||
def lora_llama2_7b( | |||
lora_attn_modules: List[LORA_ATTN_MODULES], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we set a default here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discussed this with @joecummings as well in the original PR where we added this. I don't think we should set a default list, I was thinking of just having the default be None, but ultimately settled on making this mandatory. Btw another path around this is to have bools for each individual param, but I don't love that either
recipes/lora_finetune.py
Outdated
_set_trainable_params(model) | ||
|
||
# TODO: move this somewhere else | ||
def lora_custom_auto_wrap_policy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call me crazy, but I like the function here. Two reasons:
- Clearly this is a specific to this particular model (for now). We can think of generalizing as a util later.
- Can't think of a reason why this should be visible to the entire class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're crazy! :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanna bump this debate so we can decide where to put this. @rohan-varma I know you mention having a policies
folder/file elsewhere is one path. Personally I do not love this and if we are gonna move it I would rather keep it near the model definition (e.g. torchtune/models/lora_llama2.py
) where it's easier to find. Thoughts? @kartikayk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not to do a complete 180, but on the flip side one advantage of @rohan-varma's suggestion to use a separate policies file/folder is that we can maintain a mapping model -> FSDP policy there, which is definitely cleaner on the config side of things (then I can just pass my model_name and infer the appropriate FSDP wrapping if needed). Anyways I am gonna leave it here for the time being with a TODO
recipes/lora_finetune.py
Outdated
): | ||
if isinstance(module, modules.TransformerDecoderLayer): | ||
return True | ||
if hasattr(module, 'weight') and module.weight.requires_grad == True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, why are we doing this sort of wrapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I guess this works because only the lora_a and lora_b will be trainable / have requires_grad set to True the way we're setting things up. But this sort of relies on a good amount of assumptions and ordering of when we're wrapping with FSDP versus not. Also, the wrapping will change if users play around with what parts of the network are trainable or not, which is IMO an unexpected side effect.
Proposal for the wrapping function:
def lora_wrap(module, *args, **kwargs):
if recurse: return True
if isinstance(module, TransformerDecoderLayer):
return True
if hasattr(module, '_lora_module'):
return True
and then we mark i.e. https://github.com/pytorch-labs/torchtune/blob/6fb9fbc4441a5057710f84a8bbb3218053ca742a/torchtune/modules/peft/lora.py#L59 with the _lora_module
attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is why I was hoping to get the adapter_params method working on modules as well. Then we have a single point of definition for (a) checkpoint save/load, (b) trainable params, (c) FSDP wrap logic. This way we have a single property spread out across multiple places, which I don't love.
recipes/lora_finetune.py
Outdated
# TODO: move this somewhere else | ||
def lora_custom_auto_wrap_policy( | ||
module: nn.Module, | ||
recurse: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: since these are not used, just make it *args / **kwargs?
recipes/lora_finetune.py
Outdated
# TODO: move this somewhere else | ||
def lora_custom_auto_wrap_policy( | ||
module: nn.Module, | ||
recurse: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh this recurse flag needs to be checked. i.e. make sure to have a if recurse: return True
to basically tell the function in FSDP applying the wrapping to recurse.
recipes/lora_finetune.py
Outdated
_set_trainable_params(model) | ||
|
||
# TODO: move this somewhere else | ||
def lora_custom_auto_wrap_policy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would generalize this into a policies
folder / file where we can have training specific FSDP wrapping policies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will let you and @kartikayk duke it out over that one 😄
return getattr(torch.optim, optimizer)(model.parameters(), lr=lr) | ||
trainable_params = [p for n, p in model.named_parameters() if p.requires_grad] | ||
return getattr(torch.optim, optimizer)( | ||
trainable_params, lr=lr, weight_decay=weight_decay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: many, but not all, optimizers in torch.optim support the weight_decay param. Based on this we could also consider going the kwargs route here
# If we're resuming from checkpoint, the recipe's state should be updated before | ||
# initializing the training components. This ensures that the seed is correctly | ||
# propagated to the relevant components | ||
if self._resume_from_checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO mentions seed and epoch are not checkpointed, if we support resume in this recipe, not having this could lead to silent incorrectness. Shall we disable resume_from_checkpoint flag until we've implemented this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am gonna do some fast follow on checkpoint save regardless (just to save LoRA params). I think it is straightforward to add seed and epoch checkpointing in this PR so will do that. Personally I would bias towards just leaving resume_from_checkpoint support in to avoid thrash given that I will test it thoroughly when making the checkpoint save changes (prob later today). Lmk what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay with this as long as we do it asap later today / tomorrow
|
||
self._model = self._setup_model( | ||
model=params.model, | ||
lora_attn_modules=params.lora_attn_modules, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we envision this list of args as getting to long? Seems fine for now but wanted to put on your radar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. There is def a tradeoff between configurability and code readability. I def wouldn't add more args here. Another thing we can do is separate out the construction of the basic LoRA model class from wrapping with fsdp, activation checkpointing, and state dict loading.
recipes/lora_finetune.py
Outdated
|
||
# training setup | ||
self._autocast = utils.get_autocast(self._dtype, self._device) | ||
self._grad_scaler = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad_scaler always gets defined in if/else block, so this not needed?
recipes/lora_finetune.py
Outdated
# for logging and tracking training state. This should be computed after the dataloader | ||
# has been setup | ||
steps_per_epoch = len(self._dataloader) | ||
if self.max_steps_per_epoch and self.max_steps_per_epoch < len( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.max_steps_per_epoch -> if self.max_steps_per_epoch is not None
to be more explicit (and support the 0 case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would self.max_steps_per_epoch be 0? Either way I am good with this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have done it a few times just to immediately take a checkpoint
torchtune/utils/distributed.py
Outdated
@@ -104,6 +104,7 @@ def wrap_fsdp( | |||
dtype: torch.dtype, | |||
strategy: Optional[str] = None, | |||
auto_wrap_policy: Optional[Set[Type]] = None, | |||
custom_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both custom_wrap_policy and auto_wrap_policy can get confusing, maybe less so if we clearly document which one takes precedence.
I would just have one arg: custom_wrap_policy
that takes in a callable or set of nn.Modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm already confused.
torchtune/utils/distributed.py
Outdated
@@ -156,6 +160,8 @@ def wrap_fsdp( | |||
strategy = "FULL_SHARD" | |||
_validate_device_from_env(device) | |||
wrap_policy = ModuleWrapPolicy(auto_wrap_policy or set()) | |||
if custom_wrap_policy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This overriding logic should be documented if we do go with 2 policies in the function (not a fan of that though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess if we go with the util you added in #317 we don't need to modify this method, is that right? If that's the case I am good with that approach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep!
@@ -24,7 +27,10 @@ def get_optimizer(optimizer: str, model: torch.nn.Module, lr: float) -> Optimize | |||
ValueError: if the optimizer is not a valid optimizer from torch.optim. | |||
""" | |||
try: | |||
return getattr(torch.optim, optimizer)(model.parameters(), lr=lr) | |||
trainable_params = [p for n, p in model.named_parameters() if p.requires_grad] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is LoRA specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general we do not wanna optimize any parameters that do not require grad. So I think this is general enough to put in the get_optimizer method
self._metric_logger.log_dict( | ||
{ | ||
"loss": loss.item(), | ||
"lr": self._optimizer.param_groups[0]["lr"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test that this outputs the correct LR? Not sure that it will.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG overall, some small comments, will stamp after that!
# If we're resuming from checkpoint, the recipe's state should be updated before | ||
# initializing the training components. This ensures that the seed is correctly | ||
# propagated to the relevant components | ||
if self._resume_from_checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay with this as long as we do it asap later today / tomorrow
recipes/params.py
Outdated
Args: | ||
model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. | ||
model_checkpoint (str): Local path to load model checkpoint from. | ||
lora_attn_modules (List[str]): List of attention modules to use for LoRA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would change this to "List describing which attention modules to apply LoRA to", and also add examples / the supported strings.
recipes/params.py
Outdated
|
||
@dataclass | ||
class LoRAFinetuneParams: | ||
"""Arguments for the finetune_lora recipe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a caveat somewhere that mentions applying LoRA to MLP is not supported, and file issue for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -92,3 +91,30 @@ def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> No | |||
""" | |||
for k, v in model.named_parameters(): | |||
v.requires_grad_(k in adapter_params) | |||
|
|||
|
|||
def validate_state_dict_for_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd eventually want us to update this to be a full FQN check to guard against if somehow there is a partial match, but the state_dict is not valid (I know I'm being kind of paranoid). Can file this as a follow up issue
torchtune/modules/peft/peft_utils.py
Outdated
""" | ||
for x in missing_keys: | ||
if not any([k in x for k in lora_modules]): | ||
raise AssertionError(f"{k} is not a LoRA module {lora_modules}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it x
that is not a lora module? k
in this case is in lora_modules
?
@@ -92,3 +91,30 @@ def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> No | |||
""" | |||
for k, v in model.named_parameters(): | |||
v.requires_grad_(k in adapter_params) | |||
|
|||
|
|||
def validate_state_dict_for_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a unittest for this function
torchtune/modules/peft/peft_utils.py
Outdated
for x in missing_keys: | ||
if not any([k in x for k in lora_modules]): | ||
raise AssertionError(f"{k} is not a LoRA module {lora_modules}") | ||
if unexpected_keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be moved outside of the loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks so much!
@@ -92,3 +91,30 @@ def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> No | |||
""" | |||
for k, v in model.named_parameters(): | |||
v.requires_grad_(k in adapter_params) | |||
|
|||
|
|||
def validate_state_dict_for_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: have to add this to docs?
Update 2/5
After closing some of the gaps described in detail further down in the description, I am putting this PR up for review now. This PR adds a new recipe,
lora_finetune.py
, which is pretty similar to the existingfull_finetune.py
, but usinglora_llama2_7b
(defined in this PR: a LoRA-wrapped version ofllama2_7b
, where LoRA is applied to the Q and V projections of each layer's self-attention).Enumerating some of the todos here:
(1) Proper support for saving LoRA weights only in
save_checkpoint
and resuming a run by loading in the LoRA weights + base model weights + LoRA optimizer states.(2) Integration test a la
test_finetune_llm.py
(3) Proper evaluation via eleuther harness and comparison to lit-gpt
(4) Integrate gradient accumulation support
(5) LoRA integration with MLP linears
(6) Some UX cleanup, especially around FSDP APIs and usage of
get_optimizer
.(Probably more I'm missing, will add as I think of them)
Test plan
Loss curve
Also add unit test for state dict load validation utility:
Addendum: details from the draft version of this PR
Very ugly first version of LoRA FT script.
The initial version of this script (and our full fine-tuning script) had really noisy loss curves with a pretty weak trend.
After comparing to lit-gpt, we saw that the magnitude of the loss trend (especially relative to the amount of noise), was much smaller in our version than in theirs (sample lit-gpt loss curve below).
Additionally, lit-gpt's LoRA fine-tuning shows a clear improvement early in fine-tuning (within the first couple hundred iterations) that our LoRA/full fine-tuning runs did not.
After a bit of investigation, we found a couple gaps between our version and theirs:
(1) LR warmup with cosine annealing (lit-gpt) or linear decay (alpaca-lora), which we don't yet support.
(2) Grad accumulation (a known feature gap)
(3) Weight decay
However, none of these 3 made a material difference in the loss curve discrepancy. Finally, we found the real cause:
(4) We are training on the raw alpaca dataset, but most other implementations (lit-gpt, alpaca-lora, etc) train on a cleaned version of alpaca.
(5) Our dataset transform masks everything up to the end of the response tag for our labels (ref). But actually, many existing implementations do not do this (e.g. lit-gpt here and here, alpaca-lora here, this issue on llama-factory).
After changing all of the above, we can run the following command.
So far, our loss curve looks like this:
Note that plot scales, log frequency, and batch size are not 100% the same across the above figures, but hopefully the proximity of this new loss curve to the lit-gpt one (especially relative to our previous loss curve) is evident. Eventually we can run some proper evals, but haven't gotten that far yet.