In [4]:
from typing import Optional
import torch
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device('cpu')
import time
from pathlib import Path
import json
from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
from dataclasses import dataclass
from typing import Optional
import math

from model import ModelArgs, Transformer

### state dict contains the following weights:

#### Input
* tok_embeddings (vocab_size, embedding_dim) = (32000,4096)
### Layer 0-31
#### Attention
* attention_norm (embedding_dim) = (4096)
* attention.wq (embedding_dim, embedding_dim) = (4096, 4096)
* attention.wk (embedding_dim, embedding_dim) = (4096, 4096)
* attention.wv (embedding_dim, embedding_dim) = (4096, 4096)
* attention.w0 (embedding_dim, embedding_dim) = (4096, 4096)
#### FeedFordward
* feed_forward.norm (embedding_dim) = (4096)
* feed_forward.w1 (embedding_dim, hidden_dim) = (4096, 11008)
* feed_forward.w3 (embedding_dim, hidden_dim) = (4096, 11008)

(w1 and w3 get both applied to the input embeddings and then element wise multiplied)
* feed_forward.w2 (hidden_dim, embedding_dim) = (11008, 4096)



## Output
* norm (embedding_dim) = (4096)
* output (embedding_dim, vocab_size) = (4096, 32000)

In [5]:
def load_llama(checkpoints_dir: str, vocab_size: int, max_seq_len: int):
    prev_time = time.time()
    checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
    assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}"
    ckpt_path = checkpoints[0]
    print(f'Loading checkpoint "{ckpt_path}"')
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s")

    with open(Path(checkpoints_dir) / "params.json", "r") as f:
        params = json.loads(f.read())
        print(f"params: {params}")

    model_args = ModelArgs()
    model_args.max_seq_len = max_seq_len

    assert(model_args.dim == params['dim'])
    assert(model_args.n_layers == params['n_layers'])
    assert(model_args.vocab_size == vocab_size)
    assert(model_args.n_heads == params['n_heads'])
    assert(model_args.n_layers == params['n_layers'])

    model_args.vocab_size = vocab_size
    print(f"model_args: {model_args}")
    model = Transformer(model_args)

    del checkpoint['rope.freqs']
    model.load_state_dict(checkpoint, strict=True)
    print(f"Loaded model in {time.time() - prev_time:.2f}s")
    return model

def load_tokenizer(tokenizer_path: str):
    tokenizer = SentencePieceProcessor()
    tokenizer.load(tokenizer_path)
    return tokenizer

In [6]:
tokenizer = load_tokenizer("tokenizer.model")
llama = load_llama("llama-2-7b", tokenizer.vocab_size(), max_seq_len=1024)

Loading checkpoint "llama-2-7b/consolidated.00.pth"
Loaded checkpoint in 10.01s
params: {'dim': 4096, 'multiple_of': 256, 'n_heads': 32, 'n_layers': 32, 'norm_eps': 1e-05, 'vocab_size': -1}
model_args: ModelArgs(dim=4096, n_layers=32, n_heads=32, head_dim=128, hidden_dim=11008, vocab_size=32000, norm_eps=1e-05, max_seq_len=1024)
Loaded model in 43.96s


In [7]:
def generate(model: Transformer, tokenizer: SentencePieceProcessor, promt: str, max_toks: int = 100):
    model.eval()
    with torch.no_grad():
        input = tokenizer.encode(promt)
        output = []
        # feed the entire prompt as context
        for token in tqdm(input, desc="feeding prompt"):
            out = model(token, len(output))
            output.append(token)

        # generate the rest of the tokens
        for _ in tqdm(range(max_toks - len(output)), desc="generating"):
            out = model(output[-1], len(output))
            probs = torch.softmax(out, dim=-1)
            next_token = torch.argmax(probs, dim=-1).item()
            if (next_token == tokenizer.eos_id()):
                break
            output.append(next_token)
    return tokenizer.decode(output)


In [8]:
out = generate(llama, tokenizer, "Simply put, the theory of relativity states that ", max_toks=50)
print(out)

feeding prompt: 100%|██████████| 12/12 [00:35<00:00,  2.93s/it]
generating: 100%|██████████| 38/38 [01:52<00:00,  2.97s/it]

Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, regardless of their relative motion, and 2) the laws of physics are the same for all observers, regardless of



