diff --git a/README.md b/README.md index 65d31ca55b..cb0ce54899 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ torchtune currently supports the following models. |-----------------------------------------------|-----------| | [Llama3](https://llama.meta.com/llama3) | 8B, 70B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] | | [Llama2](https://llama.meta.com/llama2/) | 7B, 13B, 70B [[models](torchtune/models/llama2/_model_builders.py), [configs](recipes/configs/llama2/)] | +| [Code-Llama2](https://huggingface.co/codellama) | 7B, 13B, 70B [[model](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] | | [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] | | [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] | diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml new file mode 100644 index 0000000000..75023994e3 --- /dev/null +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -0,0 +1,79 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config code_llama2/7B_full_low_memory checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: True +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.code_llama2.code_llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 2e-5 +optimizer_in_bwd: True +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/code_llama2_finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml new file mode 100644 index 0000000000..5798ad7ea0 --- /dev/null +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -0,0 +1,89 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config code_llama2/7B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.code_llama2.lora_code_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + train_on_input: True +seed: null +shuffle: True + + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + + +# Training environment +device: cuda +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/lora_code_llama2_finetune_output +log_every_n_steps: 1 +log_peak_memory_stats: False + +profiler: + _component_: torchtune.utils.profiler + enabled: False + output_dir: ${output_dir}/torchtune_perf_tracing.json diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml new file mode 100644 index 0000000000..50d41d024a --- /dev/null +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -0,0 +1,92 @@ +# Config for single device QLoRA finetuning in lora_finetune_single_device.py +# using a Code-Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download codellama/CodeLlama-7b-hf --output-dir /tmp/CodeLlama-7b-hf +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config code_llama2/7B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.code_llama2.qlora_code_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/CodeLlama-7b-hf/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + train_on_input: True +seed: null +shuffle: True + + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/CodeLlama-7b-hf + checkpoint_files: [ + pytorch_model-00001-of-00003.bin, + pytorch_model-00002-of-00003.bin, + pytorch_model-00003-of-00003.bin + ] + adapter_checkpoint: null + recipe_checkpoint: null + output_dir: /tmp/CodeLlama-7b-hf + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments and training +batch_size: 2 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 64 +compile: False + +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss + + +# Training environment +device: cuda +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/qlora_code_llama2_finetune_output +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.utils.profiler + enabled: False + output_dir: ${output_dir}/torchtune_perf_tracing.json diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 039a106acc..b1610c37f4 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -31,6 +31,10 @@ class Recipe: name="llama2/7B_full_low_memory", file_path="llama2/7B_full_low_memory.yaml", ), + Config( + name="code_llama2/7B_full_low_memory", + file_path="code_llama2/7B_full_low_memory.yaml", + ), Config( name="llama3/8B_full_single_device", file_path="llama3/8B_full_single_device.yaml", @@ -66,6 +70,14 @@ class Recipe: name="llama2/7B_qlora_single_device", file_path="llama2/7B_qlora_single_device.yaml", ), + Config( + name="code_llama2/7B_lora_single_device", + file_path="code_llama2/7B_lora_single_device.yaml", + ), + Config( + name="code_llama2/7B_qlora_single_device", + file_path="code_llama2/7B_qlora_single_device.yaml", + ), Config( name="llama3/8B_lora_single_device", file_path="llama3/8B_lora_single_device.yaml", diff --git a/torchtune/models/__init__.py b/torchtune/models/__init__.py index f57ff7e27b..33605b74a0 100644 --- a/torchtune/models/__init__.py +++ b/torchtune/models/__init__.py @@ -4,4 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.models import convert_weights, gemma, llama2, mistral # noqa +from torchtune.models import ( # noqa + code_llama2, + convert_weights, + gemma, + llama2, + mistral, +) diff --git a/torchtune/models/code_llama2/__init__.py b/torchtune/models/code_llama2/__init__.py new file mode 100644 index 0000000000..40aea996eb --- /dev/null +++ b/torchtune/models/code_llama2/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._model_builders import ( # noqa + code_llama2_13b, + code_llama2_70b, + code_llama2_7b, + lora_code_llama2_13b, + lora_code_llama2_70b, + lora_code_llama2_7b, + qlora_code_llama2_13b, + qlora_code_llama2_7b, +) + +__all__ = [ + "code_llama2_13b", + "code_llama2_70b", + "code_llama2_7b", + "lora_code_llama2_13b", + "lora_code_llama2_70b", + "lora_code_llama2_7b", + "qlora_code_llama2_13b", + "qlora_code_llama2_7b", +] diff --git a/torchtune/models/code_llama2/_model_builders.py b/torchtune/models/code_llama2/_model_builders.py new file mode 100644 index 0000000000..6849582105 --- /dev/null +++ b/torchtune/models/code_llama2/_model_builders.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from functools import partial + +from torchtune.models.llama2._component_builders import llama2, lora_llama2 + +from torchtune.modules import TransformerDecoder +from torchtune.modules.peft import LORA_ATTN_MODULES + + +def code_llama2_7b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 7B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model + """ + return llama2( + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_7b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 7B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=32, + num_heads=32, + num_kv_heads=32, + embed_dim=4096, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_7b = partial(lora_code_llama2_7b, quantize_base=True) + +qlora_code_llama2_7b.__doc__ = """ +Builder for creating a Code-Llama2 7B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_7b` for full API arguments. +""" + + +def code_llama2_13b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 13B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 13B model + """ + return llama2( + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_13b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 13B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=40, + num_heads=40, + num_kv_heads=40, + embed_dim=5120, + intermediate_dim=13824, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) + + +qlora_code_llama2_13b = partial(lora_code_llama2_13b, quantize_base=True) + +qlora_code_llama2_13b.__doc__ = """ +Builder for creating a Code-Llama2 13B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_code_llama2_13b` for full API arguments. +""" + + +def code_llama2_70b() -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 model initialized w/ the default 70B parameter values + from https://arxiv.org/pdf/2308.12950.pdf + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 70B model + """ + return llama2( + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + ) + + +def lora_code_llama2_70b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Code-Llama2 70B model with LoRA enabled. + + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_70b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Code-Llama2 70B model with LoRA applied + """ + return lora_llama2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + apply_lora_to_output=apply_lora_to_output, + vocab_size=32_016, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=28672, + max_seq_len=16384, + attn_dropout=0.0, + norm_eps=1e-5, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quantize_base=quantize_base, + ) diff --git a/torchtune/models/llama2/_model_builders.py b/torchtune/models/llama2/_model_builders.py index 15db69e146..e83ea97fac 100644 --- a/torchtune/models/llama2/_model_builders.py +++ b/torchtune/models/llama2/_model_builders.py @@ -22,7 +22,7 @@ def llama2_7b() -> TransformerDecoder: """ - Builder for creating a Llama2 model initialized w/ the default 7b parameter values + Builder for creating a Llama2 model initialized w/ the default 7B parameter values from https://arxiv.org/abs/2307.09288 Returns: @@ -92,14 +92,15 @@ def lora_llama2_7b( norm_eps=1e-5, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, ) + qlora_llama2_7b = partial(lora_llama2_7b, quantize_base=True) qlora_llama2_7b.__doc__ = """ -Builder for creating a Llama2 model with QLoRA enabled. Base model weights in linear layers +Builder for creating a Llama2 7B model with QLoRA enabled. Base model weights in linear layers that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. Please see `lora_llama2_7b` for full API arguments. """ @@ -107,7 +108,7 @@ def lora_llama2_7b( def llama2_13b() -> TransformerDecoder: """ - Builder for creating a Llama2 model initialized w/ the default 13b parameter values + Builder for creating a Llama2 model initialized w/ the default 13B parameter values from https://arxiv.org/abs/2307.09288 Returns: @@ -123,7 +124,6 @@ def llama2_13b() -> TransformerDecoder: max_seq_len=4096, attn_dropout=0.0, norm_eps=1e-5, - ) @@ -133,6 +133,7 @@ def lora_llama2_13b( apply_lora_to_output: bool = False, lora_rank: int = 8, lora_alpha: float = 16, + lora_dropout: float = 0.05, quantize_base: bool = False, ) -> TransformerDecoder: """ @@ -167,26 +168,32 @@ def lora_llama2_13b( num_heads=40, num_kv_heads=40, embed_dim=5120, - max_seq_len=4096, intermediate_dim=13824, + max_seq_len=4096, attn_dropout=0.0, norm_eps=1e-5, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, ) + qlora_llama2_13b = partial(lora_llama2_13b, quantize_base=True) +qlora_llama2_13b.__doc__ = """ +Builder for creating a Llama2 13B model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_llama2_13b` for full API arguments. +""" def llama2_70b() -> TransformerDecoder: """ - Builder for creating a Llama2 model initialized w/ the default 70 parameter values + Builder for creating a Llama2 model initialized w/ the default 70B parameter values from https://arxiv.org/abs/2307.09288 Returns: - TransformerDecoder: Instantiation of Llama2 70 model + TransformerDecoder: Instantiation of Llama2 70B model """ return llama2( vocab_size=32_000, @@ -213,7 +220,7 @@ def lora_llama2_70b( """ Builder for creating a Llama2 70B model with LoRA enabled. - The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_7b`, + The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_70b`, while LoRA default params are based on https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. @@ -247,6 +254,6 @@ def lora_llama2_70b( norm_eps=1e-5, lora_rank=lora_rank, lora_alpha=lora_alpha, - lora_dropout=0.05, + lora_dropout=lora_dropout, quantize_base=quantize_base, )