-
Notifications
You must be signed in to change notification settings - Fork 392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding model builders for code-llama2 7b, 13b, and 70b #847
Merged
ebsmothers
merged 15 commits into
pytorch:main
from
SalmanMohammadi:code-llama-model-builder
Apr 26, 2024
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d38fec2
adding model builders for code-llama2 7b, 13b, and 70b
366ddc9
Added lora and qlora code-llama models. Added a lora and qlora base l…
SalmanMohammadi 8c11e2a
updating __init__.py
SalmanMohammadi ba0dd40
refacting code_llama2 into own folder. adding recipes for code_llama2…
SalmanMohammadi 329fed5
removing unused imports in llama2/__init__.py
SalmanMohammadi b88bade
Updating README.md, fixing mis-copied files
SalmanMohammadi 36e1754
adding missing import in torchtune/models/__init__.py
SalmanMohammadi 257b5de
adding code_llama2 recipes to _recipe_registry.py
SalmanMohammadi 7890051
Fixing bug in lora and qlora code_llama2 recipes using the wrong mode…
SalmanMohammadi 2e06912
missing profiler configs in lora_code_llama2_7b
SalmanMohammadi ff4f6f5
Fixing typos in recipe docs, removing references to instruct models, …
SalmanMohammadi 64907b4
git not picking up case change in filename...
SalmanMohammadi c8d9584
updating reference to codellama huggingface repo
SalmanMohammadi 83153e5
updating docs, removing qlora 70b and models
SalmanMohammadi 98ef8b8
Updating typos in docstrings, fixing recipe config errors
SalmanMohammadi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Config for single device full finetuning in full_finetune_single_device.py | ||
# using a Code-Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf | ||
# The default config uses an optimizer from bitsandbytes. If you do not have it installed, | ||
# you can install it with | ||
# pip install bitsandbytes | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/CodeLlama-7b-hf/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.code_llama2.code_llama2_7b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/CodeLlama-7b-hf | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00003.bin, | ||
pytorch_model-00002-of-00003.bin, | ||
pytorch_model-00003-of-00003.bin | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/CodeLlama-7b-hf | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 3 | ||
optimizer: | ||
_component_: bitsandbytes.optim.PagedAdamW | ||
lr: 2e-5 | ||
optimizer_in_bwd: True | ||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
compile: False | ||
|
||
# Training environment | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/code_llama2_finetune | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Config for single device full finetuning in full_finetune_single_device.py | ||
# using a Code-Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.code_llama2.lora_code_llama2_7b | ||
lora_attn_modules: ['q_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/CodeLlama-7b-hf/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
|
||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/CodeLlama-7b-hf | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00003.bin, | ||
pytorch_model-00002-of-00003.bin, | ||
pytorch_model-00003-of-00003.bin | ||
] | ||
adapter_checkpoint: null | ||
recipe_checkpoint: null | ||
output_dir: /tmp/CodeLlama-7b-hf | ||
model_type: LLAMA2 | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 64 | ||
compile: False | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
|
||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
|
||
# Training environment | ||
device: cuda | ||
enable_activation_checkpointing: True | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/lora_code_llama2_finetune_output | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
profiler: | ||
_component_: torchtune.utils.profiler | ||
enabled: False | ||
output_dir: ${output_dir}/torchtune_perf_tracing.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Config for single device QLoRA finetuning in lora_finetune_single_device.py | ||
# using a Code-Llama2 7B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.code_llama2.qlora_code_llama2_7b | ||
lora_attn_modules: ['q_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/CodeLlama-7b-hf/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
|
||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/CodeLlama-7b-hf | ||
checkpoint_files: [ | ||
pytorch_model-00001-of-00003.bin, | ||
pytorch_model-00002-of-00003.bin, | ||
pytorch_model-00003-of-00003.bin | ||
] | ||
adapter_checkpoint: null | ||
recipe_checkpoint: null | ||
output_dir: /tmp/CodeLlama-7b-hf | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments and training | ||
batch_size: 2 | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 64 | ||
compile: False | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
|
||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
|
||
# Training environment | ||
device: cuda | ||
enable_activation_checkpointing: True | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/qlora_code_llama2_finetune_output | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.utils.profiler | ||
enabled: False | ||
output_dir: ${output_dir}/torchtune_perf_tracing.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from ._model_builders import ( # noqa | ||
code_llama2_13b, | ||
code_llama2_70b, | ||
code_llama2_7b, | ||
lora_code_llama2_13b, | ||
lora_code_llama2_70b, | ||
lora_code_llama2_7b, | ||
qlora_code_llama2_13b, | ||
qlora_code_llama2_7b, | ||
) | ||
|
||
__all__ = [ | ||
"code_llama2_13b", | ||
"code_llama2_70b", | ||
"code_llama2_7b", | ||
"lora_code_llama2_13b", | ||
"lora_code_llama2_70b", | ||
"lora_code_llama2_7b", | ||
"qlora_code_llama2_13b", | ||
"qlora_code_llama2_7b", | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please add
log_peak_memory_stats: False
in these configs. It won't error out without it, but rn we do a safe check on the config inside the recipe, which we'd eventually like to remove (keeping configs as the source of truth).