# Transformer Teardown: Llama 3

> Look mom no hands!

# Setup

In [1]:
from functools import partial
import json
import math
import os
from pathlib import Path
import warnings

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from pandas import Series
from pytest import approx

import torch
from torch import nn
from torch.nn.functional import relu, silu, softmax

from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import RMSNorm

import stickshift as ss
from stickshift import default_arg, take
from stickshift.models import llama

In [2]:
# Ignore all warnings
warnings.filterwarnings("ignore")

# Configure gpu
device = ss.torch.device()

In [3]:
%%html
<style>
.stickshift-figure {
    display: block;
    margin-left: auto !important;
    margin-right: auto !important;
}
</style>

# Text Generation with Llama 3

Our goal is to build a Llama 3 text generation pipeline using the pre-trained weights and Meta's reference implementation as a guide. By the time we're done, our model should be able to correctly answer the question "What is the capital of Massachusetts?"

In [4]:
question = "What is the capital of Massachusetts?"

# Load Checkpoint

We'll start by loading the hyperparameters and pre-trained weights published by Meta.

In [5]:
config, checkpoint = llama.load_checkpoint("Meta-Llama3.1-8B-Instruct", device=device)

In [6]:
config.model_dump()

{'checkpoint_path': PosixPath('/Users/andrewyoung/.llama/checkpoints/Meta-Llama3.1-8B-Instruct'),
 'vocab_size': 128256,
 'd_model': 4096,
 'd_head': 128,
 'd_ffn': 14336,
 'n_layers': 32,
 'n_heads': 32,
 'n_kv_heads': 8,
 'n_kv_groups': 4,
 'rms_norm_eps': 1e-05,
 'rope_theta': 500000.0}

In [7]:
def load_pretrained_state(layer):    
    # Load pre-trained state
    llama.load_state(
        normalize_attention, "normalize_attention", 
        normalize_ffn, "normalize_ffn", 
        w_q, "w_q", 
        w_k, "w_k", 
        w_v, "w_v", 
        attention_outputs, "attention_outputs",
        ffn_gates, "ffn_gates",
        ffn_inputs, "ffn_inputs",
        ffn_outputs, "ffn_outputs",
        checkpoint=checkpoint,
        layer=layer,
    ) 

In [8]:
# [k for k in checkpoint]

# Transformer Pipeline

<img src="transformer-pipeline.svg" class="stickshift-figure" width="800">

# Tokenize

The Tokenize stage transforms raw data into a sequence of "tokens".

In [9]:
tokenizer = Tokenizer(str(config.checkpoint_path / "tokenizer.model"))

In [10]:
token_ids = tokenizer.encode(question, bos=True, eos=False)
token_ids

[128000, 3923, 374, 279, 6864, 315, 22108, 30]

In [11]:
tokenizer.decode(token_ids)

'<|begin_of_text|>What is the capital of Massachusetts?'

# Embeddings

The Embeddings stage converts tokens into "embeddings".

In [12]:
# Initialize embeddings lookup table
embeddings = nn.Embedding(
    num_embeddings=config.vocab_size, 
    embedding_dim=config.d_model,
    device=device,
)

# Load pre-trained weights
llama.load_state(embeddings, "embeddings", checkpoint=checkpoint)

In [13]:
# Load token_ids into a tensor
token_values = torch.tensor(token_ids, device=device)

token_values.shape

torch.Size([8])

In [14]:
# Record sequence length
n = len(token_values)

n

8

In [15]:
# Map token values to embeddings
x = embeddings(token_values)

x.shape

torch.Size([8, 4096])

In [16]:
# Show sample
x

tensor([[ 2.6512e-04, -4.9973e-04, -5.8365e-04,  ...,  3.8147e-03,
          6.3419e-05,  1.1902e-03],
        [ 2.0752e-02, -1.2894e-03,  2.8229e-03,  ...,  2.1973e-02,
          3.1128e-03,  1.0681e-02],
        [-2.6093e-03,  7.7057e-04,  2.6131e-04,  ...,  1.1902e-02,
          4.6387e-03,  9.1553e-03],
        ...,
        [ 1.2817e-03,  9.1171e-04,  2.0905e-03,  ...,  1.6251e-03,
          4.0894e-03, -4.0283e-03],
        [ 1.2146e-02,  1.1597e-02,  1.7822e-02,  ...,  1.9684e-03,
         -1.4771e-02, -2.5940e-03],
        [-4.8523e-03, -1.8005e-03,  7.2937e-03,  ...,  2.3956e-03,
         -1.3657e-03, -5.4932e-03]], device='mps:0',
       grad_fn=<EmbeddingBackward0>)

# Context Layers

The Context Layers in a Transformer are responsible for infusing each token embedding with contextual signals drawn from the rest of the sequence. The mechanism works by passing the token embeddings through multiple layers of attention and feedforward blocks. The attention blocks focus on relationships between tokens, augmenting each embedding with contextual information drawn from the surrounding embeddings. The feedforward blocks capitalize on the extra context, transforming each augmented embedding using a fully connected neural network. This pattern of attention and transformation is repeated over and over again, gradually converting representations of individual words into representations of abstract semantic concepts over a series of small increments.

<img src="transformer-layers.svg" class="stickshift-figure" width="800">

## Attention

The attention block augments each token representation with additional context drawn from the surrounding tokens.

<img src="attention.svg" class="stickshift-figure" width="500">

Attention starts with the token embeddings stacked in an $n \times d_{model}$ matrix $\mathbf{X}$. For each embedding $\Set{\mathbf{x}_i | \mathbf{x}_i \in \mathbf{X}}$, we'll generate a new attention embedding $\mathbf{a}_i$ using a weighted combination of all embeddings in $\mathbf{X}$. The equation below emphasizes the fact that the weight for each embedding $\mathbf{x}_j$ is a calculated as a function $f_{w}$ of the embedding values $\mathbf{x}_i$, $\mathbf{x}_j$ and their positions $i$, $j$.

$$
\mathbf{A} = \Set{\mathbf{a}_i | \mathbf{a}_i = \sum_{\mathbf{x}_j \in \mathbf{X}} f_{w}(\mathbf{x}_i, \mathbf{x}_j, i, j) \mathbf{x}_j}
$$

The attention function used by LLama 3 can be rewritten as

$$
\begin{align}
\mathbf{A} &= softmax\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}} + \mathbf{M}\right)\mathbf{V} \\
\end{align}
$$

which expands to

$$
\begin{align}
\mathbf{A} &= softmax\left(\frac{(\mathbf{R}_{\Theta}^d\mathbf{W}_Q\mathbf{X})(\mathbf{R}_{\Theta}^d\mathbf{W}_K\mathbf{X})^T}{\sqrt{d}} + \mathbf{M}\right)\mathbf{W}_V\mathbf{X} \\
\end{align}
$$


### Normalize Inputs

In [17]:
# Configure attention normalization
normalize_attention = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_attention, "normalize_attention", checkpoint=checkpoint)

In [18]:
# Normalize attention inputs
residual = x
x = normalize_attention(x)

x.shape

torch.Size([8, 4096])

### Project Queries, Keys, Values

In [19]:
# Configure query, key, value projections
w_q = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_heads * config.d_head,
    bias=False,
    device=device,
)
w_k = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)
w_v = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(w_q, "w_q", w_k, "w_k", w_v, "w_v", checkpoint=checkpoint)

In [20]:
# Project embeddings to query, key, value spaces
q = w_q(x)
k = w_k(x)
v = w_v(x)

q.shape, k.shape, v.shape

(torch.Size([8, 4096]), torch.Size([8, 1024]), torch.Size([8, 1024]))

### Split Attention Heads

In [21]:
def split_heads(x, n_heads):
    return x.view(-1, n_heads, config.d_head).transpose(-3, -2)

In [22]:
# Split attention heads
q = split_heads(q, config.n_heads)
k = split_heads(k, config.n_kv_heads)
v = split_heads(v, config.n_kv_heads)

q.shape, k.shape, v.shape

(torch.Size([32, 8, 128]), torch.Size([8, 8, 128]), torch.Size([8, 8, 128]))

### Encode Positions (RoPE)

Given $\Theta = \text{base}$, $m = \text{position}$ and $d = d_{head},$ the RoPE rotation matrix $R_{\Theta,m}^d$ can be calculated as:

$$
\mathbf{R}_{\Theta,m}^d \mathbf{x} = 
\begin{bmatrix}
x_0 \\
x_1 \\
x_2 \\
x_3 \\
\vdots \\
x_{d/2-2} \\
x_{d/2-1} \\
\end{bmatrix}
\begin{bmatrix}
cos(m \theta_0) \\
cos(m \theta_0) \\
cos(m \theta_1) \\
cos(m \theta_1) \\
\vdots \\
cos(m \theta_{d/2-1}) \\
cos(m \theta_{d/2-1}) \\
\end{bmatrix}
+
\begin{bmatrix}
-x_1 \\
x_0 \\
-x_3 \\
x_2 \\
\vdots \\
-x_{d/2-1} \\
x_{d/2-2} \\
\end{bmatrix}
\begin{bmatrix}
sin(m \theta_0) \\
sin(m \theta_0) \\
sin(m \theta_1) \\
sin(m \theta_1) \\
\vdots \\
sin(m \theta_{d/2-1}) \\
sin(m \theta_{d/2-1}) \\
\end{bmatrix}             
$$

where

$$
\theta_i = \frac{1}{\Theta^{2i/d}}
$$

In [23]:
# Compute rope_cos and rope_sin
base = config.rope_theta
d = config.d_head

# Compute theta_i = 1 / base^(2i/d) from i = 0 to d/2-1
thetas = 1.0 / base**(2 * torch.arange(d // 2, device=device) / d)

# Compute m * theta_i for position m in 0 to n
frequencies = torch.stack([m*thetas for m in range(n)])

# Duplicate each row
frequencies = torch.cat((frequencies, frequencies), dim=-1)

# Apply cos, sin
rope_cos = torch.cos(frequencies)
rope_sin = torch.sin(frequencies)

rope_cos.shape, rope_sin.shape

(torch.Size([8, 128]), torch.Size([8, 128]))

In [24]:
# Sanity check
assert rope_cos.shape[0] == n and rope_cos.shape[1] == config.d_head
assert rope_sin.shape[0] == n and rope_sin.shape[1] == config.d_head

In [25]:
# Encode positions by rotating queries and keys
q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)

q.shape, k.shape, v.shape

(torch.Size([32, 8, 128]), torch.Size([8, 8, 128]), torch.Size([8, 8, 128]))

### Expand Key / Value Groups (GQA)

In [26]:
# Expand key/value groups
k = k.repeat_interleave(config.n_kv_groups, dim=0)
v = v.repeat_interleave(config.n_kv_groups, dim=0)

q.shape, k.shape, v.shape

(torch.Size([32, 8, 128]), torch.Size([32, 8, 128]), torch.Size([32, 8, 128]))

In [27]:
# Sanity check
assert q.shape == k.shape == v.shape

### Calculate Attention

$$
\begin{align}
\mathbf{A} &= softmax\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}} + \mathbf{M}\right)\mathbf{V} \\
\end{align}
$$

In [28]:
# Compute masked attention bias M
mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))

m

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='mps:0')

In [29]:
# Compute attention for all heads in parallel
a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v

a.shape

torch.Size([32, 8, 128])

### Recombine Attention Heads

In [30]:
def combine_heads(x):
    return x.transpose(-3, -2).contiguous().view(-1, int(config.n_heads * config.d_head))

In [31]:
# Combine attention heads
a = combine_heads(a)

a.shape

torch.Size([8, 4096])

### Project Outputs

In [32]:
# Configure attention output projection
attention_outputs = nn.Linear(
    in_features=config.d_model, 
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(attention_outputs, "attention_outputs", checkpoint=checkpoint)

In [33]:
# Project attention embeddings back to model space
a = attention_outputs(a)

a.shape

torch.Size([8, 4096])

### Combine w/ Residuals

In [34]:
# Combine attention embeddings with residuals
x = residual + a

x.shape

torch.Size([8, 4096])

## FFN

### Normalize Inputs

In [35]:
# Configure FFN normalization
normalize_ffn = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained state
llama.load_state(normalize_ffn, "normalize_ffn", checkpoint=checkpoint)

In [36]:
# Normalize FFN inputs
residual = x
x = normalize_ffn(x)

x.shape

torch.Size([8, 4096])

### Transform

In [37]:
# Configure SwiGLU FFN
ffn_gates = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)
ffn_inputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)
ffn_outputs = nn.Linear(
    in_features=config.d_ffn,
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(ffn_gates, "ffn_gates", ffn_inputs, "ffn_inputs", ffn_outputs, "ffn_outputs", checkpoint=checkpoint)

In [38]:
# Apply FFN
f = ffn_outputs(silu(ffn_gates(x)) * ffn_inputs(x))

f.shape

torch.Size([8, 4096])

### Combine w/ Residuals

In [39]:
# Combine FFN embeddings with residuals
x = residual + f

x.shape

torch.Size([8, 4096])

## Stacking the Layers

In [40]:
# Start over from initial token embeddings
x = embeddings(token_values)

x.shape

torch.Size([8, 4096])

In [41]:
# Apply layer logic in a loop
for layer in range(config.n_layers):

    # Load pre-trained state for layer
    load_pretrained_state(layer)

    #
    # Attention
    #

    # Normalize attention inputs
    residual = x
    x = normalize_attention(x)
    
    # Project embeddings to query, key, value spaces
    q = w_q(x)
    k = w_k(x)
    v = w_v(x)
    
    # Split attention heads
    q = split_heads(q, config.n_heads)
    k = split_heads(k, config.n_kv_heads)
    v = split_heads(v, config.n_kv_heads)

    # Encode positions by rotating queries and keys
    q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
    k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)
    
    # Expand key/value groups
    k = k.repeat_interleave(config.n_kv_groups, dim=0)
    v = v.repeat_interleave(config.n_kv_groups, dim=0)
    
    # Sanity check
    assert q.shape == k.shape == v.shape

    # Compute masked attention bias M
    mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
    m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))
    
    # Compute attention for all heads in parallel
    a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v

    # Combine attention heads
    a = combine_heads(a)
    
    # Project attention embeddings back to model space
    a = attention_outputs(a)
    
    # Combine attention embeddings with residuals
    x = residual + a
    
    # Normalize FFN inputs
    residual = x
    x = normalize_ffn(x)

    # Apply FFN
    f = ffn_outputs(silu(ffn_gates(x)) * ffn_inputs(x))
    
    # Combine FFN embeddings with residuals
    x = residual + f

# Head

## Normalize Inputs

In [42]:
# Configure head normalization
normalize_head = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_head, "normalize_head", checkpoint=checkpoint)

In [43]:
# Normalize head inputs
x = normalize_head(x)

x.shape

torch.Size([8, 4096])

## Predict Next Token

In [44]:
# Configure output projection
head_outputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.vocab_size,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(head_outputs, "head_outputs", checkpoint=checkpoint)

In [45]:
# Use last embedding to represent the entire sequence
features = x[-1]

features.shape

torch.Size([4096])

In [50]:
# Predict next token
logits = head_outputs(features)
token_id = logits.argmax()

tokenizer.decode([token_id])

' Boston'

In [51]:
# Verify the answer is Boston
assert token_id == 10406