-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add gemma7b support #971
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/971
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 0dc0c23 with merge base dc2b991 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[WIP] @joecummings should I add unit tests for this PR ? |
a4b6867
to
5c7454c
Compare
I am new to using pre-commit. When running pre-commit run --all-files I see trim trailing whitespace.................................................Passed hook id: ufmt Could anyone tell me what ufmt format is? |
Whoops, I keep overwriting instead of quote and reply. Let's just start with W&B run first. |
UFMT formats your code for you. So the above "error" just means that the current code was formatted and you'll have to use |
Thanks for hopping on this so quickly!! |
@joecummings I have a working version. Things have been slightly more complicated than expected because there was a silent bug in the gemma architecture. I ran the qlora single gpu pipeline for 1 epoch, please find attached the logs. I could not run the full training pipeline because of OOM on my GPU. Please let me know what is left to be done to check that everything works as expected! |
Hi @Optimox can you elaborate on the silent bug? I think @kartikayk mentioned that for Gemma 7B it may be the case that |
@@ -401,6 +401,7 @@ def load_checkpoint(self) -> Dict[str, Any]: | |||
num_heads=self._config["num_attention_heads"], | |||
num_kv_heads=self._config["num_key_value_heads"], | |||
dim=self._config["hidden_size"], | |||
head_dim=self._config["head_dim"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers this is a "generic" silent error where head_dim was never given and always inferred to be head_dim = dim // num_heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also confirm that this is provided in the config for all the other models we support (I think it is? just want to be sure)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right I think mistral does not have it in config : https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
embed_dim, | ||
rank=lora_rank, | ||
alpha=lora_alpha, | ||
quantize_base=quantize_base, | ||
) | ||
if "output_proj" in lora_modules | ||
else nn.Linear(embed_dim, embed_dim, bias=False) | ||
else nn.Linear(num_heads * head_dim, embed_dim, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers here there was another "silent" error since embed_dim = num_heads * head_dim
for gemma2b but this is not true for gemma7b.
@joecummings @ebsmothers would any one of you accept to review this PR? Let me know if I need to add something! Thanks! |
Yep, looking today |
Hey, is there any update on this? It would be great if this was added soon. |
@1Krypt0 I'll let @joecummings review the PR, but I was generally curious about why you need Gemma 7B? Looking at benchmarks seems like Mistral 7B and Llama3 are very competitive and have better community support (inference etc). So was curious about the use of Gemma 7B. Would you be able to say more about your use case? Benchmarks are definitely not comprehensive and so it would be nice to learn a bit more about the kind of use cases where Gemma shines. |
@kartikayk Yeah, of course! It's for my Master's dissertation, I am assessing the capability of multiple models on "large" document summarization (around the 8k token limit). It just so happens that 3 of the models I would like to test are Mistral-7B, Llama3-8B, and Gemma-7B, both in the "base" form and fine-tuned, to see if there is any improvement, as they seem to be some of the more popular and capable "small" models, and are at the limit of what my GPU can handle as well. |
@1Krypt0 sounds like a very interesting dissertation! We'd love to learn more when you have some results :) In the mean time we'll review ASAP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay in getting to this one.. the changes look good overall, just a few comments here and there. Are you able to run fine-tunes for these configs to confirm that the loss curves look reasonable? It'd be good to make sure that things run end-to-end and the values in the configs are sensible. Please let me know if you need any assistance there.
@@ -0,0 +1,92 @@ | |||
# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to change the filename of this one
Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers | ||
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. | ||
Please see `lora_gemma_7b` for full API arguments. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May need to install and run pre-commit if you haven't (the no newline at the end of the file makes me suspect). Ref
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've installed it, I just reran manually pre-commit run --all-files
and everything is green on my side....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh weird. Our pre-commit config should enforce exactly one newline at the end of each file (see here, here). But yeah it's green on CI too. So either there's some bug in our pre-commit config or it's a misunderstanding on my part (hopefully the latter). Either way, nothing for you to worry about on this PR.
@@ -401,6 +401,7 @@ def load_checkpoint(self) -> Dict[str, Any]: | |||
num_heads=self._config["num_attention_heads"], | |||
num_kv_heads=self._config["num_key_value_heads"], | |||
dim=self._config["hidden_size"], | |||
head_dim=self._config["head_dim"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also confirm that this is provided in the config for all the other models we support (I think it is? just want to be sure)
# this run: | ||
# tune download google/gemma-7b --hf-token <HF_TOKEN> --output-dir /tmp/gemma --ignore-patterns "" | ||
# | ||
# To launch on 4 devices, run the following command from root: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look right, I think QLoRA should be single device
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma.qlora_gemma_7b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we're not also applying to output_proj
here? I think this is what we do in many of our other QLoRA configs. (I notice we're not doing it in Gemma 2B though, so if your config is just based on that it's reasonable. Just wanna make sure we're giving good defaults here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I did not give it much thoughts but followed the proposed config for gemma2b. I don't see why the out_proj
would be left out, do you know if there is a good rationale behind it? Or should I add it back to both gemma7b and gemma2b ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm actually not sure why it's not included in the 2B config, cc @solitude-alive who added that one and may have more insights here. Would you be interested in running some experiments here to determine which version performs better (at least for 7B)? I think we should try to provide some loss curves as part of the testing plan regardless (just to ensure things are working as expected). I'm happy to provide help with this if you need it, just let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I had just consulted the settings provided under Mistral
qlora at that moment, without applying to output_proj
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers my GPU is busy at the moment so I'm not able to launch the comparison between with/without out_proj
qlora. Any chance you could run it on your side ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I can do that in a bit. Do you also have loss curves for the other runs already? If so can you paste in the PR summary as a reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the qlora pipeline for 1 epoch, I shared the logs here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra file?
cbf1046
to
3de7ba9
Compare
@ebsmothers I've just pushed a new version which takes all your points into consideration. I think the only thing missing is the run of the other pipelines (only ran qlora single device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this change, @Optimox! Pretty awesome add to the repo :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more request: can you update all the paths in the download commands to /tmp/gemma-7b
instead of /tmp/gemma
? And can you update the tokenizer path in the QLoRA config similarly? Otherwise the commands don't work out of the box. After that and green CI I will merge.
Thanks so much for contributing this, and thanks for your patience during the review process!
@ebsmothers I updated the download command as asked! Thank you for your careful review! |
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma/tokenizer.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last one
Thank you @ebsmothers for the last commit! |
Co-authored-by: Evan Smothers <ebs@fb.com>
Co-authored-by: Evan Smothers <ebs@fb.com>
Context
What is the purpose of this PR? Is it to
This PR adds support for gemma 7b: #969
Changelog
Minimal changes, simply adding gemma7b configs.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test