Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA fine-tuning recipe #266

Merged
merged 18 commits into from
Feb 6, 2024
Merged

LoRA fine-tuning recipe #266

merged 18 commits into from
Feb 6, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Jan 30, 2024

Update 2/5

After closing some of the gaps described in detail further down in the description, I am putting this PR up for review now. This PR adds a new recipe, lora_finetune.py, which is pretty similar to the existing full_finetune.py, but using lora_llama2_7b (defined in this PR: a LoRA-wrapped version of llama2_7b, where LoRA is applied to the Q and V projections of each layer's self-attention).

Enumerating some of the todos here:

(1) Proper support for saving LoRA weights only in save_checkpoint and resuming a run by loading in the LoRA weights + base model weights + LoRA optimizer states.
(2) Integration test a la test_finetune_llm.py
(3) Proper evaluation via eleuther harness and comparison to lit-gpt
(4) Integrate gradient accumulation support
(5) LoRA integration with MLP linears
(6) Some UX cleanup, especially around FSDP APIs and usage of get_optimizer.
(Probably more I'm missing, will add as I think of them)

Test plan

torchrun --nnodes 1 --nproc_per_node 8 recipes/lora_finetune.py \
--config recipes/configs/alpaca_llama2_lora_finetune.yaml --override \
model_checkpoint='/data/users/ebs/checkpoints/lora-debug/llama2-7b-01242024' seed=18 \
tokenizer_checkpoint='/data/users/ebs/checkpoints/lora-debug/tokenizer.model' \
output_dir='/data/users/ebs/checkpoints/lora-debug' \
metric_logger_type='wandb' project='lora-debug' log_every_n_steps=5

Loss curve

Screenshot 2024-02-06 at 8 28 00 AM

Also add unit test for state dict load validation utility:

python -m pytest -v tests/torchtune/modules/peft/test_peft_utils.py
...
======= 9 passed, 11 warnings in 4.97s =========

Addendum: details from the draft version of this PR

Very ugly first version of LoRA FT script.

The initial version of this script (and our full fine-tuning script) had really noisy loss curves with a pretty weak trend.
Screenshot 2024-02-03 at 3 55 31 PM

After comparing to lit-gpt, we saw that the magnitude of the loss trend (especially relative to the amount of noise), was much smaller in our version than in theirs (sample lit-gpt loss curve below).
Screenshot 2024-02-03 at 3 57 35 PM

Additionally, lit-gpt's LoRA fine-tuning shows a clear improvement early in fine-tuning (within the first couple hundred iterations) that our LoRA/full fine-tuning runs did not.

After a bit of investigation, we found a couple gaps between our version and theirs:

(1) LR warmup with cosine annealing (lit-gpt) or linear decay (alpaca-lora), which we don't yet support.
(2) Grad accumulation (a known feature gap)
(3) Weight decay

However, none of these 3 made a material difference in the loss curve discrepancy. Finally, we found the real cause:

(4) We are training on the raw alpaca dataset, but most other implementations (lit-gpt, alpaca-lora, etc) train on a cleaned version of alpaca.
(5) Our dataset transform masks everything up to the end of the response tag for our labels (ref). But actually, many existing implementations do not do this (e.g. lit-gpt here and here, alpaca-lora here, this issue on llama-factory).

After changing all of the above, we can run the following command.

with-proxy torchrun --nnodes 1 --nproc_per_node 8 recipes/lora_finetune.py --config recipes/configs/alpaca_llama2_lora_finetune.yaml --model-checkpoint /data/users/ebs/checkpoints/lora-debug/llama2-7b-01242024 --seed 18 --tokenizer-checkpoint /data/users/ebs/checkpoints/lora-debug/tokenizer.model --output-dir /data/users/ebs/checkpoints/lora-debug

So far, our loss curve looks like this:

Screenshot 2024-02-03 at 4 05 17 PM

Note that plot scales, log frequency, and batch size are not 100% the same across the above figures, but hopefully the proximity of this new loss curve to the lit-gpt one (especially relative to our previous loss curve) is evident. Eventually we can run some proper evals, but haven't gotten that far yet.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 30, 2024
Copy link

netlify bot commented Jan 30, 2024

Deploy Preview for torchtune-preview failed.

Name Link
🔨 Latest commit ecab693
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65c2c56abe99780008c67cdf

@@ -25,6 +25,31 @@
# Modules from CausalSelfAttention that LoRA can be applied to
LORA_ATTN_MODULES = Literal["q_proj", "k_proj", "v_proj", "output_proj"]

def lora_llama2_7b(
lora_attn_modules: List[LORA_ATTN_MODULES],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we set a default here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this with @joecummings as well in the original PR where we added this. I don't think we should set a default list, I was thinking of just having the default be None, but ultimately settled on making this mandatory. Btw another path around this is to have bools for each individual param, but I don't love that either

_set_trainable_params(model)

# TODO: move this somewhere else
def lora_custom_auto_wrap_policy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call me crazy, but I like the function here. Two reasons:

  • Clearly this is a specific to this particular model (for now). We can think of generalizing as a util later.
  • Can't think of a reason why this should be visible to the entire class

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're crazy! :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanna bump this debate so we can decide where to put this. @rohan-varma I know you mention having a policies folder/file elsewhere is one path. Personally I do not love this and if we are gonna move it I would rather keep it near the model definition (e.g. torchtune/models/lora_llama2.py) where it's easier to find. Thoughts? @kartikayk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not to do a complete 180, but on the flip side one advantage of @rohan-varma's suggestion to use a separate policies file/folder is that we can maintain a mapping model -> FSDP policy there, which is definitely cleaner on the config side of things (then I can just pass my model_name and infer the appropriate FSDP wrapping if needed). Anyways I am gonna leave it here for the time being with a TODO

):
if isinstance(module, modules.TransformerDecoderLayer):
return True
if hasattr(module, 'weight') and module.weight.requires_grad == True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why are we doing this sort of wrapping?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I guess this works because only the lora_a and lora_b will be trainable / have requires_grad set to True the way we're setting things up. But this sort of relies on a good amount of assumptions and ordering of when we're wrapping with FSDP versus not. Also, the wrapping will change if users play around with what parts of the network are trainable or not, which is IMO an unexpected side effect.

Proposal for the wrapping function:

def lora_wrap(module, *args, **kwargs):
    if recurse: return True
    if isinstance(module, TransformerDecoderLayer):
        return True
    if hasattr(module, '_lora_module'):
        return True

and then we mark i.e. https://github.com/pytorch-labs/torchtune/blob/6fb9fbc4441a5057710f84a8bbb3218053ca742a/torchtune/modules/peft/lora.py#L59 with the _lora_module attribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is why I was hoping to get the adapter_params method working on modules as well. Then we have a single point of definition for (a) checkpoint save/load, (b) trainable params, (c) FSDP wrap logic. This way we have a single property spread out across multiple places, which I don't love.

# TODO: move this somewhere else
def lora_custom_auto_wrap_policy(
module: nn.Module,
recurse: bool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since these are not used, just make it *args / **kwargs?

# TODO: move this somewhere else
def lora_custom_auto_wrap_policy(
module: nn.Module,
recurse: bool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this recurse flag needs to be checked. i.e. make sure to have a if recurse: return True to basically tell the function in FSDP applying the wrapping to recurse.

_set_trainable_params(model)

# TODO: move this somewhere else
def lora_custom_auto_wrap_policy(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would generalize this into a policies folder / file where we can have training specific FSDP wrapping policies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let you and @kartikayk duke it out over that one 😄

return getattr(torch.optim, optimizer)(model.parameters(), lr=lr)
trainable_params = [p for n, p in model.named_parameters() if p.requires_grad]
return getattr(torch.optim, optimizer)(
trainable_params, lr=lr, weight_decay=weight_decay
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: many, but not all, optimizers in torch.optim support the weight_decay param. Based on this we could also consider going the kwargs route here

@ebsmothers ebsmothers changed the title [DRAFT] first version of LoRA FT (very hacky) LoRA fine-tuning recipe Feb 6, 2024
@ebsmothers ebsmothers marked this pull request as ready for review February 6, 2024 05:09
# If we're resuming from checkpoint, the recipe's state should be updated before
# initializing the training components. This ensures that the seed is correctly
# propagated to the relevant components
if self._resume_from_checkpoint:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO mentions seed and epoch are not checkpointed, if we support resume in this recipe, not having this could lead to silent incorrectness. Shall we disable resume_from_checkpoint flag until we've implemented this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am gonna do some fast follow on checkpoint save regardless (just to save LoRA params). I think it is straightforward to add seed and epoch checkpointing in this PR so will do that. Personally I would bias towards just leaving resume_from_checkpoint support in to avoid thrash given that I will test it thoroughly when making the checkpoint save changes (prob later today). Lmk what you think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with this as long as we do it asap later today / tomorrow


self._model = self._setup_model(
model=params.model,
lora_attn_modules=params.lora_attn_modules,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we envision this list of args as getting to long? Seems fine for now but wanted to put on your radar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. There is def a tradeoff between configurability and code readability. I def wouldn't add more args here. Another thing we can do is separate out the construction of the basic LoRA model class from wrapping with fsdp, activation checkpointing, and state dict loading.


# training setup
self._autocast = utils.get_autocast(self._dtype, self._device)
self._grad_scaler = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_scaler always gets defined in if/else block, so this not needed?

# for logging and tracking training state. This should be computed after the dataloader
# has been setup
steps_per_epoch = len(self._dataloader)
if self.max_steps_per_epoch and self.max_steps_per_epoch < len(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.max_steps_per_epoch -> if self.max_steps_per_epoch is not None to be more explicit (and support the 0 case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would self.max_steps_per_epoch be 0? Either way I am good with this change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have done it a few times just to immediately take a checkpoint

recipes/params.py Show resolved Hide resolved
recipes/params.py Show resolved Hide resolved
torchtune/models/lora_llama2.py Show resolved Hide resolved
@@ -104,6 +104,7 @@ def wrap_fsdp(
dtype: torch.dtype,
strategy: Optional[str] = None,
auto_wrap_policy: Optional[Set[Type]] = None,
custom_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both custom_wrap_policy and auto_wrap_policy can get confusing, maybe less so if we clearly document which one takes precedence.

I would just have one arg: custom_wrap_policy that takes in a callable or set of nn.Modules.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm already confused.

@@ -156,6 +160,8 @@ def wrap_fsdp(
strategy = "FULL_SHARD"
_validate_device_from_env(device)
wrap_policy = ModuleWrapPolicy(auto_wrap_policy or set())
if custom_wrap_policy:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This overriding logic should be documented if we do go with 2 policies in the function (not a fan of that though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if we go with the util you added in #317 we don't need to modify this method, is that right? If that's the case I am good with that approach

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

@@ -24,7 +27,10 @@ def get_optimizer(optimizer: str, model: torch.nn.Module, lr: float) -> Optimize
ValueError: if the optimizer is not a valid optimizer from torch.optim.
"""
try:
return getattr(torch.optim, optimizer)(model.parameters(), lr=lr)
trainable_params = [p for n, p in model.named_parameters() if p.requires_grad]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is LoRA specific?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we do not wanna optimize any parameters that do not require grad. So I think this is general enough to put in the get_optimizer method

self._metric_logger.log_dict(
{
"loss": loss.item(),
"lr": self._optimizer.param_groups[0]["lr"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test that this outputs the correct LR? Not sure that it will.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you say that? Attaching a sample from one of my recent runs

Screenshot 2024-02-06 at 1 17 56 PM

@rohan-varma rohan-varma self-requested a review February 6, 2024 23:21
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG overall, some small comments, will stamp after that!

# If we're resuming from checkpoint, the recipe's state should be updated before
# initializing the training components. This ensures that the seed is correctly
# propagated to the relevant components
if self._resume_from_checkpoint:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with this as long as we do it asap later today / tomorrow

Args:
model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options.
model_checkpoint (str): Local path to load model checkpoint from.
lora_attn_modules (List[str]): List of attention modules to use for LoRA.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would change this to "List describing which attention modules to apply LoRA to", and also add examples / the supported strings.


@dataclass
class LoRAFinetuneParams:
"""Arguments for the finetune_lora recipe.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a caveat somewhere that mentions applying LoRA to MLP is not supported, and file issue for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -92,3 +91,30 @@ def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> No
"""
for k, v in model.named_parameters():
v.requires_grad_(k in adapter_params)


def validate_state_dict_for_lora(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd eventually want us to update this to be a full FQN check to guard against if somehow there is a partial match, but the state_dict is not valid (I know I'm being kind of paranoid). Can file this as a follow up issue

"""
for x in missing_keys:
if not any([k in x for k in lora_modules]):
raise AssertionError(f"{k} is not a LoRA module {lora_modules}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it x that is not a lora module? k in this case is in lora_modules?

@@ -92,3 +91,30 @@ def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> No
"""
for k, v in model.named_parameters():
v.requires_grad_(k in adapter_params)


def validate_state_dict_for_lora(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a unittest for this function

for x in missing_keys:
if not any([k in x for k in lora_modules]):
raise AssertionError(f"{k} is not a LoRA module {lora_modules}")
if unexpected_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be moved outside of the loop

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks so much!

@@ -92,3 +91,30 @@ def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> No
"""
for k, v in model.named_parameters():
v.requires_grad_(k in adapter_params)


def validate_state_dict_for_lora(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: have to add this to docs?

@ebsmothers ebsmothers merged commit 7898961 into pytorch:main Feb 6, 2024
11 of 15 checks passed
@ebsmothers ebsmothers deleted the lora-ft-recipe branch February 7, 2024 00:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants