In [1]:
from typing import Optional
import torch
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cpu')
import time
from pathlib import Path
import json
from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
from dataclasses import dataclass
from typing import Optional
import math

from model import ModelArgs, Transformer

### state dict contains the following weights:

#### Input
* tok_embeddings (vocab_size, embedding_dim) = (32000,4096)
### Layer 0-31
#### Attention
* attention_norm (embedding_dim) = (4096)
* attention.wq (embedding_dim, embedding_dim) = (4096, 4096)
* attention.wk (embedding_dim, embedding_dim) = (4096, 4096)
* attention.wv (embedding_dim, embedding_dim) = (4096, 4096)
* attention.w0 (embedding_dim, embedding_dim) = (4096, 4096)
#### FeedFordward
* feed_forward.norm (embedding_dim) = (4096)
* feed_forward.w1 (embedding_dim, hidden_dim) = (4096, 11008)
* feed_forward.w3 (embedding_dim, hidden_dim) = (4096, 11008)

(w1 and w3 get both applied to the input embeddings and then element wise multiplied)
* feed_forward.w2 (hidden_dim, embedding_dim) = (11008, 4096)



## Output
* norm (embedding_dim) = (4096)
* output (embedding_dim, vocab_size) = (4096, 32000)

In [2]:
def load_llama(checkpoints_dir: str, vocab_size: int, max_seq_len: int):
    prev_time = time.time()
    checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
    assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}"
    ckpt_path = checkpoints[0]
    print(f'Loading checkpoint "{ckpt_path}"')
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s")

    with open(Path(checkpoints_dir) / "params.json", "r") as f:
        params = json.loads(f.read())
        print(f"params: {params}")

    model_args = ModelArgs()
    model_args.max_seq_len = max_seq_len

    assert(model_args.dim == params['dim'])
    assert(model_args.n_layers == params['n_layers'])
    assert(model_args.vocab_size == vocab_size)
    assert(model_args.n_heads == params['n_heads'])
    assert(model_args.n_layers == params['n_layers'])

    model_args.vocab_size = vocab_size
    print(f"model_args: {model_args}")
    model = Transformer(model_args)

    del checkpoint['rope.freqs']
    model.load_state_dict(checkpoint, strict=True)
    print(f"Loaded model in {time.time() - prev_time:.2f}s")
    return model

def load_tokenizer(tokenizer_path: str):
    tokenizer = SentencePieceProcessor()
    tokenizer.load(tokenizer_path)
    return tokenizer

In [3]:
tokenizer = load_tokenizer("tokenizer.model")
llama = load_llama("llama-2-7b", tokenizer.vocab_size(), max_seq_len=1024)

Loading checkpoint "llama-2-7b/consolidated.00.pth"
Loaded checkpoint in 11.88s
params: {'dim': 4096, 'multiple_of': 256, 'n_heads': 32, 'n_layers': 32, 'norm_eps': 1e-05, 'vocab_size': -1}
model_args: ModelArgs(dim=4096, n_layers=32, n_heads=32, head_dim=128, hidden_dim=11008, vocab_size=32000, norm_eps=1e-05, max_seq_len=1024)
Loaded model in 46.14s


In [5]:
print(llama.layers[0].feed_forward.w1.weight)

Parameter containing:
tensor([[ 1.5747e-02,  1.7090e-02,  3.1494e-02,  ..., -1.5869e-02,
          6.5002e-03,  1.5869e-02],
        [-2.1667e-03, -6.0120e-03,  5.6458e-03,  ...,  1.6113e-02,
         -8.6670e-03,  9.8877e-03],
        [ 6.8359e-03, -2.1606e-02,  2.0508e-02,  ..., -1.3000e-02,
          1.8921e-02,  1.9409e-02],
        ...,
        [ 1.4126e-05, -3.2227e-02,  5.7983e-03,  ..., -8.9111e-03,
         -1.3489e-02,  4.0283e-02],
        [ 2.6611e-02,  2.0142e-02, -1.7090e-02,  ..., -3.4332e-03,
         -6.4087e-03, -1.8921e-02],
        [-5.9891e-04, -1.1353e-02, -2.3682e-02,  ...,  1.1063e-03,
          5.9204e-03, -2.4780e-02]], requires_grad=True)


In [4]:
llama(3439, 0)

embedding: tensor([-0.0156,  0.0029,  0.0171,  ..., -0.0131,  0.0103,  0.0302],
       grad_fn=<SliceBackward0>)
scores: tensor([[[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]],

        [[1.]]], grad_fn=<SoftmaxBackward0>)
values: tensor([[[-3.4943e-03, -1.4343e-03,  9.7656e-04,  ..., -6.9427e-04,
           1.0681e-03, -1.5335e-03]],

        [[-1.1444e-03, -2.4319e-05, -6.0797e-05,  ...,  1.7853e-03,
           9.3384e-03, -2.9755e-03]],

        [[-1.2695e-02, -3.9368e-03, -3.7994e-03,  ..., -8.3618e-03,
           2.4128

tensor([-9.8750, -9.6875, -1.9219,  ..., -6.3125, -6.8125, -6.1562],
       dtype=torch.float32, grad_fn=<ToCopyBackward0>)

In [3]:
tokenizer = load_tokenizer("tokenizer.model")
encoded =tokenizer.encode("Simply put, the theory of relativity states that ")
print(encoded)

toks = [3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871, 29896, 29897, 278, 6210, 310, 3578, 338, 4868, 297, 599, 297, 814, 616, 3407, 16608, 29892, 322, 29871, 29906, 29897, 278, 14243, 310, 17558, 526, 278, 1021, 363, 599, 5366, 874, 29892, 17126, 310, 1009, 6198, 10884, 470]

print(tokenizer.decode(toks))

[3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871]
Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial reference frames, and 


In [7]:
def generate(model: Transformer, tokenizer: SentencePieceProcessor, promt: str, max_toks: int = 100):
    model.eval()
    with torch.no_grad():
        input = tokenizer.encode(promt)
        output = []
        # feed the entire prompt as context
        for token in tqdm(input, desc="feeding prompt"):
            out = model(token, len(output))
            output.append(token)

        # generate the rest of the tokens
        for _ in tqdm(range(max_toks - len(output)), desc="generating"):
            out = model(output[-1], len(output))
            probs = torch.softmax(out, dim=-1)
            next_token = torch.argmax(probs, dim=-1).item()
            if (next_token == tokenizer.eos_id()):
                break
            output.append(next_token)
    return tokenizer.decode(output)


In [4]:
out = generate(llama, tokenizer, "Simply put, the theory of relativity states that ", max_toks=50)
print(out)

NameError: name 'generate' is not defined

In [3]:
import struct


def serialize_fp32(file, tensor):
    """ writes one fp32 tensor to file that is open in wb mode """
    d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
    b = struct.pack(f'{len(d)}f', *d)
    file.write(b)


def serialize(filename: str, model: Transformer):
  version = 1
  magic = 0x7fdd7f7f
  out_file = open(filename, 'wb')
  out_file.write(struct.pack('I', magic))
  out_file.write(struct.pack('i', version))
  p = model.params
  header = struct.pack('iiiiii', p.dim, p.hidden_dim, p.n_heads, p.n_layers, p.vocab_size, p.max_seq_len)
  out_file.write(header)
  pad = 256 - out_file.tell()
  assert pad >= 0
  out_file.write(b'\0' * pad)
  weights = [
        *[layer.attention_norm.weight for layer in model.layers],
        *[layer.ffn_norm.weight for layer in model.layers],
        model.norm.weight,
        model.tok_embeddings.weight,
        *[layer.attention.wq.weight for layer in model.layers],
        *[layer.attention.wk.weight for layer in model.layers],
        *[layer.attention.wv.weight for layer in model.layers],
        *[layer.attention.wo.weight for layer in model.layers],
        *[layer.feed_forward.w1.weight for layer in model.layers],
        *[layer.feed_forward.w2.weight for layer in model.layers],
        *[layer.feed_forward.w3.weight for layer in model.layers],
    ]
  weights.append(model.output.weight)
  for w in weights:
    serialize_fp32(out_file, w)
  out_file.close()


In [2]:
model = Transformer(ModelArgs())

In [3]:
print(model(1,1))

torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([

In [5]:
serialize("llama.bin", model)