From d38fec2efa5a8e8951388797e0f3c32e2decb54d Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 23 Apr 2024 21:00:17 +0100 Subject: [PATCH 01/15] adding model builders for code-llama2 7b, 13b, and 70b --- torchtune/models/llama2/__init__.py | 3 + torchtune/models/llama2/_model_builders.py | 73 ++++++++++++++++++++-- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 1391e61d4..76b34e5a4 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -7,6 +7,9 @@ from ._component_builders import llama2, lora_llama2 from ._model_builders import ( # noqa + code_llama2_13b, + code_llama2_70b, + code_llama2_7b, llama2_13b, llama2_70b, llama2_7b, diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 15db69e14..98d34092d 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -22,7 +22,7 @@ def llama2_7b() -> TransformerDecoder: """ - Builder for creating a Llama2 model initialized w/ the default 7b parameter values + Builder for creating a Llama2 model initialized w/ the default 7B parameter values from https://arxiv.org/abs/2307.09288 Returns: @@ -96,6 +96,7 @@ def lora_llama2_7b( quantize_base=quantize_base, ) + qlora_llama2_7b = partial(lora_llama2_7b, quantize_base=True) qlora_llama2_7b.__doc__ = """ @@ -107,7 +108,7 @@ def lora_llama2_7b( def llama2_13b() -> TransformerDecoder: """ - Builder for creating a Llama2 model initialized w/ the default 13b parameter values + Builder for creating a Llama2 model initialized w/ the default 13B parameter values from https://arxiv.org/abs/2307.09288 Returns: @@ -123,7 +124,6 @@ def llama2_13b() -> TransformerDecoder: max_seq_len=4096, attn_dropout=0.0, norm_eps=1e-5, - ) @@ -177,16 +177,17 @@ def lora_llama2_13b( quantize_base=quantize_base, ) + qlora_llama2_13b = partial(lora_llama2_13b, quantize_base=True) def llama2_70b() -> TransformerDecoder: """ - Builder for creating a Llama2 model initialized w/ the default 70 parameter values + Builder for creating a Llama2 model initialized w/ the default 70B parameter values from https://arxiv.org/abs/2307.09288 Returns: - TransformerDecoder: Instantiation of Llama2 70 model + TransformerDecoder: Instantiation of Llama2 70B model """ return llama2( vocab_size=32_000, @@ -250,3 +251,65 @@ def lora_llama2_70b( lora_dropout=0.05, quantize_base=quantize_base, ) + + +def code_llama2_7b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 7B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model + """ + return llama2( + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def code_llama2_13b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 13B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 13B model + """ + return llama2( + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def code_llama2_70b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 70B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 70B model + """ + return llama2( + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) From 366ddc99239d1e9ea3c260038d98ed8aac53d660 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 24 Apr 2024 13:05:43 +0100 Subject: [PATCH 02/15] Added lora and qlora code-llama models. Added a lora and qlora base llama 70B. Small documentation fixes for consistency --- torchtune/models/llama2/_model_builders.py | 205 ++++++++++++++++++++- 1 file changed, 199 insertions(+), 6 deletions(-) diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 98d34092d..632c43fe2 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -92,7 +92,7 @@ def lora_llama2_7b( norm_eps=1e-5, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, ) @@ -100,7 +100,7 @@ def lora_llama2_7b( qlora_llama2_7b = partial(lora_llama2_7b, quantize_base=True) qlora_llama2_7b.__doc__ = """ -Builder for creating a Llama2 model with QLoRA enabled. Base model weights in linear layers +Builder for creating a Llama2 7B model with QLoRA enabled. Base model weights in linear layers that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. Please see `lora_llama2_7b` for full API arguments. """ @@ -133,6 +133,7 @@ def lora_llama2_13b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.05, quantize_base: bool = False, ) -> TransformerDecoder: """ @@ -167,18 +168,23 @@ def lora_llama2_13b( num_heads=40, num_kv_heads=40, embed_dim=5120, - max_seq_len=4096, intermediate_dim=13824, + max_seq_len=4096, attn_dropout=0.0, norm_eps=1e-5, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, ) qlora_llama2_13b = partial(lora_llama2_13b, quantize_base=True) +qlora_llama2_13b.__doc__ = """ +Builder for creating a Llama2 13B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_13b` for full API arguments. +""" def llama2_70b() -> TransformerDecoder: @@ -214,7 +220,7 @@ def lora_llama2_70b( """ Builder for creating a Llama2 70B model with LoRA enabled. - The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_7b`, + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_70b`, while LoRA default params are based on https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. @@ -248,11 +254,19 @@ def lora_llama2_70b( norm_eps=1e-5, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, ) +qlora_llama2_70b = partial(lora_llama2_70b, quantize_base=True) +qlora_llama2_70b.__doc__ = """ +Builder for creating a Llama2 70B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_70b` for full API arguments. +""" + + def code_llama2_7b() -> TransformerDecoder: """ Builder for creating a Code-Llama2 model initialized w/ the default 7B parameter values @@ -273,6 +287,65 @@ def code_llama2_7b() -> TransformerDecoder: ) +def lora_code_llama2_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 7B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_7b = partial(lora_code_llama2_7b, quantize_base=True) + +qlora_code_llama2_7b.__doc__ = """ +Builder for creating a Code-Llama2 7B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_7b` for full API arguments. +""" + + def code_llama2_13b() -> TransformerDecoder: """ Builder for creating a Code-Llama2 model initialized w/ the default 13B parameter values @@ -294,6 +367,66 @@ def code_llama2_13b() -> TransformerDecoder: ) +def lora_code_llama2_13b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 13B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_13b = partial(lora_code_llama2_13b, quantize_base=True) + +qlora_code_llama2_13b.__doc__ = """ +Builder for creating a Code-Llama2 13B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_13b` for full API arguments. +""" + + def code_llama2_70b() -> TransformerDecoder: """ Builder for creating a Code-Llama2 model initialized w/ the default 70B parameter values @@ -313,3 +446,63 @@ def code_llama2_70b() -> TransformerDecoder: attn_dropout=0.0, norm_eps=1e-5, ) + + +def lora_code_llama2_70b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 70B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_70b = partial(lora_code_llama2_70b, quantize_base=True) + +qlora_code_llama2_70b.__doc__ = """ +Builder for creating a Code-Llama2 70B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_70b` for full API arguments. +""" From 8c11e2afe9569d81f37a784e9c973f8ad6646e19 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 24 Apr 2024 13:14:28 +0100 Subject: [PATCH 03/15] updating __init__.py --- torchtune/models/llama2/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 76b34e5a4..4f6d4114d 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -14,9 +14,15 @@ llama2_70b, llama2_7b, llama2_tokenizer, + lora_code_llama2_13b, + lora_code_llama2_70b, + lora_code_llama2_7b, lora_llama2_13b, lora_llama2_70b, lora_llama2_7b, + qlora_code_llama2_13b, + qlora_code_llama2_70b, + qlora_code_llama2_7b, qlora_llama2_13b, qlora_llama2_7b, ) From ba0dd4094c06429b74a77983896a7c00a60b5132 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 12:08:01 +0100 Subject: [PATCH 04/15] refacting code_llama2 into own folder. adding recipes for code_llama2 for ease of use. --- .../code_llama2/7B_full_low_memory.yaml | 78 ++++++ .../code_llama2/7B_lora_single_device.yaml | 89 +++++++ .../code_llama2/7b_qlora_single_device.yaml | 96 +++++++ torchtune/models/code_llama2/__init__.py | 5 + .../models/code_llama2/_model_builders.py | 29 +++ torchtune/models/llama2/_model_builders.py | 241 ------------------ 6 files changed, 297 insertions(+), 241 deletions(-) create mode 100644 recipes/configs/code_llama2/7B_full_low_memory.yaml create mode 100644 recipes/configs/code_llama2/7B_lora_single_device.yaml create mode 100644 recipes/configs/code_llama2/7b_qlora_single_device.yaml create mode 100644 torchtune/models/code_llama2/__init__.py create mode 100644 torchtune/models/code_llama2/_model_builders.py diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml new file mode 100644 index 000000000..3ee282834 --- /dev/null +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -0,0 +1,78 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download codellama/CodeLlama-7b-Instruct-hf --output-dir /tmp/CodeLlama-7b-Instruct-hf +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: True +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.code_llama2.code_llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 2e-5 +optimizer_in_bwd: True +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/code_llama2_finetune +log_every_n_steps: null diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml new file mode 100644 index 000000000..a731d5bfa --- /dev/null +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -0,0 +1,89 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download codellama/CodeLlama-7b-Instruct-hf --output-dir /tmp/CodeLlama-7b-Instruct-hf +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config code_llama2/7B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config llama2/7B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.lora_code_llama2.code_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + train_on_input: True +seed: null +shuffle: True + + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +optimizer_in_bwd: True + +loss: + _component_: torch.nn.CrossEntropyLoss + + +# Training environment +device: cuda +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir}/torchtune_perf_tracing.json +output_dir: /tmp/lora_code_llama2_finetune_output +log_every_n_steps: null diff --git a/recipes/configs/code_llama2/7b_qlora_single_device.yaml b/recipes/configs/code_llama2/7b_qlora_single_device.yaml new file mode 100644 index 000000000..51d3a6794 --- /dev/null +++ b/recipes/configs/code_llama2/7b_qlora_single_device.yaml @@ -0,0 +1,96 @@ +# Config for single device QLoRA finetuning in lora_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download codellama/CodeLlama-7b-Instruct-hf --output-dir /tmp/CodeLlama-7b-Instruct-hf +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config code_llama2/7B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config code_llama2/7B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.lora_code_llama2.qlora_code_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + train_on_input: True +seed: null +shuffle: True + + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments and training +batch_size: 2 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +optimizer_in_bwd: True + +loss: + _component_: torch.nn.CrossEntropyLoss + + +# Training environment +device: cuda +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir}/torchtune_perf_tracing.json +output_dir: /tmp/qlora_code_llama2_finetune_output +log_every_n_steps: null + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.profiler + enabled: False + output_dir: ${output_dir}/torchtune_perf_tracing.json diff --git a/torchtune/models/code_llama2/__init__.py b/torchtune/models/code_llama2/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/torchtune/models/code_llama2/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtune/models/code_llama2/_model_builders.py b/torchtune/models/code_llama2/_model_builders.py new file mode 100644 index 000000000..b334f6f4a --- /dev/null +++ b/torchtune/models/code_llama2/_model_builders.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._model_builders import ( # noqa + code_llama2_13b, + code_llama2_70b, + code_llama2_7b, + lora_code_llama2_13b, + lora_code_llama2_70b, + lora_code_llama2_7b, + qlora_code_llama2_13b, + qlora_code_llama2_70b, + qlora_code_llama2_7b, +) + +__all__ = [ + "code_llama2_13b", + "code_llama2_70b", + "code_llama2_7b", + "lora_code_llama2_13b", + "lora_code_llama2_70b", + "lora_code_llama2_7b", + "qlora_code_llama2_13b", + "qlora_code_llama2_70b", + "qlora_code_llama2_7b", +] diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 632c43fe2..0602b7ec0 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -265,244 +265,3 @@ def lora_llama2_70b( that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. Please see `lora_llama2_70b` for full API arguments. """ - - -def code_llama2_7b() -> TransformerDecoder: - """ - Builder for creating a Code-Llama2 model initialized w/ the default 7B parameter values - from https://arxiv.org/pdf/2308.12950.pdf - - Returns: - TransformerDecoder: Instantiation of Code-Llama2 7B model - """ - return llama2( - vocab_size=32_016, - num_layers=32, - num_heads=32, - num_kv_heads=32, - embed_dim=4096, - max_seq_len=16384, - attn_dropout=0.0, - norm_eps=1e-5, - ) - - -def lora_code_llama2_7b( - lora_attn_modules: List[LORA_ATTN_MODULES], - apply_lora_to_mlp: bool = False, - apply_lora_to_output: bool = False, - lora_rank: int = 8, - lora_alpha: float = 16, - lora_dropout: float = 0.05, - quantize_base: bool = False, -) -> TransformerDecoder: - """ - Builder for creating a Code-Llama2 7B model with LoRA enabled. - - The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_7b`, - while LoRA default params are based on - https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. - - Args: - lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers - LoRA should be applied to in each self-attention block. Options are - ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. - apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. - Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False - lora_rank (int): rank of each low-rank approximation - lora_alpha (float): scaling factor for the low-rank approximation - quantize_base (bool): Whether to quantize base model weights - - Returns: - TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied - """ - return lora_llama2( - lora_attn_modules=lora_attn_modules, - apply_lora_to_mlp=apply_lora_to_mlp, - apply_lora_to_output=apply_lora_to_output, - vocab_size=32_016, - num_layers=32, - num_heads=32, - num_kv_heads=32, - embed_dim=4096, - max_seq_len=16384, - attn_dropout=0.0, - norm_eps=1e-5, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - quantize_base=quantize_base, - ) - - -qlora_code_llama2_7b = partial(lora_code_llama2_7b, quantize_base=True) - -qlora_code_llama2_7b.__doc__ = """ -Builder for creating a Code-Llama2 7B model with QLoRA enabled. Base model weights in linear layers -that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. -Please see `lora_code_llama2_7b` for full API arguments. -""" - - -def code_llama2_13b() -> TransformerDecoder: - """ - Builder for creating a Code-Llama2 model initialized w/ the default 13B parameter values - from https://arxiv.org/pdf/2308.12950.pdf - - Returns: - TransformerDecoder: Instantiation of Code-Llama2 13B model - """ - return llama2( - vocab_size=32_016, - num_layers=40, - num_heads=40, - num_kv_heads=40, - embed_dim=5120, - intermediate_dim=13824, - max_seq_len=16384, - attn_dropout=0.0, - norm_eps=1e-5, - ) - - -def lora_code_llama2_13b( - lora_attn_modules: List[LORA_ATTN_MODULES], - apply_lora_to_mlp: bool = False, - apply_lora_to_output: bool = False, - lora_rank: int = 8, - lora_alpha: float = 16, - lora_dropout: float = 0.05, - quantize_base: bool = False, -) -> TransformerDecoder: - """ - Builder for creating a Code-Llama2 13B model with LoRA enabled. - - The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, - while LoRA default params are based on - https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. - - Args: - lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers - LoRA should be applied to in each self-attention block. Options are - ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. - apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. - Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False - lora_rank (int): rank of each low-rank approximation - lora_alpha (float): scaling factor for the low-rank approximation - quantize_base (bool): Whether to quantize base model weights - - Returns: - TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied - """ - return lora_llama2( - lora_attn_modules=lora_attn_modules, - apply_lora_to_mlp=apply_lora_to_mlp, - apply_lora_to_output=apply_lora_to_output, - vocab_size=32_016, - num_layers=40, - num_heads=40, - num_kv_heads=40, - embed_dim=5120, - intermediate_dim=13824, - max_seq_len=16384, - attn_dropout=0.0, - norm_eps=1e-5, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - quantize_base=quantize_base, - ) - - -qlora_code_llama2_13b = partial(lora_code_llama2_13b, quantize_base=True) - -qlora_code_llama2_13b.__doc__ = """ -Builder for creating a Code-Llama2 13B model with QLoRA enabled. Base model weights in linear layers -that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. -Please see `lora_code_llama2_13b` for full API arguments. -""" - - -def code_llama2_70b() -> TransformerDecoder: - """ - Builder for creating a Code-Llama2 model initialized w/ the default 70B parameter values - from https://arxiv.org/pdf/2308.12950.pdf - - Returns: - TransformerDecoder: Instantiation of Code-Llama2 70B model - """ - return llama2( - vocab_size=32_016, - num_layers=80, - num_heads=64, - num_kv_heads=8, - embed_dim=8192, - intermediate_dim=28672, - max_seq_len=16384, - attn_dropout=0.0, - norm_eps=1e-5, - ) - - -def lora_code_llama2_70b( - lora_attn_modules: List[LORA_ATTN_MODULES], - apply_lora_to_mlp: bool = False, - apply_lora_to_output: bool = False, - lora_rank: int = 8, - lora_alpha: float = 16, - lora_dropout: float = 0.05, - quantize_base: bool = False, -) -> TransformerDecoder: - """ - Builder for creating a Code-Llama2 70B model with LoRA enabled. - - The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, - while LoRA default params are based on - https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. - - Args: - lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers - LoRA should be applied to in each self-attention block. Options are - ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. - apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. - Default: False - apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. - Default: False - lora_rank (int): rank of each low-rank approximation - lora_alpha (float): scaling factor for the low-rank approximation - quantize_base (bool): Whether to quantize base model weights - - Returns: - TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied - """ - return lora_llama2( - lora_attn_modules=lora_attn_modules, - apply_lora_to_mlp=apply_lora_to_mlp, - apply_lora_to_output=apply_lora_to_output, - vocab_size=32_016, - num_layers=80, - num_heads=64, - num_kv_heads=8, - embed_dim=8192, - intermediate_dim=28672, - max_seq_len=16384, - attn_dropout=0.0, - norm_eps=1e-5, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - quantize_base=quantize_base, - ) - - -qlora_code_llama2_70b = partial(lora_code_llama2_70b, quantize_base=True) - -qlora_code_llama2_70b.__doc__ = """ -Builder for creating a Code-Llama2 70B model with QLoRA enabled. Base model weights in linear layers -that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. -Please see `lora_code_llama2_70b` for full API arguments. -""" From 329fed5dc2ed580301beaa393269fe6c3aa657d0 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 12:13:09 +0100 Subject: [PATCH 05/15] removing unused imports in llama2/__init__.py --- torchtune/models/llama2/__init__.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 4f6d4114d..13f38ffdd 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -7,23 +7,15 @@ from ._component_builders import llama2, lora_llama2 from ._model_builders import ( # noqa - code_llama2_13b, - code_llama2_70b, - code_llama2_7b, llama2_13b, llama2_70b, llama2_7b, llama2_tokenizer, - lora_code_llama2_13b, - lora_code_llama2_70b, - lora_code_llama2_7b, lora_llama2_13b, lora_llama2_70b, lora_llama2_7b, - qlora_code_llama2_13b, - qlora_code_llama2_70b, - qlora_code_llama2_7b, qlora_llama2_13b, + qlora_llama2_70b, qlora_llama2_7b, ) from ._model_utils import scale_hidden_dim_for_mlp From b88baded6a540236ace802ce3877b78efe3f8ea9 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 12:21:51 +0100 Subject: [PATCH 06/15] Updating README.md, fixing mis-copied files --- README.md | 1 + torchtune/models/code_llama2/__init__.py | 24 ++ .../models/code_llama2/_model_builders.py | 271 ++++++++++++++++-- 3 files changed, 273 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 65d31ca55..5b2994d77 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ torchtune currently supports the following models. |-----------------------------------------------|-----------| | [Llama3](https://llama.meta.com/llama3) | 8B, 70B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] | | [Llama2](https://llama.meta.com/llama2/) | 7B, 13B, 70B [[models](torchtune/models/llama2/_model_builders.py), [configs](recipes/configs/llama2/)] | +| [Code-Llama2](https://huggingface.co/codellama/CodeLlama-34b-hf) | 7B, 13B, 70B [[model](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] | | [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] | | [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] | diff --git a/torchtune/models/code_llama2/__init__.py b/torchtune/models/code_llama2/__init__.py index 2e41cd717..b334f6f4a 100644 --- a/torchtune/models/code_llama2/__init__.py +++ b/torchtune/models/code_llama2/__init__.py @@ -3,3 +3,27 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from ._model_builders import ( # noqa + code_llama2_13b, + code_llama2_70b, + code_llama2_7b, + lora_code_llama2_13b, + lora_code_llama2_70b, + lora_code_llama2_7b, + qlora_code_llama2_13b, + qlora_code_llama2_70b, + qlora_code_llama2_7b, +) + +__all__ = [ + "code_llama2_13b", + "code_llama2_70b", + "code_llama2_7b", + "lora_code_llama2_13b", + "lora_code_llama2_70b", + "lora_code_llama2_7b", + "qlora_code_llama2_13b", + "qlora_code_llama2_70b", + "qlora_code_llama2_7b", +] diff --git a/torchtune/models/code_llama2/_model_builders.py b/torchtune/models/code_llama2/_model_builders.py index b334f6f4a..6270a9079 100644 --- a/torchtune/models/code_llama2/_model_builders.py +++ b/torchtune/models/code_llama2/_model_builders.py @@ -4,26 +4,251 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._model_builders import ( # noqa - code_llama2_13b, - code_llama2_70b, - code_llama2_7b, - lora_code_llama2_13b, - lora_code_llama2_70b, - lora_code_llama2_7b, - qlora_code_llama2_13b, - qlora_code_llama2_70b, - qlora_code_llama2_7b, -) - -__all__ = [ - "code_llama2_13b", - "code_llama2_70b", - "code_llama2_7b", - "lora_code_llama2_13b", - "lora_code_llama2_70b", - "lora_code_llama2_7b", - "qlora_code_llama2_13b", - "qlora_code_llama2_70b", - "qlora_code_llama2_7b", -] +from typing import List +from functools import partial + +from torchtune.models.llama2._component_builders import llama2, lora_llama2 + +from torchtune.modules import TransformerDecoder +from torchtune.modules.peft import LORA_ATTN_MODULES + + +def code_llama2_7b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 7B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model + """ + return llama2( + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 7B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_7b = partial(lora_code_llama2_7b, quantize_base=True) + +qlora_code_llama2_7b.__doc__ = """ +Builder for creating a Code-Llama2 7B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_7b` for full API arguments. +""" + + +def code_llama2_13b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 13B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 13B model + """ + return llama2( + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_13b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 13B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_13b = partial(lora_code_llama2_13b, quantize_base=True) + +qlora_code_llama2_13b.__doc__ = """ +Builder for creating a Code-Llama2 13B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_13b` for full API arguments. +""" + + +def code_llama2_70b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 70B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 70B model + """ + return llama2( + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_70b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 70B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_70b = partial(lora_code_llama2_70b, quantize_base=True) + +qlora_code_llama2_70b.__doc__ = """ +Builder for creating a Code-Llama2 70B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_70b` for full API arguments. +""" From 36e1754f66852ef9a095d7302c47bbb193241bda Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 12:33:06 +0100 Subject: [PATCH 07/15] adding missing import in torchtune/models/__init__.py --- torchtune/models/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchtune/models/__init__.py b/torchtune/models/__init__.py index f57ff7e27..33605b74a 100644 --- a/torchtune/models/__init__.py +++ b/torchtune/models/__init__.py @@ -4,4 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.models import convert_weights, gemma, llama2, mistral # noqa +from torchtune.models import ( # noqa + code_llama2, + convert_weights, + gemma, + llama2, + mistral, +) From 257b5defc8409e142cac782c13db9a38d1e54856 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 12:40:58 +0100 Subject: [PATCH 08/15] adding code_llama2 recipes to _recipe_registry.py --- torchtune/_recipe_registry.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 039a106ac..a5e5396d7 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -31,6 +31,10 @@ class Recipe: name="llama2/7B_full_low_memory", file_path="llama2/7B_full_low_memory.yaml", ), + Config( + name="code_llama2/7B_full_low_memory", + file_path="code_llama2/7B_full_low_memory.yaml", + ), Config( name="llama3/8B_full_single_device", file_path="llama3/8B_full_single_device.yaml", @@ -66,6 +70,14 @@ class Recipe: name="llama2/7B_qlora_single_device", file_path="llama2/7B_qlora_single_device.yaml", ), + Config( + name="code_llama2/7B_lora_single_device", + file_path="code_llama2/7B_lora_single_device.yaml", + ), + Config( + name="code_llama2/7B_qlora_single_device", + file_path="llama2/7B_qlora_single_device.yaml", + ), Config( name="llama3/8B_lora_single_device", file_path="llama3/8B_lora_single_device.yaml", From 789005111850155d57433201be995cd99f492933 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 13:10:33 +0100 Subject: [PATCH 09/15] Fixing bug in lora and qlora code_llama2 recipes using the wrong model reference --- recipes/configs/code_llama2/7B_lora_single_device.yaml | 2 +- recipes/configs/code_llama2/7b_qlora_single_device.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index a731d5bfa..b8ddd585e 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -21,7 +21,7 @@ # Model Arguments model: - _component_: torchtune.models.lora_code_llama2.code_llama2_7b + _component_: torchtune.models.code_llama2.lora_code_llama2_7b lora_attn_modules: ['q_proj', 'v_proj'] apply_lora_to_mlp: False apply_lora_to_output: False diff --git a/recipes/configs/code_llama2/7b_qlora_single_device.yaml b/recipes/configs/code_llama2/7b_qlora_single_device.yaml index 51d3a6794..960fbedd9 100644 --- a/recipes/configs/code_llama2/7b_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7b_qlora_single_device.yaml @@ -21,7 +21,7 @@ # Model Arguments model: - _component_: torchtune.models.lora_code_llama2.qlora_code_llama2_7b + _component_: torchtune.models.code_llama2.qlora_code_llama2_7b lora_attn_modules: ['q_proj', 'v_proj'] apply_lora_to_mlp: False apply_lora_to_output: False From 2e06912f47aaeb87abd67dbcc6b7afa4970221f1 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 13:17:14 +0100 Subject: [PATCH 10/15] missing profiler configs in lora_code_llama2_7b --- recipes/configs/code_llama2/7B_lora_single_device.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index b8ddd585e..fe2730e76 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -87,3 +87,8 @@ metric_logger: log_dir: ${output_dir}/torchtune_perf_tracing.json output_dir: /tmp/lora_code_llama2_finetune_output log_every_n_steps: null + +profiler: + _component_: torchtune.utils.profiler + enabled: False + output_dir: ${output_dir}/torchtune_perf_tracing.json From ff4f6f5985e86dfbd0bdd8ae5b6a45b90f04f69c Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 13:30:26 +0100 Subject: [PATCH 11/15] Fixing typos in recipe docs, removing references to instruct models, and a typo to refernce to qlora config for code_llama2 in recipe registry --- recipes/configs/code_llama2/7B_full_low_memory.yaml | 2 +- recipes/configs/code_llama2/7B_lora_single_device.yaml | 5 ++--- recipes/configs/code_llama2/7b_qlora_single_device.yaml | 2 +- torchtune/_recipe_registry.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index 3ee282834..fe25f180d 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -3,7 +3,7 @@ # # This config assumes that you've run the following command before launching # this run: -# tune download codellama/CodeLlama-7b-Instruct-hf --output-dir /tmp/CodeLlama-7b-Instruct-hf +# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf # The default config uses an optimizer from bitsandbytes. If you do not have it installed, # you can install it with # pip install bitsandbytes diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index fe2730e76..62fe98143 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -1,9 +1,9 @@ # Config for single device full finetuning in full_finetune_single_device.py -# using a Llama2 7B model +# using a Code-Llama2 7B model # # This config assumes that you've run the following command before launching # this run: -# tune download codellama/CodeLlama-7b-Instruct-hf --output-dir /tmp/CodeLlama-7b-Instruct-hf +# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf # # The default config uses an optimizer from bitsandbytes. If you do not have it installed, # you can install it with @@ -53,7 +53,6 @@ checkpointer: recipe_checkpoint: null output_dir: /tmp/CodeLlama-7b-hf model_type: LLAMA2 -resume_from_checkpoint: False # Fine-tuning arguments batch_size: 2 diff --git a/recipes/configs/code_llama2/7b_qlora_single_device.yaml b/recipes/configs/code_llama2/7b_qlora_single_device.yaml index 960fbedd9..439585f39 100644 --- a/recipes/configs/code_llama2/7b_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7b_qlora_single_device.yaml @@ -3,7 +3,7 @@ # # This config assumes that you've run the following command before launching # this run: -# tune download codellama/CodeLlama-7b-Instruct-hf --output-dir /tmp/CodeLlama-7b-Instruct-hf +# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf # # The default config uses an optimizer from bitsandbytes. If you do not have it installed, # you can install it with diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index a5e5396d7..b1610c37f 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -76,7 +76,7 @@ class Recipe: ), Config( name="code_llama2/7B_qlora_single_device", - file_path="llama2/7B_qlora_single_device.yaml", + file_path="code_llama2/7B_qlora_single_device.yaml", ), Config( name="llama3/8B_lora_single_device", From 64907b4b96c138e67c0de57b56ef3aee72504379 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 13:38:16 +0100 Subject: [PATCH 12/15] git not picking up case change in filename... --- .../{7b_qlora_single_device.yaml => 7B_qlora_single_device.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename recipes/configs/code_llama2/{7b_qlora_single_device.yaml => 7B_qlora_single_device.yaml} (100%) diff --git a/recipes/configs/code_llama2/7b_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml similarity index 100% rename from recipes/configs/code_llama2/7b_qlora_single_device.yaml rename to recipes/configs/code_llama2/7B_qlora_single_device.yaml From c8d958474d0dc352443d8b36b5da1aa0cc4a9af6 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 25 Apr 2024 14:01:13 +0100 Subject: [PATCH 13/15] updating reference to codellama huggingface repo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5b2994d77..cb0ce5489 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ torchtune currently supports the following models. |-----------------------------------------------|-----------| | [Llama3](https://llama.meta.com/llama3) | 8B, 70B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] | | [Llama2](https://llama.meta.com/llama2/) | 7B, 13B, 70B [[models](torchtune/models/llama2/_model_builders.py), [configs](recipes/configs/llama2/)] | -| [Code-Llama2](https://huggingface.co/codellama/CodeLlama-34b-hf) | 7B, 13B, 70B [[model](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] | +| [Code-Llama2](https://huggingface.co/codellama) | 7B, 13B, 70B [[model](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] | | [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] | | [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] | From 83153e5a9cf7300b7e4bc829ab923026c2c6b27f Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 26 Apr 2024 15:07:36 +0100 Subject: [PATCH 14/15] updating docs, removing qlora 70b and models --- recipes/configs/code_llama2/7B_full_low_memory.yaml | 2 +- recipes/configs/code_llama2/7B_lora_single_device.yaml | 6 +++--- recipes/configs/code_llama2/7B_qlora_single_device.yaml | 6 +++--- torchtune/models/code_llama2/__init__.py | 2 -- torchtune/models/code_llama2/_model_builders.py | 9 --------- torchtune/models/llama2/__init__.py | 1 - torchtune/models/llama2/_model_builders.py | 8 -------- 7 files changed, 7 insertions(+), 27 deletions(-) diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index fe25f180d..6b4c88d55 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -75,4 +75,4 @@ metric_logger: _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} output_dir: /tmp/code_llama2_finetune -log_every_n_steps: null +log_every_n_steps: 1 diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 62fe98143..07bc5bc71 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -10,12 +10,12 @@ # pip install bitsandbytes # # To launch on a single device, run the following command from root: -# tune run full_finetune_single_device --config code_llama2/7B_lora_single_device +# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune run full_finetune_single_device --config llama2/7B_lora_single_device checkpointer.checkpoint_dir= +# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device checkpointer.checkpoint_dir= # # This config works only for training on single device. @@ -85,7 +85,7 @@ metric_logger: _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir}/torchtune_perf_tracing.json output_dir: /tmp/lora_code_llama2_finetune_output -log_every_n_steps: null +log_every_n_steps: 1 profiler: _component_: torchtune.utils.profiler diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index 439585f39..1509cb966 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -10,12 +10,12 @@ # pip install bitsandbytes # # To launch on a single device, run the following command from root: -# tune run full_finetune_single_device --config code_llama2/7B_qlora_single_device +# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training # you can run: -# tune run full_finetune_single_device --config code_llama2/7B_qlora_single_device checkpointer.checkpoint_dir= +# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device checkpointer.checkpoint_dir= # # This config works only for training on single device. @@ -86,7 +86,7 @@ metric_logger: _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir}/torchtune_perf_tracing.json output_dir: /tmp/qlora_code_llama2_finetune_output -log_every_n_steps: null +log_every_n_steps: 1 # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/torchtune/models/code_llama2/__init__.py b/torchtune/models/code_llama2/__init__.py index b334f6f4a..40aea996e 100644 --- a/torchtune/models/code_llama2/__init__.py +++ b/torchtune/models/code_llama2/__init__.py @@ -12,7 +12,6 @@ lora_code_llama2_70b, lora_code_llama2_7b, qlora_code_llama2_13b, - qlora_code_llama2_70b, qlora_code_llama2_7b, ) @@ -24,6 +23,5 @@ "lora_code_llama2_70b", "lora_code_llama2_7b", "qlora_code_llama2_13b", - "qlora_code_llama2_70b", "qlora_code_llama2_7b", ] diff --git a/torchtune/models/code_llama2/_model_builders.py b/torchtune/models/code_llama2/_model_builders.py index 6270a9079..9b9d078b2 100644 --- a/torchtune/models/code_llama2/_model_builders.py +++ b/torchtune/models/code_llama2/_model_builders.py @@ -243,12 +243,3 @@ def lora_code_llama2_70b( lora_dropout=lora_dropout, quantize_base=quantize_base, ) - - -qlora_code_llama2_70b = partial(lora_code_llama2_70b, quantize_base=True) - -qlora_code_llama2_70b.__doc__ = """ -Builder for creating a Code-Llama2 70B model with QLoRA enabled. Base model weights in linear layers -that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. -Please see `lora_code_llama2_70b` for full API arguments. -""" diff --git a/torchtune/models/llama2/__init__.py b/torchtune/models/llama2/__init__.py index 13f38ffdd..1391e61d4 100644 --- a/torchtune/models/llama2/__init__.py +++ b/torchtune/models/llama2/__init__.py @@ -15,7 +15,6 @@ lora_llama2_70b, lora_llama2_7b, qlora_llama2_13b, - qlora_llama2_70b, qlora_llama2_7b, ) from ._model_utils import scale_hidden_dim_for_mlp diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 0602b7ec0..e83ea97fa 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -257,11 +257,3 @@ def lora_llama2_70b( lora_dropout=lora_dropout, quantize_base=quantize_base, ) - - -qlora_llama2_70b = partial(lora_llama2_70b, quantize_base=True) -qlora_llama2_70b.__doc__ = """ -Builder for creating a Llama2 70B model with QLoRA enabled. Base model weights in linear layers -that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. -Please see `lora_llama2_70b` for full API arguments. -""" From 98ef8b8d734bca8512b66c3526675f8d2ea6ccea Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 26 Apr 2024 19:00:12 +0100 Subject: [PATCH 15/15] Updating typos in docstrings, fixing recipe config errors --- recipes/configs/code_llama2/7B_full_low_memory.yaml | 1 + recipes/configs/code_llama2/7B_lora_single_device.yaml | 8 ++------ recipes/configs/code_llama2/7B_qlora_single_device.yaml | 8 ++------ torchtune/models/code_llama2/_model_builders.py | 4 ++-- 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index 6b4c88d55..75023994e 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -76,3 +76,4 @@ metric_logger: log_dir: ${output_dir} output_dir: /tmp/code_llama2_finetune log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 07bc5bc71..5798ad7ea 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -5,10 +5,6 @@ # this run: # tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf # -# The default config uses an optimizer from bitsandbytes. If you do not have it installed, -# you can install it with -# pip install bitsandbytes -# # To launch on a single device, run the following command from root: # tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device # @@ -69,7 +65,6 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 -optimizer_in_bwd: True loss: _component_: torch.nn.CrossEntropyLoss @@ -83,9 +78,10 @@ dtype: bf16 # Logging metric_logger: _component_: torchtune.utils.metric_logging.DiskLogger - log_dir: ${output_dir}/torchtune_perf_tracing.json + log_dir: ${output_dir} output_dir: /tmp/lora_code_llama2_finetune_output log_every_n_steps: 1 +log_peak_memory_stats: False profiler: _component_: torchtune.utils.profiler diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index 1509cb966..50d41d024 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -5,10 +5,6 @@ # this run: # tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf # -# The default config uses an optimizer from bitsandbytes. If you do not have it installed, -# you can install it with -# pip install bitsandbytes -# # To launch on a single device, run the following command from root: # tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device # @@ -70,7 +66,6 @@ optimizer: lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 -optimizer_in_bwd: True loss: _component_: torch.nn.CrossEntropyLoss @@ -84,9 +79,10 @@ dtype: bf16 # Logging metric_logger: _component_: torchtune.utils.metric_logging.DiskLogger - log_dir: ${output_dir}/torchtune_perf_tracing.json + log_dir: ${output_dir} output_dir: /tmp/qlora_code_llama2_finetune_output log_every_n_steps: 1 +log_peak_memory_stats: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/torchtune/models/code_llama2/_model_builders.py b/torchtune/models/code_llama2/_model_builders.py index 9b9d078b2..684958210 100644 --- a/torchtune/models/code_llama2/_model_builders.py +++ b/torchtune/models/code_llama2/_model_builders.py @@ -206,7 +206,7 @@ def lora_code_llama2_70b( """ Builder for creating a Code-Llama2 70B model with LoRA enabled. - The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_70b`, while LoRA default params are based on https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. @@ -223,7 +223,7 @@ def lora_code_llama2_70b( quantize_base (bool): Whether to quantize base model weights Returns: - TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + TransformerDecoder: Instantiation of Code-Llama2 70B model with LoRA applied """ return lora_llama2( lora_attn_modules=lora_attn_modules,