-
Notifications
You must be signed in to change notification settings - Fork 314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding model builders for code-llama2 7b, 13b, and 70b #847
Changes from 3 commits
d38fec2
366ddc9
8c11e2a
ba0dd40
329fed5
b88bade
36e1754
257b5de
7890051
2e06912
ff4f6f5
64907b4
c8d9584
83153e5
98ef8b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
|
||
def llama2_7b() -> TransformerDecoder: | ||
""" | ||
Builder for creating a Llama2 model initialized w/ the default 7b parameter values | ||
Builder for creating a Llama2 model initialized w/ the default 7B parameter values | ||
from https://arxiv.org/abs/2307.09288 | ||
|
||
Returns: | ||
|
@@ -92,22 +92,23 @@ def lora_llama2_7b( | |
norm_eps=1e-5, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=0.05, | ||
lora_dropout=lora_dropout, | ||
quantize_base=quantize_base, | ||
) | ||
|
||
|
||
qlora_llama2_7b = partial(lora_llama2_7b, quantize_base=True) | ||
|
||
qlora_llama2_7b.__doc__ = """ | ||
Builder for creating a Llama2 model with QLoRA enabled. Base model weights in linear layers | ||
Builder for creating a Llama2 7B model with QLoRA enabled. Base model weights in linear layers | ||
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. | ||
Please see `lora_llama2_7b` for full API arguments. | ||
""" | ||
|
||
|
||
def llama2_13b() -> TransformerDecoder: | ||
""" | ||
Builder for creating a Llama2 model initialized w/ the default 13b parameter values | ||
Builder for creating a Llama2 model initialized w/ the default 13B parameter values | ||
from https://arxiv.org/abs/2307.09288 | ||
|
||
Returns: | ||
|
@@ -123,7 +124,6 @@ def llama2_13b() -> TransformerDecoder: | |
max_seq_len=4096, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
|
||
) | ||
|
||
|
||
|
@@ -133,6 +133,7 @@ def lora_llama2_13b( | |
apply_lora_to_output: bool = False, | ||
lora_rank: int = 8, | ||
lora_alpha: float = 16, | ||
lora_dropout: float = 0.05, | ||
quantize_base: bool = False, | ||
) -> TransformerDecoder: | ||
""" | ||
|
@@ -167,26 +168,32 @@ def lora_llama2_13b( | |
num_heads=40, | ||
num_kv_heads=40, | ||
embed_dim=5120, | ||
max_seq_len=4096, | ||
intermediate_dim=13824, | ||
max_seq_len=4096, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=0.05, | ||
lora_dropout=lora_dropout, | ||
quantize_base=quantize_base, | ||
) | ||
|
||
|
||
qlora_llama2_13b = partial(lora_llama2_13b, quantize_base=True) | ||
qlora_llama2_13b.__doc__ = """ | ||
Builder for creating a Llama2 13B model with QLoRA enabled. Base model weights in linear layers | ||
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. | ||
Please see `lora_llama2_13b` for full API arguments. | ||
""" | ||
|
||
|
||
def llama2_70b() -> TransformerDecoder: | ||
""" | ||
Builder for creating a Llama2 model initialized w/ the default 70 parameter values | ||
Builder for creating a Llama2 model initialized w/ the default 70B parameter values | ||
from https://arxiv.org/abs/2307.09288 | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Llama2 70 model | ||
TransformerDecoder: Instantiation of Llama2 70B model | ||
""" | ||
return llama2( | ||
vocab_size=32_000, | ||
|
@@ -213,7 +220,7 @@ def lora_llama2_70b( | |
""" | ||
Builder for creating a Llama2 70B model with LoRA enabled. | ||
|
||
The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_7b`, | ||
The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.llama2_70b`, | ||
while LoRA default params are based on | ||
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. | ||
|
||
|
@@ -247,6 +254,255 @@ def lora_llama2_70b( | |
norm_eps=1e-5, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=0.05, | ||
lora_dropout=lora_dropout, | ||
quantize_base=quantize_base, | ||
) | ||
|
||
|
||
qlora_llama2_70b = partial(lora_llama2_70b, quantize_base=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment here re QLoRA + 70B |
||
qlora_llama2_70b.__doc__ = """ | ||
Builder for creating a Llama2 70B model with QLoRA enabled. Base model weights in linear layers | ||
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. | ||
Please see `lora_llama2_70b` for full API arguments. | ||
""" | ||
|
||
|
||
def code_llama2_7b() -> TransformerDecoder: | ||
""" | ||
Builder for creating a Code-Llama2 model initialized w/ the default 7B parameter values | ||
from https://arxiv.org/pdf/2308.12950.pdf | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Code-Llama2 7B model | ||
""" | ||
return llama2( | ||
vocab_size=32_016, | ||
num_layers=32, | ||
num_heads=32, | ||
num_kv_heads=32, | ||
embed_dim=4096, | ||
max_seq_len=16384, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
) | ||
|
||
|
||
def lora_code_llama2_7b( | ||
lora_attn_modules: List[LORA_ATTN_MODULES], | ||
apply_lora_to_mlp: bool = False, | ||
apply_lora_to_output: bool = False, | ||
lora_rank: int = 8, | ||
lora_alpha: float = 16, | ||
lora_dropout: float = 0.05, | ||
quantize_base: bool = False, | ||
) -> TransformerDecoder: | ||
""" | ||
Builder for creating a Code-Llama2 7B model with LoRA enabled. | ||
|
||
The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_7b`, | ||
while LoRA default params are based on | ||
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. | ||
|
||
Args: | ||
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers | ||
LoRA should be applied to in each self-attention block. Options are | ||
``{"q_proj", "k_proj", "v_proj", "output_proj"}``. | ||
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. | ||
Default: False | ||
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. | ||
Default: False | ||
lora_rank (int): rank of each low-rank approximation | ||
lora_alpha (float): scaling factor for the low-rank approximation | ||
quantize_base (bool): Whether to quantize base model weights | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied | ||
""" | ||
return lora_llama2( | ||
lora_attn_modules=lora_attn_modules, | ||
apply_lora_to_mlp=apply_lora_to_mlp, | ||
apply_lora_to_output=apply_lora_to_output, | ||
vocab_size=32_016, | ||
num_layers=32, | ||
num_heads=32, | ||
num_kv_heads=32, | ||
embed_dim=4096, | ||
max_seq_len=16384, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=lora_dropout, | ||
quantize_base=quantize_base, | ||
) | ||
|
||
|
||
qlora_code_llama2_7b = partial(lora_code_llama2_7b, quantize_base=True) | ||
|
||
qlora_code_llama2_7b.__doc__ = """ | ||
Builder for creating a Code-Llama2 7B model with QLoRA enabled. Base model weights in linear layers | ||
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. | ||
Please see `lora_code_llama2_7b` for full API arguments. | ||
""" | ||
|
||
|
||
def code_llama2_13b() -> TransformerDecoder: | ||
""" | ||
Builder for creating a Code-Llama2 model initialized w/ the default 13B parameter values | ||
from https://arxiv.org/pdf/2308.12950.pdf | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Code-Llama2 13B model | ||
""" | ||
return llama2( | ||
vocab_size=32_016, | ||
num_layers=40, | ||
num_heads=40, | ||
num_kv_heads=40, | ||
embed_dim=5120, | ||
intermediate_dim=13824, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think you don't need this (the math for scale_hidden_dim_for_mlp should work). Fine to keep it in just to be explicit though (honestly I am leaning towards that approach more and more cause I'm sick of all these integer-rounded calculations 😅 ) |
||
max_seq_len=16384, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
) | ||
|
||
|
||
def lora_code_llama2_13b( | ||
lora_attn_modules: List[LORA_ATTN_MODULES], | ||
apply_lora_to_mlp: bool = False, | ||
apply_lora_to_output: bool = False, | ||
lora_rank: int = 8, | ||
lora_alpha: float = 16, | ||
lora_dropout: float = 0.05, | ||
quantize_base: bool = False, | ||
) -> TransformerDecoder: | ||
""" | ||
Builder for creating a Code-Llama2 13B model with LoRA enabled. | ||
|
||
The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, | ||
while LoRA default params are based on | ||
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. | ||
|
||
Args: | ||
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers | ||
LoRA should be applied to in each self-attention block. Options are | ||
``{"q_proj", "k_proj", "v_proj", "output_proj"}``. | ||
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. | ||
Default: False | ||
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. | ||
Default: False | ||
lora_rank (int): rank of each low-rank approximation | ||
lora_alpha (float): scaling factor for the low-rank approximation | ||
quantize_base (bool): Whether to quantize base model weights | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied | ||
""" | ||
return lora_llama2( | ||
lora_attn_modules=lora_attn_modules, | ||
apply_lora_to_mlp=apply_lora_to_mlp, | ||
apply_lora_to_output=apply_lora_to_output, | ||
vocab_size=32_016, | ||
num_layers=40, | ||
num_heads=40, | ||
num_kv_heads=40, | ||
embed_dim=5120, | ||
intermediate_dim=13824, | ||
max_seq_len=16384, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=lora_dropout, | ||
quantize_base=quantize_base, | ||
) | ||
|
||
|
||
qlora_code_llama2_13b = partial(lora_code_llama2_13b, quantize_base=True) | ||
|
||
qlora_code_llama2_13b.__doc__ = """ | ||
Builder for creating a Code-Llama2 13B model with QLoRA enabled. Base model weights in linear layers | ||
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. | ||
Please see `lora_code_llama2_13b` for full API arguments. | ||
""" | ||
|
||
|
||
def code_llama2_70b() -> TransformerDecoder: | ||
""" | ||
Builder for creating a Code-Llama2 model initialized w/ the default 70B parameter values | ||
from https://arxiv.org/pdf/2308.12950.pdf | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Code-Llama2 70B model | ||
""" | ||
return llama2( | ||
vocab_size=32_016, | ||
num_layers=80, | ||
num_heads=64, | ||
num_kv_heads=8, | ||
embed_dim=8192, | ||
intermediate_dim=28672, | ||
max_seq_len=16384, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
) | ||
|
||
|
||
def lora_code_llama2_70b( | ||
lora_attn_modules: List[LORA_ATTN_MODULES], | ||
apply_lora_to_mlp: bool = False, | ||
apply_lora_to_output: bool = False, | ||
lora_rank: int = 8, | ||
lora_alpha: float = 16, | ||
lora_dropout: float = 0.05, | ||
quantize_base: bool = False, | ||
) -> TransformerDecoder: | ||
""" | ||
Builder for creating a Code-Llama2 70B model with LoRA enabled. | ||
|
||
The Llama2 defaults are the same as in :func:`~torchtune.models.llama2.code_llama2_13b`, | ||
while LoRA default params are based on | ||
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. | ||
|
||
Args: | ||
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers | ||
LoRA should be applied to in each self-attention block. Options are | ||
``{"q_proj", "k_proj", "v_proj", "output_proj"}``. | ||
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. | ||
Default: False | ||
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. | ||
Default: False | ||
lora_rank (int): rank of each low-rank approximation | ||
lora_alpha (float): scaling factor for the low-rank approximation | ||
quantize_base (bool): Whether to quantize base model weights | ||
|
||
Returns: | ||
TransformerDecoder: Instantiation of Code-Llama2 7B model with LoRA applied | ||
""" | ||
return lora_llama2( | ||
lora_attn_modules=lora_attn_modules, | ||
apply_lora_to_mlp=apply_lora_to_mlp, | ||
apply_lora_to_output=apply_lora_to_output, | ||
vocab_size=32_016, | ||
num_layers=80, | ||
num_heads=64, | ||
num_kv_heads=8, | ||
embed_dim=8192, | ||
intermediate_dim=28672, | ||
max_seq_len=16384, | ||
attn_dropout=0.0, | ||
norm_eps=1e-5, | ||
lora_rank=lora_rank, | ||
lora_alpha=lora_alpha, | ||
lora_dropout=lora_dropout, | ||
quantize_base=quantize_base, | ||
) | ||
|
||
|
||
qlora_code_llama2_70b = partial(lora_code_llama2_70b, quantize_base=True) | ||
|
||
qlora_code_llama2_70b.__doc__ = """ | ||
Builder for creating a Code-Llama2 70B model with QLoRA enabled. Base model weights in linear layers | ||
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. | ||
Please see `lora_code_llama2_70b` for full API arguments. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yikes, good catch!