In [1]:
from torchtune.models.llama2 import llama2
from torchtune.modules import TransformerDecoder
from torchtune.utils.checkpoint import save_checkpoint, load_checkpoint

In [2]:
def llama2_tiny(max_batch_size = None) -> TransformerDecoder:
    """
    Builder for creating a Llama2 model initialized w/ the default 7b parameter values
    from https://arxiv.org/abs/2307.09288

    Args:
        max_batch_size (Optional[int]): Maximum batch size to be passed to KVCache.

    Returns:
        TransformerDecoder: Instantiation of Llama2 7B model
    """
    return llama2(
        vocab_size=32_000,
        num_layers=2,
        num_heads=4,
        num_kv_heads=4,
        embed_dim=128,
        max_seq_len=4096,
        max_batch_size=max_batch_size,
        attn_dropout=0.0,
        norm_eps=1e-5,
    )

In [3]:
m = llama2_tiny()

In [4]:
chpt_dict = {"model": m}


In [8]:
import torch
torch.save(m.state_dict(), "tiny_llama/model.pt")

In [5]:
save_checkpoint(chpt_dict, 'tiny_llama/model.pth')

In [6]:
load_checkpoint('tiny_llama/model.pth', m)

{'model': OrderedDict([('tok_embeddings.weight',
               tensor([[ 0.6721, -0.0063,  0.2206,  ..., -0.4694,  1.8257,  2.1110],
                       [-0.3021,  0.1397,  0.0167,  ...,  1.3691, -0.1057,  0.4563],
                       [ 0.2236,  2.0759,  3.3434,  ...,  0.3655,  0.8593,  0.4055],
                       ...,
                       [ 2.0139, -0.4313, -0.6845,  ..., -0.1600,  0.2825, -0.6131],
                       [ 0.1583,  0.9916,  1.0918,  ..., -2.1604, -1.0594,  0.0648],
                       [-1.4329, -0.8049, -0.5369,  ..., -1.2593,  2.4386,  1.2828]])),
              ('layers.0.sa_norm.scale',
               tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1