# Transformer Teardown: Llama 3.1

> Trace an Inference Through Each Layer in a SOTA Transformer Using Llama 3.1 Open Source Foundation Models

In [the last Transformer Teardown](https://stickshift.github.io/2024/09/04/transformer-teardown.html), we dissected a DistilBERT text classification pipeline, tracing a single inference through the entire Transformer stack from raw data to final prediction. We learned about the main stages of a Transformer pipeline as well as fundamental Transformer concepts such as token embeddings and Multi-Head Self Attention. Exploring BERT-based models is a fantastic way to see the core Transformer concepts in action. But BERT was released 6 years ago! ChatGPT wouldn't even exist for another 4 years. It's safe to say a lot has changed since then.

In this post, we'll fast forward to present day. We'll use the same teardown process to unpack the latest [Llama 3.1](https://llama.meta.com/) open source foundation models released by Meta over the summer. We'll break the model down and walk through each step one cell at a time, giving you a close-up view of a modern LLM's inner workings. By the time we're done, you'll leave with a much stronger understanding of the core mechanisms driving the Generative AI revolution.

# Llama Foundation Models

[Llama](https://llama.meta.com/) is a family of general purpose, state-of-the-art open source foundation models from Meta. According to the 3.1 technical report, the latest models can "answer questions in at least 8 languages, write high quality code, solve complex reasoning problems, and use tools in a zero-shot way." (Dubey et al. 2024) 

Llama 3.1 includes 8B, 70B, and 405B sizes. While Meta recommends a cluster with at least 8 GPUs to run the 70B and 405B sizes, the 8B model is small enough to run on a single GPU with 20GB of memory. For more technical stats, see [the official Llama 3.1 model card](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md) in GitHub.

Over the course of this post, we'll implement a complete text generation pipeline using only the research literature, pre-trained weights from the `Meta-Llama3.1-8B-Instruct` checkpoint, and Meta's reference implementation as a guide. In the next section, we'll review the stages of an end-to-end, text generation pipeline. In the sections that follow, we'll walk through a detailed teardown of each stage—tracing an inference from raw data to the first output token. In the last section, we'll put all the pieces together into a custom chatbot capable of generating long form content.

Let the teardown begin!

# Setup

In [1]:
from functools import partial
import json
import math
import os
from pathlib import Path
from sys import stdout
from textwrap import dedent
import warnings

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from pandas import Series
from pydantic import BaseModel, validate_call
from pytest import approx
from tqdm import tqdm

import torch
from torch import nn, Tensor
from torch.nn.functional import relu, silu, softmax

from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import RMSNorm

import stickshift as ss
from stickshift import default_arg, take
from stickshift.models import llama
from stickshift.torch import device as torch_device

In [2]:
# Ignore all warnings
warnings.filterwarnings("ignore")

# Configure gpu
device = torch_device()

In [3]:
%%html
<style>
figure > img {
    display:block;
    margin-left: auto !important;
    margin-right: auto !important;
}
figcaption {
    text-align: center;
}
</style>

In [4]:
%pprint

Pretty printing has been turned OFF


# Text Generation Pipeline

In [the last teardown](https://stickshift.github.io/2024/09/04/transformer-teardown.html), we looked at a text classification Transformer. This time we're going to dissect a *text generation* Transformer. Instead of simply applying a label to the input text, the Head stage will be responsible for *generating* new content. Don't worry—it's not as complicated as it sounds.

Shown below in Figure 1, we can see the stages of a text generation pipeline are the same as the text classification pipeline we saw last time. The Tokenize stage splits raw text into tokens. The Embeddings stage converts individual tokens into embedding vectors. The Context Layers stage augments the input embeddings with contextual signals drawn from the surrounding tokens. Finally, the Head stage converts the contextualized embeddings into task-specific predictions.

<figure>
<img src="transformer-pipeline.svg" width="800">
<figcaption>Figure 1: Text Generation Pipeline</figcaption>
</figure>

The key differences between text generation and text classification are the *task-specific* predictions in the Head stage. While text classification Transformers predict a label for the raw text, *text generation Transformers predict the next token.*

But one token is just the beginning. The magical powers driving the Generative AI explosion come from simply running the predictions in a loop. The next token predicted in each iteration is appended to the end of the input sequence, and the process repeats. Over and over again.

# Model Checkpoint

We'll start by loading the configuration and pre-trained weights for the `Meta-Llama3.1-8B-Instruct` checkpoint. The weights for all Llama checkpoints can be downloaded directly from [Meta](https://llama.meta.com/) and [Hugging Face](https://huggingface.co/meta-llama).

In [5]:
# Load model config
config = llama.config("Meta-Llama3.1-8B-Instruct")

# Load pre-trained model parameters
checkpoint = torch.load(
    config.checkpoint_path / "consolidated.00.pth", 
    weights_only=True, 
    map_location=device,
)

config.model_dump()

{'checkpoint_path': PosixPath('/Users/andrewyoung/.llama/checkpoints/Meta-Llama3.1-8B-Instruct'), 'vocab_size': 128256, 'd_model': 4096, 'd_head': 128, 'd_ffn': 14336, 'n_layers': 32, 'n_heads': 32, 'n_kv_heads': 8, 'n_kv_groups': 4, 'rms_norm_eps': 1e-05, 'rope_theta': 500000.0, 'max_seq_len': 8192, 'temperature': 0.6, 'top_k': 50, 'top_p': 0.9, 'max_output_tokens': 500}

We'll reference a number of the settings in `config` throughout the teardown. For now, a few interesting ones to note are `d_model`, `d_fnn`, `n_layers`, and `n_heads`. These represent the main differences between the 8B, 70B, and 405B sizes.

In [6]:
def load_pretrained_state(layer):    
    # Load pre-trained state
    llama.load_state(
        normalize_attention, "normalize_attention", 
        normalize_ffn, "normalize_ffn", 
        w_q, "w_q", 
        w_k, "w_k", 
        w_v, "w_v", 
        attention_outputs, "attention_outputs",
        ffn_gates, "ffn_gates",
        ffn_inputs, "ffn_inputs",
        ffn_outputs, "ffn_outputs",
        checkpoint=checkpoint,
        layer=layer,
    ) 

# Raw Text

Before we can tear anything down, we need a prompt. Since our goal is to trace an inference from raw text to the first output token, we want to start with a prompt that's specific enough to generate a consistent, one-word answer. If we do everything right, the first output token we predict should be "Boston".

In [7]:
# Prompt
prompt = "<|start_header_id|>user<|end_header_id|>\n\n"
prompt += "What is the capital of Massachusetts? Answer in one word."
prompt += "<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"

You can see `prompt` already includes a number of special tokens. While the chatbot we build at the end will add these for us automatically, we need to manually inject them for now. You can read more about the Llama 3.1 prompt syntax in the [Llama Prompting Guide](https://www.llama.com/docs/how-to-guides/prompting).

# Tokenize

The Tokenize stage splits raw text into tokens using a fixed vocabulary. Llama uses a vocabulary of 128k tokens built on top of [OpenAI's tiktoken](https://github.com/openai/tiktoken) tokenizer. We'll dig into the gory details in the later stages, but here we'll simply use the off-the-shelf Tokenizer from Meta's [llama-models](https://github.com/meta-llama/llama-models) reference implementation.

In [8]:
# Load tokenizer model from checkpoint
tokenizer = Tokenizer(str(config.checkpoint_path / "tokenizer.model"))

In [9]:
# Split raw text into tokens
token_ids = tokenizer.encode(prompt, bos=True, eos=False, allowed_special="all")
token_ids

[128000, 128006, 882, 128007, 271, 3923, 374, 279, 6864, 315, 22108, 30, 22559, 304, 832, 3492, 13, 128009, 128006, 78191, 128007, 271]

In [10]:
len(token_ids)

22

We see `tokenizer.encode` split our prompt into 22 token ids. These ids represent the index of each token in Llama's 128k token vocabulary. We can always reverse the process with `tokenizer.decode`. If you look closely at the cell output below, you'll notice the tokenizer injected another special token `(128000, '<|begin_of_text|>')` to mark the beginning of the sequence.

In [11]:
# Decode token ids back into raw text
tokenizer.decode(token_ids)

'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWhat is the capital of Massachusetts? Answer in one word.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

Our last step is to convert `token_ids` into a PyTorch tensor. The `x` variable represents our "current state". We'll trace `x` through every stage of the pipeline.

In [12]:
# Load token_ids into a tensor
x = torch.tensor(token_ids, device=device)

x.shape

torch.Size([22])

# Embeddings

The Embeddings stage converts individual tokens into embeddings. Embeddings (Bengio et al. 2000) are vectors that represent tokens as unique points in a multi-dimensional vector subspace. Embeddings serve as the fundamental data structure of Transformers. The context layers we'll look at in the next stage take embeddings as input, transform them, and produce embeddings as output. 

In the following cells, we start by creating a lookup table with one embedding for each of the tokens in the Llama vocabulary. Next, we load the pre-trained embedding values from the model checkpoint. Finally, we use the lookup table to map each token value to an embedding vector.

In [13]:
# Initialize embeddings lookup table
embeddings = nn.Embedding(
    num_embeddings=config.vocab_size, 
    embedding_dim=config.d_model,
    device=device,
)

# Load pre-trained state
llama.load_state(embeddings, "embeddings", checkpoint=checkpoint)

In [14]:
# Map token values to embeddings
x = embeddings(x)

x.shape

torch.Size([22, 4096])

From `x.shape` above, you can see the embeddings are represented as row vectors stacked together into an $n \times d_{model}$ tensor.

In [15]:
# Show sample
x

tensor([[ 2.6512e-04, -4.9973e-04, -5.8365e-04,  ...,  3.8147e-03,
          6.3419e-05,  1.1902e-03],
        [-1.6499e-04, -2.4319e-04,  1.6403e-04,  ..., -1.5163e-04,
          3.5095e-04,  7.3242e-04],
        [ 3.5095e-03,  7.2021e-03,  5.3406e-05,  ..., -7.2479e-04,
         -1.0620e-02,  8.2779e-04],
        ...,
        [-9.7656e-03, -3.4637e-03,  1.8616e-03,  ..., -7.1411e-03,
         -4.3030e-03,  8.6060e-03],
        [-4.6158e-04, -3.9291e-04, -6.5863e-06,  ..., -6.2561e-04,
         -5.0354e-04,  6.6757e-04],
        [-2.8687e-03,  3.8910e-03, -1.7357e-04,  ...,  8.0872e-04,
          5.0354e-04,  2.3041e-03]], device='mps:0',
       grad_fn=<EmbeddingBackward0>)

# Context Layers

The Context Layers in a Transformer are responsible for infusing each token embedding with contextual signals drawn from the rest of the sequence. The mechanism works by passing the token embeddings through multiple layers of attention and feedforward blocks. The attention blocks focus on relationships between tokens, augmenting each embedding with a weighted combination of the surrounding embeddings. The feedforward blocks capitalize on the extra context, transforming each augmented embedding using a fully connected neural network. This pattern of attention and transformation is repeated over and over again, gradually converting representations of individual words into representations of abstract semantic concepts over a series of small increments.

<figure>
<img src="transformer-layers.svg" width="800">
<figcaption>Figure 2: Context Layers</figcaption>
</figure>

Figure 2 illustrates the components of each context layer in more detail. The input embeddings are first passed to the Attention block and then the FFN block. Note the Residual Input and Residual Attention paths that are combined with the Attention and FFN outputs. These residuals are critical for providing a stable path for gradient flow during training (He et al. 2015).

## Position Encoding

The relevance of one token to another often depends on their proximity, making token positions critical to all Transformers. While early Transformer models like Vanilla and BERT encoded the absolute token positions directly in the input representation, more recent models have adopted *relative* position encoding schemes shown to perform better on larger sequences. Inspired by GPTNeo, Llama replaces absolute position encoding with the Rotary Position Embedding (RoPE) approach from Su et al. (2021). (Touvron et al. 2023)

### Rotary Position Embedding (RoPE)

One of the reasons embeddings are so powerful is they represent the semantic meaning of abstract concepts in a form that you can do math with. For example, if you create embeddings for the words "king", "queen", "man", and "woman" using Word2Vec, you can show that the closest embedding to $E_{king} - E_{man} + E_{woman}$ is $E_{queen}$. I know that's a bit of a mind bender, but the point of the story is the power of embeddings often comes from the "distance" between them.

While there are multiple ways to measure distance between vectors, Transformers specifically define distance between embeddings as the angle between their vectors. If the angle is small, the embeddings are close to each other. If the angle is large, the embeddings are further apart.

> RoPE works by converting the distance between token positions into angular distance between embedding vectors, fitting perfectly into the Transformer's mental model of "distance".

<figure>
<img src="rope-concept.svg" width="800">
<figcaption>Figure 3: RoPE Concept in 2D</figcaption>
</figure>

Figure 3 illustrates the intuition behind RoPE in 2-dimensions. On the left, we have a sequence of 2-dimensional vectors $\begin{bmatrix}\mathbf{x}_0 & \mathbf{x}_1 & \mathbf{x}_2\end{bmatrix}^T$ with their 2-dimensional geometric interpretation plotted on the right. The matrix in the center represents a rotational transformation. Given an token's position $m$, the idea behind RoPE is to rotate the token's embedding a distance of $m \theta$. Following this idea, $x_0$ would stay the same, $x_1$ would be rotated a distance of $\theta$, and $x_2$ would be rotated $2 \theta$.

While Figure 3 illustrates RoPE conceptually in 2D, implementing it is a little more complicated. In practice, RoPE splits the token embeddings into pairs, e.g. {::nomarkdown}$\Set{(\mathbf{x}_0,\mathbf{x}_1), (\mathbf{x}_2,\mathbf{x}_3), \dots}${:/}, and then applies 2D rotations to each pair using the rotation matrix $\mathbf{R}_{\Theta,m}^d$.

Given a hyperparameter $\Theta$,

$$
\begin{align}
\mathbf{R}_{\Theta,m}^d &= 
\begin{bmatrix}
cos(m \theta_0) & -sin(m \theta_0) & 0 & 0 & \dots & 0 & 0 \\
sin(m \theta_0) & cos(m \theta_0) & 0 & 0 & \dots & 0 & 0 \\
0 & 0 & cos(m \theta_1) & -sin(m \theta_1) & \dots & 0 & 0 \\
0 & 0 & sin(m \theta_1) & cos(m \theta_1) & \dots & 0 & 0 \\
\vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\
0 & 0 & \dots & 0 & 0 & cos(m \theta_{d/2-1}) & -sin(m \theta_{d/2-1}) \\
0 & 0 & \dots & 0 & 0 & sin(m \theta_{d/2-1}) & cos(m \theta_{d/2-1}) \\
\end{bmatrix} \\
\text{where } \theta_i &= \frac{1}{\Theta^{2i/d_{head}}}, \text{and } i \text{ is the embedding index}
\end{align}
$$

Luckily, Su et al. (2021) give us a more efficient approach to apply $\mathbf{R}_{\Theta,m}^d$ in practice.

$$
\mathbf{R}_{\Theta,m}^d \mathbf{x} = 
\begin{bmatrix}
x_1 \\
x_2 \\
x_3 \\
x_4 \\
\vdots \\
x_{d/2-2} \\
x_{d/2-1} \\
\end{bmatrix}
\begin{bmatrix}
cos(m \theta_0) \\
cos(m \theta_0) \\
cos(m \theta_1) \\
cos(m \theta_1) \\
\vdots \\
cos(m \theta_{d/2-1}) \\
cos(m \theta_{d/2-1}) \\
\end{bmatrix}
+
\begin{bmatrix}
-x_2 \\
x_1 \\
-x_4 \\
x_3 \\
\vdots \\
-x_{d/2-1} \\
x_{d/2-2} \\
\end{bmatrix}
\begin{bmatrix}
sin(m \theta_0) \\
sin(m \theta_0) \\
sin(m \theta_1) \\
sin(m \theta_1) \\
\vdots \\
sin(m \theta_{d/2-1}) \\
sin(m \theta_{d/2-1}) \\
\end{bmatrix}
$$

Which finally brings us back to the teardown. Our goal here is to calculate the $cos$ and $sin$ vectors in the equation above. Since these are only a function of the shape of the token embeddings, we will reuse them across all the context layers.

In [16]:
def rope(n):
    # Hyperparams
    base = config.rope_theta
    d = config.d_head
    
    # Compute theta_i = 1 / base^(2i/d) from i = 0 to d/2-1
    thetas = 1.0 / base**(2 * torch.arange(d // 2, device=device) / d)
    
    # Compute m * theta_i for position m in 0 to n
    frequencies = torch.stack([m*thetas for m in range(n)])
    
    # Duplicate each row
    frequencies = torch.cat((frequencies, frequencies), dim=-1)
    
    # Apply cos, sin
    cos = torch.cos(frequencies)
    sin = torch.sin(frequencies)
    
    # Sanity check
    assert cos.shape[0] == n and cos.shape[1] == config.d_head
    assert sin.shape[0] == n and sin.shape[1] == config.d_head

    return cos, sin

In [17]:
# Compute RoPE rotation matrices
rope_cos, rope_sin = rope(len(x))

rope_cos.shape, rope_sin.shape

(torch.Size([22, 128]), torch.Size([22, 128]))

## Attention

**[Intro attention algorithm]**

### Normalize Attention Inputs

One difference between Llama and Vanilla architectures is in the approach to normalization. While Vanilla normalizes the outputs of each block, Llama uses a pre-normalization approach inspired by GPT3 (Touvron et al. 2023). In addition, Llama has replaced the LayerNorm algorithm with the less computationally expensive RMSNorm algorithm from Zhang and Sennrich (2019) (Touvron et al. 2023).

In [18]:
# Configure attention normalization
normalize_attention = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_attention, "normalize_attention", checkpoint=checkpoint)

In [19]:
# Normalize attention inputs
residual = x
x = normalize_attention(x)

x.shape

torch.Size([22, 4096])

### Project Queries, Keys, Values

**[Introduce GQA]**

In [20]:
# Configure query, key, value projections
w_q = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_heads * config.d_head,
    bias=False,
    device=device,
)
w_k = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)
w_v = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(w_q, "w_q", w_k, "w_k", w_v, "w_v", checkpoint=checkpoint)

In [21]:
# Project embeddings to query, key, value spaces
q = w_q(x)
k = w_k(x)
v = w_v(x)

q.shape, k.shape, v.shape

(torch.Size([22, 4096]), torch.Size([22, 1024]), torch.Size([22, 1024]))

**[Discuss GQA]**

### Split Attention Heads

In [22]:
def split_heads(x, n_heads):
    return x.view(-1, n_heads, config.d_head).transpose(-3, -2)

In [23]:
# Split attention heads
q = split_heads(q, config.n_heads)
k = split_heads(k, config.n_kv_heads)
v = split_heads(v, config.n_kv_heads)

q.shape, k.shape, v.shape

(torch.Size([32, 22, 128]), torch.Size([8, 22, 128]), torch.Size([8, 22, 128]))

### Encode Positions (RoPE)

In [24]:
# Encode positions by rotating queries and keys
q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)

q.shape, k.shape, v.shape

(torch.Size([32, 22, 128]), torch.Size([8, 22, 128]), torch.Size([8, 22, 128]))

### Expand Key / Value Groups (GQA)

In [25]:
# Expand key/value groups
k = k.repeat_interleave(config.n_kv_groups, dim=0)
v = v.repeat_interleave(config.n_kv_groups, dim=0)

q.shape, k.shape, v.shape

(torch.Size([32, 22, 128]), torch.Size([32, 22, 128]), torch.Size([32, 22, 128]))

In [26]:
# Sanity check
assert q.shape == k.shape == v.shape

### Calculate Attention


In [27]:
# Compute attention mask M
n = len(x)
mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))

m

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -in

In [28]:
# Compute attention for all heads in parallel
a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v

a.shape

torch.Size([32, 22, 128])

### Recombine Attention Heads

In [29]:
def combine_heads(x):
    return x.transpose(-3, -2).contiguous().view(-1, int(config.n_heads * config.d_head))

In [30]:
# Combine attention heads
a = combine_heads(a)

a.shape

torch.Size([22, 4096])

### Project Attention Outputs

In [31]:
# Configure attention output projection
attention_outputs = nn.Linear(
    in_features=config.d_model, 
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(attention_outputs, "attention_outputs", checkpoint=checkpoint)

In [32]:
# Project attention embeddings back to model space
a = attention_outputs(a)

a.shape

torch.Size([22, 4096])

### Combine w/ Residuals

In [33]:
# Combine attention embeddings with residuals
x = residual + a

x.shape

torch.Size([22, 4096])

## FFN

### Normalize FFN Inputs

In [34]:
# Configure FFN normalization
normalize_ffn = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained state
llama.load_state(normalize_ffn, "normalize_ffn", checkpoint=checkpoint)

In [35]:
# Normalize FFN inputs
residual = x
x = normalize_ffn(x)

x.shape

torch.Size([22, 4096])

### Transform

In [36]:
# Configure SwiGLU FFN
ffn_gates = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)
ffn_inputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(ffn_gates, "ffn_gates", ffn_inputs, "ffn_inputs", checkpoint=checkpoint)

In [37]:
# Apply transform
f = silu(ffn_gates(x)) * ffn_inputs(x)

f.shape

torch.Size([22, 14336])

### Project FFN Outputs

In [38]:
# Configure FFN output projection
ffn_outputs = nn.Linear(
    in_features=config.d_ffn,
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(ffn_outputs, "ffn_outputs", checkpoint=checkpoint)

In [39]:
# Project FFN embeddings back to model space
f = ffn_outputs(f)

f.shape

torch.Size([22, 4096])

### Combine w/ Residuals

In [40]:
# Combine FFN embeddings with residuals
x = residual + f

x.shape

torch.Size([22, 4096])

## Stacking the Layers

Now that we've gone through each step, let's put all the pieces together.

In [41]:
def context_layers(x):
    # Compute RoPE rotation matrices
    rope_cos, rope_sin = rope(len(x))

    # Apply layer logic in a loop
    for layer in range(config.n_layers):
    
        # Load pre-trained state for layer
        load_pretrained_state(layer)
    
        #
        # Attention
        #
    
        # Normalize attention inputs
        residual = x
        x = normalize_attention(x)
        
        # Project embeddings to query, key, value spaces
        q = w_q(x)
        k = w_k(x)
        v = w_v(x)
        
        # Split attention heads
        q = split_heads(q, config.n_heads)
        k = split_heads(k, config.n_kv_heads)
        v = split_heads(v, config.n_kv_heads)
    
        # Encode positions by rotating queries and keys
        q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
        k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)
        
        # Expand key/value groups
        k = k.repeat_interleave(config.n_kv_groups, dim=0)
        v = v.repeat_interleave(config.n_kv_groups, dim=0)
    
        # Compute masked attention bias M
        n = len(x)
        mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
        m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))
        
        # Compute attention for all heads in parallel
        a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v
    
        # Combine attention heads
        a = combine_heads(a)
        
        # Project attention embeddings back to model space
        a = attention_outputs(a)
        
        # Combine attention embeddings with residuals
        x = residual + a
        
        #
        # FFN
        #
    
        # Normalize FFN inputs
        residual = x
        x = normalize_ffn(x)
    
        # Apply transform
        f = silu(ffn_gates(x)) * ffn_inputs(x)
    
        # Project FFN embeddings back to model space
        f = ffn_outputs(f)
        
        # Combine FFN embeddings with residuals
        x = residual + f

    return x

In [42]:
# Start over from initial tokens
x = torch.tensor(token_ids, device=device)

# Initial embeddings
x = embeddings(x)

# Contextualized embeddings
x = context_layers(x)

x.shape

torch.Size([22, 4096])

In [43]:
x

tensor([[ 0.8842,  1.9047,  1.0641,  ..., -1.3221,  2.1526,  1.3637],
        [ 0.5711, -0.4364, -0.1372,  ..., -0.0924, -0.2381, -0.1401],
        [-0.2568, -0.5273, -0.4703,  ...,  0.2887,  1.0332, -1.0900],
        ...,
        [-0.0088, -0.0946, -0.1153,  ...,  0.2482,  0.1270, -0.0051],
        [ 0.2867, -0.0427,  0.3964,  ..., -0.1481,  0.2981, -0.3256],
        [-0.9179,  0.6267,  0.4772,  ...,  0.0447,  1.8657, -0.3242]],
       device='mps:0', grad_fn=<AddBackward0>)

# Head

## Normalize Head Inputs

In [44]:
# Configure head normalization
normalize_head = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_head, "normalize_head", checkpoint=checkpoint)

In [45]:
# Normalize head inputs
x = normalize_head(x)

x.shape

torch.Size([22, 4096])

## Project Head Outputs

In [46]:
# Configure output projection
head_outputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.vocab_size,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(head_outputs, "head_outputs", checkpoint=checkpoint)

In [47]:
# Use last embedding to represent the entire sequence
x = x[-1]

# Project outputs to token space
x = head_outputs(x)

x.shape

torch.Size([128256])

## Top Token

In [48]:
# Select top scoring token
token_id = x.argmax()

# Decode token
token = tokenizer.decode([token_id]).strip()

token

'Boston'

In [49]:
# Verify answer
assert token == "Boston"

## Sampling

### Temperature

In [50]:
# Hyperparameters
temperature = config.temperature
temperature

0.6

In [51]:
# Apply temperature
x = x / temperature

### Ranking

In [52]:
# Convert logits to probabilities
probs = softmax(x)

# Sort probabilities in descending order
probs, indices = probs.sort(descending=True)

### Top K

In [53]:
# Hyperparameters
top_k = config.top_k
top_k

50

In [54]:
# Retain top k tokens
probs = probs[:top_k]
print(f"Retained {len(probs)} of {len(x)}")

Retained 50 of 128256


### Top P

In [55]:
# Hyperparameters
top_p = config.top_p
top_p

0.9

In [56]:
# Find cutoff where cumulative probability exceeds top_p
cumulative_mask = probs.cumsum(dim=-1) > top_p
threshold_index = torch.argmax(cumulative_mask).item()

# Only apply threshold if top_p was exceeded
if cumulative_mask.any():
    probs = probs[:threshold_index+1]

print(f"Retained {len(probs)} of {len(x)}")

Retained 2 of 128256


### Random Selection

In [57]:
# Print remaining token pool
for i, prob in enumerate(probs):
    print(f"token id {indices[i]}, token '{tokenizer.decode([indices[i]])}', score {prob:0.3f}")

token id 65432, token 'Boston', score 0.661
token id 791, token 'The', score 0.338


In [58]:
# Sample from remaining tokens weighted by probability
sampled_index = torch.multinomial(probs, 1)

# Convert sampled_index to original logits
token_id = indices[sampled_index]

# Decode token
token = tokenizer.decode([token_id]).strip()

token

'Boston'

## Complete Head Stage

In [59]:
def head(x):
    # Normalize head inputs
    x = normalize_head(x)
    
    # Use last embedding to represent the entire sequence
    x = x[-1]
    
    # Project outputs to token space
    x = head_outputs(x)

    #
    # Temperature
    #
    
    # Apply temperature
    x = x / config.temperature

    #
    # Ranking
    #
    
    # Convert logits to probabilities
    probs = softmax(x)
    
    # Sort probabilities in descending order
    probs, indices = probs.sort(descending=True)

    #
    # Top K
    #
    
    # Retain top k tokens
    probs = probs[:config.top_k]

    #
    # Top P
    #
    
    # Find cutoff where cumulative probability exceeds top_p
    cumulative_mask = probs.cumsum(dim=-1) > config.top_p
    threshold_index = torch.argmax(cumulative_mask).item()
    
    # Only apply threshold if top_p was exceeded
    if cumulative_mask.any():
        probs = probs[:threshold_index+1]

    #
    # Random Selection
    #
    
    # Sample from remaining tokens weighted by probability
    sampled_index = torch.multinomial(probs, 1)
    
    # Convert sampled_index to original logits
    token_id = indices[sampled_index]

    return token_id.item()

# Generator

In [60]:
class Message(BaseModel):
    role: str
    content: str

In [61]:
@validate_call
def prepare_messages(messages: list[Message]):
    # Initialize prompt
    prompt = ""
    
    # Format each message
    for message in messages:
        prompt += f"<|start_header_id|>{message.role}<|end_header_id|>\n\n"
        prompt += message.content
        prompt += "<|eot_id|>"

    # Finish with the assistant role to prime the model's response
    prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"

    return prompt

In [62]:
@validate_call
def generate(messages: list[Message]):
    # Format message prompt
    prompt = prepare_messages(messages)
    
    # Split raw text into tokens
    token_ids = tokenizer.encode(prompt, bos=True, eos=False, allowed_special="all")
    
    # Generate output until we get a stop token or we exceed max_output_tokens.
    for _ in range(config.max_output_tokens):
        
        # Start over from initial tokens
        x = torch.tensor(token_ids, device=device)
        
        # Initial embeddings
        x = embeddings(x)
        
        # Contextualized embeddings
        x = context_layers(x)
        
        # Head
        token_id = head(x)
        
        # Check stopping criteria
        if token_id in tokenizer.stop_tokens:
            break
    
        # Print token
        token = tokenizer.decode([token_id])
        stdout.write(token)
        
        # Append to end of sequence
        token_ids.append(token_id)

In [63]:
generate([
    {
        "role": "user",
        "content": "What is capital of Massachusetts?",
    },
])

The capital of Massachusetts is Boston.