# Transformer Teardown: Llama 3

In [the last Transformer Teardown](https://stickshift.github.io/2024/09/04/transformer-teardown.html), we dissected a DistilBERT text classification pipeline, tracing a single inference through the entire Transformer stack from raw data to final prediction. We learned about the main stages of a Transformer pipeline as well as fundamental Transformer concepts such as token embeddings and Multi-Head Self Attention.

Exploring BERT-based models is a fantastic way to see the core ideas from Vaswani et al.'s original Transformer paper in action. But BERT was released 6 years ago! It would be another 4 years before ChatGPT would even exist! It's safe to say a lot has changed since then.

The goal of this post is to fast forward to present day. We'll use the same teardown process to unpack the latest [Llama 3.1](https://llama.meta.com/) foundation models released by Meta last month. By comparing Llama 3.1 (Dubey et al. 2024) with Vaswani et al.'s *Vanilla* Transformer architecture (Vaswani et al. 2017), we'll see how the Transformer architecture has grown up—what's changed and what hasn't—over the past 7 years of AI revolution.

# Setup

In [1]:
from functools import partial
import json
import math
import os
from pathlib import Path
import warnings

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from pandas import Series
from pytest import approx

import torch
from torch import nn
from torch.nn.functional import relu, silu, softmax

from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import RMSNorm

import stickshift as ss
from stickshift import default_arg, take
from stickshift.models import llama

In [2]:
# Ignore all warnings
warnings.filterwarnings("ignore")

# Configure gpu
device = ss.torch.device()

In [3]:
%%html
<style>
.stickshift-figure {
    display: block;
    margin-left: auto !important;
    margin-right: auto !important;
}
</style>

# Text Generation with Llama 3

In our last post, we looked at a `text-classification` Transformer where the Head stage of the pipeline used the contextualized token embeddings as features in a binary classifier. Text classification Transformers are an ideal place to start because they're the simplest and most familiar. This time we're going to dissect a `text-generation` Transformer. Instead of simply applying a label to the input text, the Head stage in a text generation pipeline is responsible for *generating* new content. Don't worry—it's not as complicated as it sounds. By the end of this teardown, you'll have a solid grasp of how all the elements in a state-of-the-art generative model fit together.

[Llama 3](https://llama.meta.com/) is a set of foundation models released by Meta over the summer that can "answer questions in at least 8 languages, write high quality code, solve complex reasoning problems, and use tools in a zero-shot way." (Dubey et al. 2024) The Llama 3 release includes 8B, 70B, and 405B sizes. While Meta recommends a cluster with at least 8 GPUs to run the 70B and 405B models, you can run the 8B model on a single GPU w/ 20GB of memory.

The goal of this teardown is to walk through each step in a Llama 3 text generation pipeline using only the research literature, weights from the `Meta-Llama3.1-8B-Instruct` checkpoint, and Meta's reference implementation as a guide. Along the way, we'll learn about the core improvements used by modern generative Transformers including:

* Pre-normalization with RMSNorm (Brown et al. 2020), (Zhang and Sennrich 2019)
* Rotary Position Encoding (RoPE) (Su et al. 2021)
* SwiGLU Activation (Shazeer 2020)
* Grouped Query Attention (GQA) (Ainslie et al. 2023)

# Prompt

Before we jump in, we need a prompt! This is a text-generation tutorial after all.

In [4]:
# Prompt
prompt = "Write a haiku"

# Checkpoint: 8B-Instruct

We'll start by loading the hyperparameters and pre-trained weights for the `Meta-Llama3.1-8B-Instruct` checkpoint. The weights for all Llama checkpoints can be downloaded directly from [Meta](https://llama.meta.com/) and [Hugging Face](https://huggingface.co/meta-llama).

In [5]:
# Load model config
config = llama.config("Meta-Llama3.1-8B-Instruct")

# Load model parameters
checkpoint = torch.load(config.checkpoint_path / "consolidated.00.pth", weights_only=True, map_location=device)

config.model_dump()

{'checkpoint_path': PosixPath('/Users/andrewyoung/.llama/checkpoints/Meta-Llama3.1-8B-Instruct'),
 'vocab_size': 128256,
 'd_model': 4096,
 'd_head': 128,
 'd_ffn': 14336,
 'n_layers': 32,
 'n_heads': 32,
 'n_kv_heads': 8,
 'n_kv_groups': 4,
 'rms_norm_eps': 1e-05,
 'rope_theta': 500000.0}

In [6]:
def load_pretrained_state(layer):    
    # Load pre-trained state
    llama.load_state(
        normalize_attention, "normalize_attention", 
        normalize_ffn, "normalize_ffn", 
        w_q, "w_q", 
        w_k, "w_k", 
        w_v, "w_v", 
        attention_outputs, "attention_outputs",
        ffn_gates, "ffn_gates",
        ffn_inputs, "ffn_inputs",
        ffn_outputs, "ffn_outputs",
        checkpoint=checkpoint,
        layer=layer,
    ) 

# Transformer Pipeline

The stages of a `text-generation` pipeline are the same as the `text-classification` pipeline we saw last time. The Tokenize stage splits raw text into tokens. The Embeddings stage converts individual tokens into embedding vectors. The Context Layers augment the input embeddings with contextual signals drawn from the surrounding tokens. The Head stage converts the contextualized embeddings into task-specific predictions.

The key differences between text generation and text classification are the *task-specific* predictions. While text classification Transformers predict a label for the raw text, text generation Transformers predict the next token. The predicted token is appended to the end of the input sequence, and the process repeats.

<img src="transformer-pipeline.svg" class="stickshift-figure" width="800">

# Tokenize

The Tokenize stage splits raw text into tokens using a fixed vocabulary. Llama 3 uses a vocabulary of 128k tokens built on top of [OpenAI's tiktoken](https://github.com/openai/tiktoken) tokenizer. We'll dig into the details of the later stages but here we'll simply use the off-the-shelf Tokenizer from Meta's [llama-models](https://github.com/meta-llama/llama-models) reference implementation.

In [7]:
# Load tokenizer model from checkpoint
tokenizer = Tokenizer(str(config.checkpoint_path / "tokenizer.model"))

In [8]:
# Split raw text into tokens
token_ids = tokenizer.encode(prompt, bos=True, eos=False)
token_ids

[128000, 8144, 264, 6520, 39342]

In [9]:
# Decode token ids back into raw text
tokenizer.decode(token_ids)

'<|begin_of_text|>Write a haiku'

# Embeddings

The Embeddings stage converts individual tokens into embedding vectors. 

In [10]:
# Initialize embeddings lookup table
embeddings = nn.Embedding(
    num_embeddings=config.vocab_size, 
    embedding_dim=config.d_model,
    device=device,
)

# Load pre-trained weights
llama.load_state(embeddings, "embeddings", checkpoint=checkpoint)

In [11]:
# Load token_ids into a tensor
token_values = torch.tensor(token_ids, device=device)

token_values.shape

torch.Size([5])

In [12]:
# Record sequence length n
n = len(token_values)

n

5

In [13]:
# Map token values to embeddings
x = embeddings(token_values)

x.shape

torch.Size([5, 4096])

In [14]:
# Show sample
x

tensor([[ 2.6512e-04, -4.9973e-04, -5.8365e-04,  ...,  3.8147e-03,
          6.3419e-05,  1.1902e-03],
        [ 1.0925e-02, -3.4943e-03,  1.8997e-03,  ..., -1.0437e-02,
         -5.5542e-03, -1.0864e-02],
        [-1.3199e-03, -6.3324e-04, -8.8882e-04,  ..., -1.2329e-02,
         -4.8218e-03,  6.7353e-06],
        [-1.1841e-02, -4.9133e-03,  7.2021e-03,  ..., -1.1683e-04,
          6.4087e-03,  5.0964e-03],
        [-9.2773e-03,  1.1353e-02,  2.1729e-02,  ...,  1.8066e-02,
          9.7656e-04,  4.0588e-03]], device='mps:0',
       grad_fn=<EmbeddingBackward0>)

# Context Layers

The Context Layers in a Transformer are responsible for infusing each token embedding with contextual signals drawn from the rest of the sequence. The mechanism works by passing the token embeddings through multiple layers of attention and feedforward blocks. The attention blocks focus on relationships between tokens, augmenting each embedding with contextual information drawn from the surrounding embeddings. The feedforward blocks capitalize on the extra context, transforming each augmented embedding using a fully connected neural network. This pattern of attention and transformation is repeated over and over again, gradually converting representations of individual words into representations of abstract semantic concepts over a series of small increments.

<img src="transformer-layers.svg" class="stickshift-figure" width="800">

## Attention

### Normalize Inputs

In [15]:
# Configure attention normalization
normalize_attention = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_attention, "normalize_attention", checkpoint=checkpoint)

In [16]:
# Normalize attention inputs
residual = x
x = normalize_attention(x)

x.shape

torch.Size([5, 4096])

### Project Queries, Keys, Values

In [17]:
# Configure query, key, value projections
w_q = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_heads * config.d_head,
    bias=False,
    device=device,
)
w_k = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)
w_v = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(w_q, "w_q", w_k, "w_k", w_v, "w_v", checkpoint=checkpoint)

In [18]:
# Project embeddings to query, key, value spaces
q = w_q(x)
k = w_k(x)
v = w_v(x)

q.shape, k.shape, v.shape

(torch.Size([5, 4096]), torch.Size([5, 1024]), torch.Size([5, 1024]))

### Split Attention Heads

In [19]:
def split_heads(x, n_heads):
    return x.view(-1, n_heads, config.d_head).transpose(-3, -2)

In [20]:
# Split attention heads
q = split_heads(q, config.n_heads)
k = split_heads(k, config.n_kv_heads)
v = split_heads(v, config.n_kv_heads)

q.shape, k.shape, v.shape

(torch.Size([32, 5, 128]), torch.Size([8, 5, 128]), torch.Size([8, 5, 128]))

### Encode Positions (RoPE)

In [21]:
# Compute rope_cos and rope_sin
base = config.rope_theta
d = config.d_head

# Compute theta_i = 1 / base^(2i/d) from i = 0 to d/2-1
thetas = 1.0 / base**(2 * torch.arange(d // 2, device=device) / d)

# Compute m * theta_i for position m in 0 to n
frequencies = torch.stack([m*thetas for m in range(n)])

# Duplicate each row
frequencies = torch.cat((frequencies, frequencies), dim=-1)

# Apply cos, sin
rope_cos = torch.cos(frequencies)
rope_sin = torch.sin(frequencies)

rope_cos.shape, rope_sin.shape

(torch.Size([5, 128]), torch.Size([5, 128]))

In [22]:
# Sanity check
assert rope_cos.shape[0] == n and rope_cos.shape[1] == config.d_head
assert rope_sin.shape[0] == n and rope_sin.shape[1] == config.d_head

In [23]:
# Encode positions by rotating queries and keys
q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)

q.shape, k.shape, v.shape

(torch.Size([32, 5, 128]), torch.Size([8, 5, 128]), torch.Size([8, 5, 128]))

### Expand Key / Value Groups (GQA)

In [24]:
# Expand key/value groups
k = k.repeat_interleave(config.n_kv_groups, dim=0)
v = v.repeat_interleave(config.n_kv_groups, dim=0)

q.shape, k.shape, v.shape

(torch.Size([32, 5, 128]), torch.Size([32, 5, 128]), torch.Size([32, 5, 128]))

In [25]:
# Sanity check
assert q.shape == k.shape == v.shape

### Calculate Attention


In [26]:
# Compute masked attention bias M
mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))

m

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]], device='mps:0')

In [27]:
# Compute attention for all heads in parallel
a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v

a.shape

torch.Size([32, 5, 128])

### Recombine Attention Heads

In [28]:
def combine_heads(x):
    return x.transpose(-3, -2).contiguous().view(-1, int(config.n_heads * config.d_head))

In [29]:
# Combine attention heads
a = combine_heads(a)

a.shape

torch.Size([5, 4096])

### Project Outputs

In [30]:
# Configure attention output projection
attention_outputs = nn.Linear(
    in_features=config.d_model, 
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(attention_outputs, "attention_outputs", checkpoint=checkpoint)

In [31]:
# Project attention embeddings back to model space
a = attention_outputs(a)

a.shape

torch.Size([5, 4096])

### Combine w/ Residuals

In [32]:
# Combine attention embeddings with residuals
x = residual + a

x.shape

torch.Size([5, 4096])

## FFN

### Normalize Inputs

In [33]:
# Configure FFN normalization
normalize_ffn = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained state
llama.load_state(normalize_ffn, "normalize_ffn", checkpoint=checkpoint)

In [34]:
# Normalize FFN inputs
residual = x
x = normalize_ffn(x)

x.shape

torch.Size([5, 4096])

### Transform

In [35]:
# Configure SwiGLU FFN
ffn_gates = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)
ffn_inputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(ffn_gates, "ffn_gates", ffn_inputs, "ffn_inputs", checkpoint=checkpoint)

In [36]:
# Apply FFN
f = silu(ffn_gates(x)) * ffn_inputs(x)

f.shape

torch.Size([5, 14336])

### Project Outputs

In [37]:
# Configure FFN output projection
ffn_outputs = nn.Linear(
    in_features=config.d_ffn,
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(ffn_outputs, "ffn_outputs", checkpoint=checkpoint)

In [38]:
# Project FFN embeddings back to model space
f = ffn_outputs(f)

f.shape

torch.Size([5, 4096])

### Combine w/ Residuals

In [39]:
# Combine FFN embeddings with residuals
x = residual + f

x.shape

torch.Size([5, 4096])

## Stacking the Layers

In [40]:
# Start over from initial token embeddings
x = embeddings(token_values)

x.shape

torch.Size([5, 4096])

In [41]:
# Apply layer logic in a loop
for layer in range(config.n_layers):

    # Load pre-trained state for layer
    load_pretrained_state(layer)

    #
    # Attention
    #

    # Normalize attention inputs
    residual = x
    x = normalize_attention(x)
    
    # Project embeddings to query, key, value spaces
    q = w_q(x)
    k = w_k(x)
    v = w_v(x)
    
    # Split attention heads
    q = split_heads(q, config.n_heads)
    k = split_heads(k, config.n_kv_heads)
    v = split_heads(v, config.n_kv_heads)

    # Encode positions by rotating queries and keys
    q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
    k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)
    
    # Expand key/value groups
    k = k.repeat_interleave(config.n_kv_groups, dim=0)
    v = v.repeat_interleave(config.n_kv_groups, dim=0)

    # Compute masked attention bias M
    mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
    m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))
    
    # Compute attention for all heads in parallel
    a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v

    # Combine attention heads
    a = combine_heads(a)
    
    # Project attention embeddings back to model space
    a = attention_outputs(a)
    
    # Combine attention embeddings with residuals
    x = residual + a
    
    #
    # FFN
    #

    # Normalize FFN inputs
    residual = x
    x = normalize_ffn(x)

    # Apply FFN
    f = silu(ffn_gates(x)) * ffn_inputs(x)

    # Project FFN embeddings back to model space
    f = ffn_outputs(f)
    
    # Combine FFN embeddings with residuals
    x = residual + f

# Head

## Normalize Inputs

In [42]:
# Configure head normalization
normalize_head = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_head, "normalize_head", checkpoint=checkpoint)

In [43]:
# Normalize head inputs
x = normalize_head(x)

x.shape

torch.Size([5, 4096])

## Predict Next Token

In [44]:
# Configure output projection
head_outputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.vocab_size,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(head_outputs, "head_outputs", checkpoint=checkpoint)

In [45]:
# Use last embedding to represent the entire sequence
features = x[-1]

features.shape

torch.Size([4096])

In [46]:
# Predict next token
logits = head_outputs(features)
token_id = logits.argmax()

tokenizer.decode([token_id])

' about'