Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize configs and add Llama2 13B + Mistral 7B #571

Merged
merged 9 commits into from
Mar 24, 2024
Merged

Generalize configs and add Llama2 13B + Mistral 7B #571

merged 9 commits into from
Mar 24, 2024

Conversation

kartikayk
Copy link
Contributor

@kartikayk kartikayk commented Mar 24, 2024

Context

All relevant TorchTune components (recipes, checkpointers and CLI) can easily be generalized to models beyond Llama2 7B. In this PR, I show that adding Llama2 13B and Mistral7B is as simple as adding some new builder functions. The Llama2 13B currently assumes HF-format checkpoints (see configs) since I need to add support for being able to deal with Meta's sharded checkpoints. I'll do this in a follow-up PR.

Accompanying this addition re some cosmetic changes which makes the repo more user friendly for multiple models. This includes better organizing our configs and shortening their names so its not cumbersome to type them out, and better organizing our models.

I manually tested out all of the commands in the different READMEs and docstrings (except QLoRA which might be currently impacted by the torchao-nightly change).

Training speed for 13B is quite competitive. Quick comparisons showed us to be 2.5x faster than some competitors without any change to recipe code. The next section shows the correctness checks.

Note: Mistral 7B requires some data preprocessing changes for Alpaca finetuning. I'll follow up with those changes in a separate PR. In this PR, I add support for the model and show numerical parity with HF.

Changelog

  • Add llama2_13b and lora_llama2_13b builder functions to support the Llama2 13B model
  • Move all existing configs (except alpaca_generate.yaml) from configs/ to configs/llama2/.
  • Shorten the names of the config
  • Add some documentation to each config so users understand when to use which config
  • Make all associated changes to tests, CLI, docs and READMEs
  • Take _convert_weights out of models/llama2 and make this a public file which I expect users to use.
  • Add _component_builders.py and _model_builders.py under models/mistral to support the mistral_7b model.

Correctness Checks for Llama 13B

Numeric Parity of Llama2 13B with HF's implementation

image

 

Eval using Eleuther's Harness on truthfulqa_mc2

Baseline: 36.9% vs Finetuned: 47.1%

image image

Loss Curve (loss is comparable to what some forums shared)

image

Correctness Checks for Mistral 7B

image

Test plan

  • All tests pass
pytest tests
  • 13B full-finetune on 4 devices
image
  • 13B LoRA on 4 devices
tune --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora.yaml
image

Copy link

pytorch-bot bot commented Mar 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/571

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2a06b8f with merge base 49b523c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 24, 2024
Copy link

netlify bot commented Mar 24, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit 2a06b8f
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/6600a30c2564a800083e2206
😎 Deploy Preview https://deploy-preview-571--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@@ -66,7 +66,7 @@ def main(parser):
epilog=textwrap.dedent(
"""\
examples:
$ tune cp lora_finetune_distributed.yaml ./my_custom_llama2_lora.yaml
$ tune cp llama2/7B_lora.yaml ./my_custom_llama2_lora.yaml
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings I think appending the model name (eg: llama2) is fine and in fact will be critical as we add more models. I tested and this should work with our pkg structure. Let me know what you think.

@@ -6,7 +6,7 @@

Recipes are the primary entry points for TorchTune users. These can be thought of as end-to-end pipelines for training and optionally evaluating LLMs. Each recipe consists of three components:

- **Configurable parameters**, specified through yaml configs [example](https://github.com/pytorch/torchtune/blob/main/recipes/configs/full_finetune_distributed.yaml) and command-line overrides
- **Configurable parameters**, specified through yaml configs [example](https://github.com/pytorch/torchtune/blob/main/recipes/configs/7B_full.yaml) and command-line overrides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this path is wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeh sorry, let me fix this

# tune --nnodes 1 --nproc_per_node 1 full_finetune_distributed \
# --config full_finetune_distributed \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> ...
# This config assumes that you've run the following command before launching
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this huge block in the config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently there's no documentation on the configs at all. Once we have live docs available, we can add these to the docs. But for now, I'd like to give some understanding to users about when and how to use each config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's still some information in the README on configs and we can make that more clear. I think cluttering up the configs can be overwhelming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would it be overwhelming? Isn't it just documentation? I don't think we would be able to add config-level info to the README?

@@ -161,7 +161,7 @@ will list out all the locations where an error was found.

.. code-block:: bash

tune validate --config recipes/configs/full_finetune_single_device.yaml batch_size=4
tune validate --config recipes/configs/llama2/7B_full.yaml batch_size=4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to watch out for: there can be lots of issues when filenames start with a number. I've definitely seen it as a problem with Python imports, maybe it will be ok with YAML files? But something to keep in mind

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting - what sort of issues? But yeh I don't expect us to be importing the configs anymore?

Comment on lines +21 to +24
# This config should be used with 2+ GPUs. Single device full fine-tuning
# requires several memory optimizations which are exposed through
# 7B_full_single_device.yaml. Please update the model and checkpoints to 13B
# in that config.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess implicit in our choice of naming here is that >1 device is kind of now the "default", right? While I understand that we are doing more memory optimizations in the single device recipes now, we've obviously seen that FSDP comes with its own nuances too. So I do wonder if it's now hard for someone to just come in and say "give me a simple single-device recipe to get started on"

This is also a bit weird for QLoRA imo where we currently only support single device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeh this is a good question. I did this primarily for two reasons:

  • With the distributed CI testing sorted out, I removed the constraint on the distributed recipes. We can now run those on single device but without the memory optimizations.
  • As we go to larger models, the single device setting will be less frequent. So when I thought about the default, distributed seem like the more natural one.

Does this make sense?

Comment on lines 47 to 48
full_finetune_distributed.py llama2/7B_full, llama2/13B_full
lora_finetune_distributed.py llama2/7B_lora, llama2/13B_lora
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to finally see more than one config per recipe. Nit: maybe we can split the configs over separate lines? E.g.

RECIPE                                           CONFIG
full_finetune_distributed.py                     llama2/7B_full
                                                 llama2/13B_full
lora_finetune_distributed                        llama2/7B_lora
                                                 llama2/13B_lora

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, my vision was to be over multiple lines and I think that's actually how it works right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, sorry I just need to update this example.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome to see this! Wonder if we have a parity check for 13b model, comparing to HF forward outputs and/or Meta model outputs?

@kartikayk
Copy link
Contributor Author

@rohan-varma yup! I don't think I'll compare with the Meta implementation since I don't have the code for sharded checkpointing in a decent state (will do that comparison in a follow up PR). But I'll add a comparison with HF implementation.

@kartikayk
Copy link
Contributor Author

Added numerical parity checks and e2e eval comparisons for Llama2 13B to the context section. Thanks @joecummings for the help on this!

@kartikayk kartikayk changed the title Generalize configs and add 13B model Generalize configs and add Llama2 13B + Mistral 7B Mar 24, 2024
@@ -60,7 +60,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
The following configs can be used to run this recipe:
>>> tune ls
RECIPE CONFIG
lora_finetune_distributed lora_finetune_distributed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you actually ran this command? I think it goes on multiple lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup!
image

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome change! LGTM

):
raise RuntimeError("Full bf16 training is not supported on this hardware.")

world_size, rank = utils.get_world_size_and_rank()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have our distributed tests working as expected, I'm removing this constraint.

@kartikayk kartikayk merged commit b65426d into main Mar 24, 2024
21 checks passed
@kartikayk kartikayk deleted the add_13b branch March 24, 2024 22:14
model_type: LLAMA2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
batch_size: 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These config changes will make for high consumption memory for mistral, right?

This was referenced Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants