Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding model builders for code-llama2 7b, 13b, and 70b #847

Merged
merged 15 commits into from
Apr 26, 2024

Conversation

SalmanMohammadi
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi commented Apr 23, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

See #826

Changelog

Added model builders for code-llama2 7b, 13b, and 70b based on base llama2 params and extended vocab size and sequence length of code-llama2. Capitalised all instances of 'b' as in '7b' in torchtune/models/llama2/_model_builders.py.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented Apr 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/847

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 98ef8b8 with merge base bec7bab (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 23, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good, thanks for picking it up! One request: can you run a couple of our fine-tune scripts as a sanity check and share the results on the PR? Mainly just want to make sure that the model gets loaded in correctly, loss is decreasing, etc.

num_heads=40,
num_kv_heads=40,
embed_dim=5120,
intermediate_dim=13824,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think you don't need this (the math for scale_hidden_dim_for_mlp should work). Fine to keep it in just to be explicit though (honestly I am leaning towards that approach more and more cause I'm sick of all these integer-rounded calculations 😅 )

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers, what's our overall approach for continuing to guarantee correctness here? Ideally we should have tests similar to llama 2 base models, where we verify equivalence via a loaded checkpoint compared to a reference implementation? Or are we confident enough given that building blocks are sufficiently tested?

@SalmanMohammadi
Copy link
Contributor Author

SalmanMohammadi commented Apr 23, 2024

Full fine tuning using the low memory config runs fine in colab. See the wandb run here. I'll let it run for ~30 minutes for now, unless you need information from later in training.

I can't test the lora fine tuning since it hasn't been implemented yet. I can try give full fine-tuning (without low_memory) a go if I can fit it on the GPU there.

One thing that springs to mind for testing @rohan-varma could be ensuring specific reference weights load correctly for models like these (e.g. codellama/CodeLlama-7b-Instruct-hf). In my example, the weights were in a slightly unexpected format for tune run and I manually specified the checkpoints.

A quick note: bitsandbytes wasn't installed by default for running with the low_memory config. I saw in a previous issue you were debating including it in requirements.

@rohan-varma
Copy link
Member

@SalmanMohammadi I can't open the colab notebook you've shared (error includes Ask the notebook's author to reshare the notebook with download permissions enabled and try loading it again), do you mind checking that on your end? Other than that, great to see that we're able to run on colab!

@SalmanMohammadi
Copy link
Contributor Author

SalmanMohammadi commented Apr 23, 2024

Try now? The wanbd link should work too.
It was pretty straightforward! Unfortunately, none of the models can fit in the free GPU since bf16 isn't supported on the free GPU - but otherwise super neat.

@ebsmothers
Copy link
Contributor

Lots of things to reply to here 😃

what's our overall approach for continuing to guarantee correctness here?

@rohan-varma this is a good question. In this case my mental model is that this is just a variant of an existing model (Llama2) and so the individual components should already be well-tested. Some E2E test is still helpful to verify that (a) checkpoints load correctly and (b) no regressions due to other changes (e.g. I think the tokenizer has a slightly different vocab size). Admittedly this is subjective and depends on the level of granularity we define models at though (e.g. most of our models are instances of TransformerDecoder but I don't think it's sufficient to just have one test for that class and claim all new models we add are covered).

I can't test the lora fine tuning since it hasn't been implemented yet.

@SalmanMohammadi what does this mean? Can't you just plug the code-llama2 models into our existing LoRA recipes? (Apologies if I'm missing something obvious here though)

In my example, the weights were in a slightly unexpected format for tune run and I manually specified the checkpoints.

Can you elaborate on this? Did you need to make any changes to the checkpoints themselves or just the file paths? (Feel free to just paste your CLI command or config file if that's easiest)

Unfortunately, none of the models can fit in the free GPU since bf16 isn't supported on the free GPU - but otherwise super neat.

Actually we can do QLoRA for Llama2-7B now in the free tier (though we do still OOM on checkpoint save). We also have smaller Gemma models, I haven't tested them myself but I think these should be OK in fp32 on the free tier too.

@SalmanMohammadi
Copy link
Contributor Author

SalmanMohammadi commented Apr 24, 2024

Can't you just plug the code-llama2 models into our existing LoRA recipes? (Apologies if I'm missing something obvious here though)

Sorry, by "not implemented" I just mean that the QLoRA and LoRA recipes use the llama2.lora_llama2_ and llama2.qlora_lamma2_ model builders, and I need to create llama2.lora_code_llama2_ etc. I'll add that now and run it on colab.

Can you elaborate on this? Did you need to make any changes to the checkpoints themselves or just the file paths? (Feel free to just paste your CLI command or config file if that's easiest)

Just the files! My CLI command was:

!tune run full_finetune_single_device \
--config llama2/7B_full_low_memory \
checkpointer.checkpoint_dir=/tmp/CodeLlama-7b-Instruct-hf \
checkpointer.checkpoint_files=['pytorch_model-00001-of-00003.bin','pytorch_model-00002-of-00003.bin','pytorch_model-00003-of-00003.bin'] \
tokenizer.path=/tmp/CodeLlama-7b-Instruct-hf/tokenizer.model \
metric_logger=torchtune.utils.metric_logging.WandBLogger \
metric_logger.project=torchtune_codellama_testing \
model=torchtune.models.llama2.code_llama2_7b

This is because the model checkpoints I'd grabbed were in this format:

!ls /tmp/CodeLlama-7b-Instruct-hf
...
pytorch_model-00001-of-00003.bin
pytorch_model-00002-of-00003.bin
pytorch_model-00003-of-00003.bin

But tune was expecting pytorch_model-00001-of-00002.bin for the first.

@SalmanMohammadi
Copy link
Contributor Author

SalmanMohammadi commented Apr 24, 2024

I've added lora_ and qlora_ code_llama_{}b models, and also added a qlora_llama2_70b while I was at it. torchtune/models/llama2/_model_builders.py is getting pretty chunky. Do you guys care about this/would you want to throw the lora model builders or code_llama_ builders in a separate file?

I've completed the following training tests using my colab above, and added some memory usage info for reference. You can see all the runs here:

  • code_llama2_7b with full_finetune_single_device with _low_memory - wandb run. Peak memory usage 14.5GB
  • lora_code_llama2_7b with lora_finetune_single_device - wandb run. Peak memory usage 14.1GB.
  • qlora_code_llama2_13b with lora_finetune_single_device - wandb run. Peak memory usage 9.4GB.

Let me know if there's anything else I can do :)

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi this is just an all-round awesome PR! I think the point about the model_builder file becoming unwieldy is very accurate. Do you mind just spliting out codellama to its own top level folder like llama2 and llama3? I think this will address both the problem at hand and make these models more discoverable. Let me know if that makes sense to you. Overall this is an awesome add to the repo!

@SalmanMohammadi
Copy link
Contributor Author

SalmanMohammadi commented Apr 25, 2024

Thanks so much for the kind feedback @kartikayk :) I've always wanted to contribute to the pytorch ecosystem - it's really nice to get the opportunity to work with such a welcoming open-source community.

Sorry for so many commits. Lots of different components, lots of docs, and I couldn't test locally.

I've updated with the refactor. I also added recipe configs, and added the additional recipe configs to torchtune/_recipe_registry.py for ease of discovery. Confirmed tests/test_import_recipes.py runs OK. I took the liberty to update the README.md model support table, I hope you don't mind : ) Hopefully this helps people to get started quickly fine-tuning code-llama-2 models. tune ls now outputs:

RECIPE                                   CONFIG                                  
full_finetune_single_device              ...               
                                         code_llama2/7B_full_low_memory          
                                         ...                                     
lora_finetune_single_device              ...           
                                         code_llama2/7B_lora_single_device       
                                         code_llama2/7B_qlora_single_device      
                                         ...          

I've confirmed all the recipes I added work nicely on my colab without any additional config specifications:

tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf
tune run full_finetune_single_device --config code_llama2/7B_full_low_memory 
tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device
tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device

Note: My initial colab tests in #847 (comment) were for Code-Llama2-7b-Instruct, but I've generalized the recipes to just Code-Llama2-7b. It's hopefully trivial for users to use the instruct models instead.

lora_code_llama2_70b,
lora_code_llama2_7b,
qlora_code_llama2_13b,
qlora_code_llama2_70b,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that we should add any QLoRA + 70B models yet, since it will require FSDP. And as of right now the combination of these two is not yet well-supported. Cc @rohan-varma for thoughts here

quantize_base=quantize_base,
)


qlora_llama2_70b = partial(lora_llama2_70b, quantize_base=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here re QLoRA + 70B

@ebsmothers
Copy link
Contributor

Sorry for so many commits. Lots of different components, lots of docs, and I couldn't test locally.

@SalmanMohammadi I'm curious about this comment. Is this just due to particulars of your dev setup? Mainly I am wondering if there's anything we can be doing on our end to make contribution smoother (whether it be ease of testing, clearer documentation, anything like that). If you have any feedback on this front do let me know!

@SalmanMohammadi
Copy link
Contributor Author

SalmanMohammadi commented Apr 26, 2024

I updated the docs, and I've just taken out the QLoRA 70B models if it's out of scope out ATM, particularly for this PR.

Is this just due to particulars of your dev setup?

I think it's somewhat my current dev setup being unable to test full-scale model trainings right now, and a bit of learning the codebase so iterating on "trying something out on colab to see what it breaks", and also I could've been a bit more careful updating each of the docs and recipes. So mostly on my side!

One thing I was thinking of was adding a rough workflow for common contributions for which there's a high standard for. In the example of #840, I've been thinking of updating tests/torchtune/models/llama2/scripts/README.md to add some of your comments from #840 (comment) (or other useful insights you provide when we start writing mistral tests).

*commit message should read *"removing qlora 70b code_llama2 and llama2 models". Some weird formatting issue.

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
optimizer_in_bwd: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think optimizer_in_bwd is only for our full_finetune_single_device recipe (for LoRA the memory savings it provides are reduced quite a bit), so would remove this field. (Similar comment for the QLoRA config too)

Comment on lines 8 to 10
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this, we don't use bitsandbytes in this config. Same comment for the QLoRA config too

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}/torchtune_perf_tracing.json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should just be ${output_dir} here and in the QLoRA config

_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/code_llama2_finetune
log_every_n_steps: 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please add log_peak_memory_stats: False in these configs. It won't error out without it, but rn we do a safe check on the config inside the recipe, which we'd eventually like to remove (keeping configs as the source of truth).

"""
Builder for creating a Code-Llama2 70B model with LoRA enabled.

The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`,
The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_70b`,

quantize_base (bool): Whether to quantize base model weights

Returns:
TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied
TransformerDecoder: Instantiation of Code-Llama2 70B model with LoRA applied

@@ -92,22 +92,23 @@ def lora_llama2_7b(
norm_eps=1e-5,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=0.05,
lora_dropout=lora_dropout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes, good catch!

@ebsmothers
Copy link
Contributor

OK just a few more small comments. Home stretch here! After those are addressed I think this is good to merge

@SalmanMohammadi
Copy link
Contributor Author

OK just a few more small comments. Home stretch here! After those are addressed I think this is good to merge

Hopefully all done! Thanks for your patience :)

@ebsmothers
Copy link
Contributor

OK just a few more small comments. Home stretch here! After those are addressed I think this is good to merge

Hopefully all done! Thanks for your patience :)

Great! Just kicked off one more CI run now, once that is green I think this is good to merge.

@ebsmothers ebsmothers merged commit cc2dd05 into pytorch:main Apr 26, 2024
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants