# Transformer Teardown: Llama 3

> Trace an Inference from Raw Data to Prediction Through Llama 3 text-generation Pipeline

In [the last Transformer Teardown](https://stickshift.github.io/2024/09/04/transformer-teardown.html), we dissected a BERT-based text classification pipeline, tracing a single inference through the entire stack from raw data to final prediction. While no longer state-of-the-art, we started with BERT to build a strong foundation in Transformer fundamentals. In this post, we'll apply what we learned to unpack [the latest Llama 3 models](https://llama.meta.com/) released by Meta last month. 

Following the same teardown process as last time, we'll start by creating an off-the-shelf Llama 3 `text-generation` pipeline, apply the pipeline to generate an answer to an arbitrary question, and then manually regenerate the same answer step by step, using clear, minimal Python code.

# Setup

In [1]:
from functools import partial
import math
import os
import sys
import warnings

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from pandas import Series
from pytest import approx
import torch
from torch import nn
from torch.nn.functional import relu, softmax
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from stickshift import default_arg, take
from stickshift.models import llama

In [2]:
# Ignore all warnings
warnings.filterwarnings("ignore")

# Disable stderr
# sys.stderr = open(os.devnull, 'w')

# Configure gpu
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")

# Text Generation with Llama 3

We start by using Hugging Face's `pipeline` API to create an end-to-end text generation pipeline using Llama 3 8B model, and then use the model to answer the question "What is the capital of Massachusetts?" Our goal throughout the rest of the post will be to recreate the steps from raw text to the answer "Boston". 

In [3]:
# Create off-the-shelf Llama 3 text generation pipeline
transformer = transformers.pipeline(
    "text-generation", 
    model="meta-llama/Meta-Llama-3.1-8B-Instruct", 
    device=device,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
# Answer question
transformer("What is the capital of Massachusetts?")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


[{'generated_text': 'What is the capital of Massachusetts? Boston\nWhat is the population of Massachusetts? 6.'}]

# Model Config

In [6]:
# Load model config and pre-trained parameters
config = llama.config(transformer.model)
parameters = transformer.model.state_dict()
llama_model = transformer.model.model

In [7]:
[k for k in parameters]

['model.embed_tokens.weight',
 'model.layers.0.self_attn.q_proj.weight',
 'model.layers.0.self_attn.k_proj.weight',
 'model.layers.0.self_attn.v_proj.weight',
 'model.layers.0.self_attn.o_proj.weight',
 'model.layers.0.mlp.gate_proj.weight',
 'model.layers.0.mlp.up_proj.weight',
 'model.layers.0.mlp.down_proj.weight',
 'model.layers.0.input_layernorm.weight',
 'model.layers.0.post_attention_layernorm.weight',
 'model.layers.1.self_attn.q_proj.weight',
 'model.layers.1.self_attn.k_proj.weight',
 'model.layers.1.self_attn.v_proj.weight',
 'model.layers.1.self_attn.o_proj.weight',
 'model.layers.1.mlp.gate_proj.weight',
 'model.layers.1.mlp.up_proj.weight',
 'model.layers.1.mlp.down_proj.weight',
 'model.layers.1.input_layernorm.weight',
 'model.layers.1.post_attention_layernorm.weight',
 'model.layers.2.self_attn.q_proj.weight',
 'model.layers.2.self_attn.k_proj.weight',
 'model.layers.2.self_attn.v_proj.weight',
 'model.layers.2.self_attn.o_proj.weight',
 'model.layers.2.mlp.gate_proj.w

In [8]:
def load_state(*args, layer=None):
    # Defaults
    layer = default_arg(layer, lambda: 0)

    for module, key in take(2, args):
        match key:
            case "value_embeddings":
                module.load_state_dict({
                    "weight": parameters["model.embed_tokens.weight"],
                })
            case "normalize_inputs":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.input_layernorm.weight"],
                })
            case "queries":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.q_proj.weight"],
                })
            case "keys":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.k_proj.weight"],
                })
            case "values":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.v_proj.weight"],
                })                
            case "attention_outputs":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.o_proj.weight"],
                })                
            case "normalize_attention":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.post_attention_layernorm.weight"],
                })
            case "gate":
                module.load_state_dict({
                    "0.weight": parameters[f"model.layers.{layer}.mlp.gate_proj.weight"],
                })
            case "up":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.mlp.up_proj.weight"],
                })
            case "down":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.mlp.down_proj.weight"],
                })
            case "normalize_context":
                module.load_state_dict({
                    "weight": parameters["model.norm.weight"],
                })
            case "classifier":
                module.load_state_dict({
                    "weight": parameters["lm_head.weight"],
                })
            case _:
                raise ValueError(f"Unexpected key {key}")


def load_pretrained_state(layer):    
    # Load pre-trained state
    load_state(
        normalize_inputs, "normalize_inputs", 
        queries, "queries", 
        keys, "keys", 
        values, "values", 
        attention_outputs, "attention_outputs", 
        normalize_attention, "normalize_attention",
        gate, "gate",
        up, "up",
        down, "down",
        layer=layer,
    )                

In [9]:
def compare_embeddings(t, llama_t):

    errors = []

    with torch.no_grad():
        # Move both tensors to cpu
        t = t.to("cpu")
        llama_t = llama_t.to("cpu")
    
        # Squeeze llama
        llama_t = llama_t.squeeze()
        assert t.shape == llama_t.shape

        # Reshape both to be 1 long list of embeddings
        t = t.reshape(-1, t.shape[-1])
        llama_t = llama_t.reshape(-1, llama_t.shape[-1])
        assert t.shape == llama_t.shape

        # Compare each embedding
        for i in range(t.shape[0]):
            e1 = t[i]
            e2 = llama_t[i]
            score = torch.dot(e1, e2) / torch.norm(e2)**2
            error = 1.0 - score
            errors.append(error.abs().item())

    return Series(errors)

# Transformer Pipeline

Llama 3 follows the same multi-stage Transformer pipeline that we saw with BERT. We'll walk through each of the stages, illustrating each of the steps in Llama 3. Using BERT as a baseline, we will focus the discussion on the key changes in Llama 3. For more detail on the fundamentals, please see [the previous Transformer Teardown](https://stickshift.github.io/2024/09/04/transformer-teardown.html).

<img src="transformer-pipeline.svg" class="stickshift-figure" width="800">

# Tokenize

The Tokenize stage is responsible for breaking raw data into a sequence of "tokens". The following cells use an algorithm known as Byte Pair Encoding (BPE) (REFERENCE) to split the sentence "What is the capital of Massachusetts?" into token ids `[128000, 3923, 374, 279, 6864, 315, 22108, 30]`.

Since our primary interest is in the Transformer layers that come later, we'll use Hugging Face's off-the-shelf tokenizer implementation here.

In [10]:
# Extract tokenizer from transformer
tokenizer = transformer.tokenizer

In [11]:
# Tokenize sentence
batch = tokenizer("What is the capital of Massachusetts?", return_tensors="pt")

batch

{'input_ids': tensor([[128000,   3923,    374,    279,   6864,    315,  22108,     30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

We can also use the tokenizer to decode the token ids back into raw text.

In [12]:
[tokenizer.decode(input_id) for input_id in batch.input_ids[0]]

['<|begin_of_text|>',
 'What',
 ' is',
 ' the',
 ' capital',
 ' of',
 ' Massachusetts',
 '?']

The following table highlights the differences in tokenization hyperparameters between BERT and Llama 3.

|                | BERT       | Llama 3 |
| -------------: | ---------- | ------- |
| **Algorithm**  | Word Piece | BPE     |
| **Vocab Size** | 30k        | 128k    |

# Embeddings

The Embeddings stage of the Transformer pipeline converts each of the token ids into an "embedding". This is the first big difference between BERT and Llama 3. While BERT used learned embeddings to encode the tokens' absolute positions, Llama 3 uses Rotary Position Encoding (RoPE) (REFERENCE) to encode the tokens' *relative* positions. RoPE moves responsibility for the position encoding from the Embeddings stage to the attention calculation in the Context stage. For now, the Embeddings stage only needs to worry about the token values.

In [13]:
# Initialize value embeddings lookup table
value_embeddings = nn.Embedding(
    num_embeddings=config.vocab_size, 
    embedding_dim=config.d_model,
)

# Load pre-trained state
load_state(value_embeddings, "value_embeddings")

In [14]:
# Calculate token values
token_values = torch.squeeze(batch.input_ids)

[tokenizer.decode(token_id) for token_id in token_values]

['<|begin_of_text|>',
 'What',
 ' is',
 ' the',
 ' capital',
 ' of',
 ' Massachusetts',
 '?']

In [15]:
# Record sequence length
n = len(token_values)

n

8

In [17]:
# Map token values to embeddings
v = value_embeddings(token_values)

v.shape

torch.Size([8, 4096])

In [18]:
# Show sample of value embeddings
v

tensor([[ 2.6512e-04, -4.9973e-04, -5.8365e-04,  ...,  3.8147e-03,
          6.3419e-05,  1.1902e-03],
        [ 2.0752e-02, -1.2894e-03,  2.8229e-03,  ...,  2.1973e-02,
          3.1128e-03,  1.0681e-02],
        [-2.6093e-03,  7.7057e-04,  2.6131e-04,  ...,  1.1902e-02,
          4.6387e-03,  9.1553e-03],
        ...,
        [ 1.2817e-03,  9.1171e-04,  2.0905e-03,  ...,  1.6251e-03,
          4.0894e-03, -4.0283e-03],
        [ 1.2146e-02,  1.1597e-02,  1.7822e-02,  ...,  1.9684e-03,
         -1.4771e-02, -2.5940e-03],
        [-4.8523e-03, -1.8005e-03,  7.2937e-03,  ...,  2.3956e-03,
         -1.3657e-03, -5.4932e-03]], grad_fn=<EmbeddingBackward0>)

In [19]:
# Save value embeddings as our "input embeddings"
hidden_states = v

# Context Layers

The Context Layers in a Transformer are responsible for infusing each token embedding with contextual signals drawn from the rest of the sequence. The mechanism works by passing the token embeddings through multiple layers of attention and feedforward blocks. The attention blocks focus on relationships between tokens, augmenting each embedding with contextual information drawn from the surrounding embeddings. The feedforward blocks capitalize on the extra context, transforming each augmented embedding using a fully connected neural network. This pattern of attention and transformation is repeated over and over again, gradually converting representations of individual words into representations of abstract semantic concepts over a series of small increments.

<img src="transformer-layers.svg" class="stickshift-figure" width="800">

## Normalize Inputs

## Attention

## Normalize Attention

## FFN

## Normalize FFN

## Stacking the Layers

# Head

# Recap

# References