In [1]:
from torchtune.models.llama2 import llama2
from torchtune.modules import TransformerDecoder

In [2]:
def llama2_tiny() -> TransformerDecoder:
    """
    Builder for creating a Llama2 model initialized w/ the default 7b parameter values
    from https://arxiv.org/abs/2307.09288

    Args:
        max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache.

    Returns:
        TransformerDecoder: Instantiation of Llama2 7B model
    """
    return llama2(
        vocab_size=32_000,
        num_layers=2,
        num_heads=4,
        num_kv_heads=4,
        embed_dim=128,
        max_seq_len=4096,
        attn_dropout=0.0,
        norm_eps=1e-5,
    )

In [3]:
m = llama2_tiny()

In [4]:
chpt_dict = {"model": m}


In [5]:
import torch
torch.save(m.state_dict(), "tiny_llama/model.pt")