Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add gemma7b support #971

Merged
merged 7 commits into from
May 31, 2024
Merged

Conversation

Optimox
Copy link
Contributor

@Optimox Optimox commented May 13, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This PR adds support for gemma 7b: #969

Changelog

Minimal changes, simply adding gemma7b configs.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented May 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/971

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 0dc0c23 with merge base dc2b991 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 13, 2024
@Optimox
Copy link
Contributor Author

Optimox commented May 13, 2024

[WIP]
I need to run the training on alpaca and check that everything is working but my GPU is busy at the moment.

@joecummings should I add unit tests for this PR ?

@Optimox
Copy link
Contributor Author

Optimox commented May 13, 2024

I am new to using pre-commit. When running pre-commit run --all-files I see

trim trailing whitespace.................................................Passed
check python ast.........................................................Passed
check for merge conflicts................................................Passed
don't commit to branch...................................................Passed
check for added large files..............................................Passed
fix end of files.........................................................Passed
Insert license in comments...............................................Passed
flake8...................................................................Passed
Format files with µfmt...................................................Failed

hook id: ufmt
files were modified by this hook
Formatted /mnt/datasets/mytorchtune/torchtune/torchtune/models/gemma/init.py
✨ 1 file formatted, 191 files already formatted ✨

Could anyone tell me what ufmt format is?

@Optimox Optimox changed the title feat: add gemma7b support [WIP]feat: add gemma7b support May 13, 2024
@Optimox Optimox marked this pull request as draft May 13, 2024 16:14
@joecummings
Copy link
Contributor

@joecummings should I add unit tests for this PR ?

Whoops, I keep overwriting instead of quote and reply. Let's just start with W&B run first.

@joecummings
Copy link
Contributor

I am new to using pre-commit. When running pre-commit run --all-files I see

trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed don't commit to branch...................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Failed

hook id: ufmt files were modified by this hook Formatted /mnt/datasets/mytorchtune/torchtune/torchtune/models/gemma/init.py ✨ 1 file formatted, 191 files already formatted ✨

Could anyone tell me what ufmt format is?

UFMT formats your code for you. So the above "error" just means that the current code was formatted and you'll have to use git add to add the newly formatted files to your commit.

@joecummings
Copy link
Contributor

Thanks for hopping on this so quickly!!

@Optimox
Copy link
Contributor Author

Optimox commented May 14, 2024

@joecummings I have a working version. Things have been slightly more complicated than expected because there was a silent bug in the gemma architecture.

I ran the qlora single gpu pipeline for 1 epoch, please find attached the logs.
log_1715681624.txt

I could not run the full training pipeline because of OOM on my GPU.

Please let me know what is left to be done to check that everything works as expected!

@Optimox Optimox marked this pull request as ready for review May 14, 2024 13:26
@Optimox Optimox changed the title [WIP]feat: add gemma7b support feat: add gemma7b support May 14, 2024
@ebsmothers
Copy link
Contributor

ebsmothers commented May 14, 2024

Hi @Optimox can you elaborate on the silent bug? I think @kartikayk mentioned that for Gemma 7B it may be the case that embed_dim != head_dim * num_heads, is it related to that?

@@ -401,6 +401,7 @@ def load_checkpoint(self) -> Dict[str, Any]:
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers this is a "generic" silent error where head_dim was never given and always inferred to be head_dim = dim // num_heads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also confirm that this is provided in the config for all the other models we support (I think it is? just want to be sure)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right I think mistral does not have it in config : https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json

embed_dim,
rank=lora_rank,
alpha=lora_alpha,
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
else nn.Linear(num_heads * head_dim, embed_dim, bias=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers here there was another "silent" error since embed_dim = num_heads * head_dim for gemma2b but this is not true for gemma7b.

@Optimox
Copy link
Contributor Author

Optimox commented May 16, 2024

@joecummings @ebsmothers would any one of you accept to review this PR? Let me know if I need to add something! Thanks!

@joecummings
Copy link
Contributor

@joecummings @ebsmothers would any one of you accept to review this PR? Let me know if I need to add something! Thanks!

Yep, looking today

@1Krypt0
Copy link

1Krypt0 commented May 21, 2024

Hey, is there any update on this? It would be great if this was added soon.

@kartikayk
Copy link
Contributor

@1Krypt0 I'll let @joecummings review the PR, but I was generally curious about why you need Gemma 7B? Looking at benchmarks seems like Mistral 7B and Llama3 are very competitive and have better community support (inference etc). So was curious about the use of Gemma 7B. Would you be able to say more about your use case? Benchmarks are definitely not comprehensive and so it would be nice to learn a bit more about the kind of use cases where Gemma shines.

@1Krypt0
Copy link

1Krypt0 commented May 23, 2024

@1Krypt0 I'll let @joecummings review the PR, but I was generally curious about why you need Gemma 7B? Looking at benchmarks seems like Mistral 7B and Llama3 are very competitive and have better community support (inference etc). So was curious about the use of Gemma 7B. Would you be able to say more about your use case? Benchmarks are definitely not comprehensive and so it would be nice to learn a bit more about the kind of use cases where Gemma shines.

@kartikayk Yeah, of course! It's for my Master's dissertation, I am assessing the capability of multiple models on "large" document summarization (around the 8k token limit). It just so happens that 3 of the models I would like to test are Mistral-7B, Llama3-8B, and Gemma-7B, both in the "base" form and fine-tuned, to see if there is any improvement, as they seem to be some of the more popular and capable "small" models, and are at the limit of what my GPU can handle as well.

@kartikayk
Copy link
Contributor

@1Krypt0 sounds like a very interesting dissertation! We'd love to learn more when you have some results :) In the mean time we'll review ASAP

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay in getting to this one.. the changes look good overall, just a few comments here and there. Are you able to run fine-tunes for these configs to confirm that the loss curves look reasonable? It'd be good to make sure that things run end-to-end and the values in the configs are sensible. Please let me know if you need any assistance there.

@@ -0,0 +1,92 @@
# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to change the filename of this one

Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_gemma_7b` for full API arguments.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need to install and run pre-commit if you haven't (the no newline at the end of the file makes me suspect). Ref

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've installed it, I just reran manually pre-commit run --all-files and everything is green on my side....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh weird. Our pre-commit config should enforce exactly one newline at the end of each file (see here, here). But yeah it's green on CI too. So either there's some bug in our pre-commit config or it's a misunderstanding on my part (hopefully the latter). Either way, nothing for you to worry about on this PR.

@@ -401,6 +401,7 @@ def load_checkpoint(self) -> Dict[str, Any]:
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also confirm that this is provided in the config for all the other models we support (I think it is? just want to be sure)

# this run:
# tune download google/gemma-7b --hf-token <HF_TOKEN> --output-dir /tmp/gemma --ignore-patterns ""
#
# To launch on 4 devices, run the following command from root:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right, I think QLoRA should be single device

# Model Arguments
model:
_component_: torchtune.models.gemma.qlora_gemma_7b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we're not also applying to output_proj here? I think this is what we do in many of our other QLoRA configs. (I notice we're not doing it in Gemma 2B though, so if your config is just based on that it's reasonable. Just wanna make sure we're giving good defaults here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I did not give it much thoughts but followed the proposed config for gemma2b. I don't see why the out_proj would be left out, do you know if there is a good rationale behind it? Or should I add it back to both gemma7b and gemma2b ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm actually not sure why it's not included in the 2B config, cc @solitude-alive who added that one and may have more insights here. Would you be interested in running some experiments here to determine which version performs better (at least for 7B)? I think we should try to provide some loss curves as part of the testing plan regardless (just to ensure things are working as expected). I'm happy to provide help with this if you need it, just let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I had just consulted the settings provided under Mistral qlora at that moment, without applying to output_proj.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers my GPU is busy at the moment so I'm not able to launch the comparison between with/without out_proj qlora. Any chance you could run it on your side ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can do that in a bit. Do you also have loss curves for the other runs already? If so can you paste in the PR summary as a reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the qlora pipeline for 1 epoch, I shared the logs here

torchtune/_recipe_registry.py Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra file?

recipes/configs/gemma/7B_full.yaml Outdated Show resolved Hide resolved
recipes/configs/gemma/2B_qlora_single_device copy.yaml Outdated Show resolved Hide resolved
docs/source/api_ref_models.rst Outdated Show resolved Hide resolved
@Optimox
Copy link
Contributor Author

Optimox commented May 29, 2024

@ebsmothers I've just pushed a new version which takes all your points into consideration. I think the only thing missing is the run of the other pipelines (only ran qlora single device)

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this change, @Optimox! Pretty awesome add to the repo :)

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more request: can you update all the paths in the download commands to /tmp/gemma-7b instead of /tmp/gemma? And can you update the tokenizer path in the QLoRA config similarly? Otherwise the commands don't work out of the box. After that and green CI I will merge.

Thanks so much for contributing this, and thanks for your patience during the review process!

@Optimox
Copy link
Contributor Author

Optimox commented May 31, 2024

@ebsmothers I updated the download command as asked! Thank you for your careful review!

# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma/tokenizer.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last one

@ebsmothers ebsmothers merged commit 135cf2e into pytorch:main May 31, 2024
29 checks passed
@Optimox
Copy link
Contributor Author

Optimox commented Jun 1, 2024

Thank you @ebsmothers for the last commit!

weifengpy pushed a commit to weifengpy/torchtune that referenced this pull request Jun 4, 2024
Co-authored-by: Evan Smothers <ebs@fb.com>
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Co-authored-by: Evan Smothers <ebs@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants