# Transformer Teardown: Llama 3

> Use what we learned about BERT as a baseline to explore SOTA Llama 3

# Setup

In [1]:
from functools import partial
import math
import warnings

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from pandas import Series
from pytest import approx
import torch
from torch import nn
from torch.nn.functional import relu, softmax
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from stickshift import default_arg, take
from stickshift.models import llama

In [2]:
# Ignore all warnings
warnings.filterwarnings("ignore")

# Configure gpu
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")

# Text Generation with Llama 3

In [3]:
# Create off-the-shelf text generation transformer
transformer = transformers.pipeline("text-generation", model="meta-llama/Meta-Llama-3.1-8B-Instruct", device=device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
transformer("What is the capital of Massachusetts?")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


[{'generated_text': 'What is the capital of Massachusetts? Boston.\nWhat is the capital of Massachusetts?\nA. Boston'}]

# Model Config

In [5]:
# Load model config and pre-trained parameters
config = llama.config(transformer.model)
parameters = transformer.model.state_dict()
llama_model = transformer.model.model

In [6]:
[k for k in parameters]

['model.embed_tokens.weight',
 'model.layers.0.self_attn.q_proj.weight',
 'model.layers.0.self_attn.k_proj.weight',
 'model.layers.0.self_attn.v_proj.weight',
 'model.layers.0.self_attn.o_proj.weight',
 'model.layers.0.mlp.gate_proj.weight',
 'model.layers.0.mlp.up_proj.weight',
 'model.layers.0.mlp.down_proj.weight',
 'model.layers.0.input_layernorm.weight',
 'model.layers.0.post_attention_layernorm.weight',
 'model.layers.1.self_attn.q_proj.weight',
 'model.layers.1.self_attn.k_proj.weight',
 'model.layers.1.self_attn.v_proj.weight',
 'model.layers.1.self_attn.o_proj.weight',
 'model.layers.1.mlp.gate_proj.weight',
 'model.layers.1.mlp.up_proj.weight',
 'model.layers.1.mlp.down_proj.weight',
 'model.layers.1.input_layernorm.weight',
 'model.layers.1.post_attention_layernorm.weight',
 'model.layers.2.self_attn.q_proj.weight',
 'model.layers.2.self_attn.k_proj.weight',
 'model.layers.2.self_attn.v_proj.weight',
 'model.layers.2.self_attn.o_proj.weight',
 'model.layers.2.mlp.gate_proj.w

In [7]:
def load_state(*args, layer=None):
    # Defaults
    layer = default_arg(layer, lambda: 0)

    for module, key in take(2, args):
        match key:
            case "value_embeddings":
                module.load_state_dict({
                    "weight": parameters["model.embed_tokens.weight"],
                })
            case "normalize_inputs":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.input_layernorm.weight"],
                })
            case "queries":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.q_proj.weight"],
                })
            case "keys":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.k_proj.weight"],
                })
            case "values":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.v_proj.weight"],
                })                
            case "attention_outputs":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.self_attn.o_proj.weight"],
                })                
            case "normalize_attention":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.post_attention_layernorm.weight"],
                })
            case "gate":
                module.load_state_dict({
                    "0.weight": parameters[f"model.layers.{layer}.mlp.gate_proj.weight"],
                })
            case "up":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.mlp.up_proj.weight"],
                })
            case "down":
                module.load_state_dict({
                    "weight": parameters[f"model.layers.{layer}.mlp.down_proj.weight"],
                })
            case "normalize_context":
                module.load_state_dict({
                    "weight": parameters["model.norm.weight"],
                })
            case "classifier":
                module.load_state_dict({
                    "weight": parameters["lm_head.weight"],
                })
            case _:
                raise ValueError(f"Unexpected key {key}")


def load_pretrained_state(layer):    
    # Load pre-trained state
    load_state(
        normalize_inputs, "normalize_inputs", 
        queries, "queries", 
        keys, "keys", 
        values, "values", 
        attention_outputs, "attention_outputs", 
        normalize_attention, "normalize_attention",
        gate, "gate",
        up, "up",
        down, "down",
        layer=layer,
    )                

In [8]:
def compare_embeddings(t, llama_t):

    errors = []

    with torch.no_grad():
        # Move both tensors to cpu
        t = t.to("cpu")
        llama_t = llama_t.to("cpu")
    
        # Squeeze llama
        llama_t = llama_t.squeeze()
        assert t.shape == llama_t.shape

        # Reshape both to be 1 long list of embeddings
        t = t.reshape(-1, t.shape[-1])
        llama_t = llama_t.reshape(-1, llama_t.shape[-1])
        assert t.shape == llama_t.shape

        # Compare each embedding
        for i in range(t.shape[0]):
            e1 = t[i]
            e2 = llama_t[i]
            score = torch.dot(e1, e2) / torch.norm(e2)**2
            error = 1.0 - score
            errors.append(error.abs().item())

    return Series(errors)

# Transformer Pipeline

<img src="transformer-pipeline.svg" class="stickshift-figure" width="800">

# Tokenize

In [9]:
# Extract tokenizer from transformer
tokenizer = transformer.tokenizer

In [10]:
# Tokenize sentence
batch = tokenizer("What is the capital of Massachusetts?", return_tensors="pt")

batch

{'input_ids': tensor([[128000,   3923,    374,    279,   6864,    315,  22108,     30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [11]:
[tokenizer.decode(input_id) for input_id in batch.input_ids[0]]

['<|begin_of_text|>',
 'What',
 ' is',
 ' the',
 ' capital',
 ' of',
 ' Massachusetts',
 '?']

# Embeddings

## Value Embeddings

In [12]:
# Initialize value embeddings lookup table
value_embeddings = nn.Embedding(
    num_embeddings=config.vocab_size, 
    embedding_dim=config.d_model,
)

# Load pre-trained state
load_state(value_embeddings, "value_embeddings")

In [13]:
# Calculate token values
values = torch.squeeze(batch.input_ids)

[tokenizer.decode(input_id) for input_id in values]

['<|begin_of_text|>',
 'What',
 ' is',
 ' the',
 ' capital',
 ' of',
 ' Massachusetts',
 '?']

In [14]:
n = len(values)
n

8

In [15]:
# Map token values to embeddings
v = value_embeddings(values)

v.shape

torch.Size([8, 4096])

In [16]:
# Show sample of value embeddings
v

tensor([[ 2.6512e-04, -4.9973e-04, -5.8365e-04,  ...,  3.8147e-03,
          6.3419e-05,  1.1902e-03],
        [ 2.0752e-02, -1.2894e-03,  2.8229e-03,  ...,  2.1973e-02,
          3.1128e-03,  1.0681e-02],
        [-2.6093e-03,  7.7057e-04,  2.6131e-04,  ...,  1.1902e-02,
          4.6387e-03,  9.1553e-03],
        ...,
        [ 1.2817e-03,  9.1171e-04,  2.0905e-03,  ...,  1.6251e-03,
          4.0894e-03, -4.0283e-03],
        [ 1.2146e-02,  1.1597e-02,  1.7822e-02,  ...,  1.9684e-03,
         -1.4771e-02, -2.5940e-03],
        [-4.8523e-03, -1.8005e-03,  7.2937e-03,  ...,  2.3956e-03,
         -1.3657e-03, -5.4932e-03]], grad_fn=<EmbeddingBackward0>)

In [17]:
llama_v = llama_model.embed_tokens(batch.input_ids.to(device))

assert compare_embeddings(v, llama_v).max() < 0.001

## Position Embeddings

Llama 3 uses rotary position encoding (RoPE) algorithm. Instead of baking absolute positions into the input embeddings, RoPE rotates the query and key embeddings according to the tokens' positions in the sequence.

The match between token embedding $m$ and $n$ is calculated as

$$
q_{m}^T k_{n} = (R_{\Theta,m}^d W_{q} x_{m})^T (R_{\Theta,n}^d W_{k} x_{n})
$$

which can be rewritten in pseudo code as

```python
weights[m][n] = transpose(rotate(query(x[m]), m)) * rotate(query(x[n]), n)
```

However, if we pack all the embeddings, queries, and keys into matrices, then we can still use the familiar SDPA equation:

$$
\begin{align}
SDPA &= softmax(\frac{QK^T}{\sqrt{d_K}})V \\
\text{where } Q &= R_{\Theta}^d W_Q X \\
              K &= R_{\Theta}^d W_K X \\
              V &= W_V X
\end{align}
$$

Note that while the query, key, and value projections are still layer-specific, the rotation matrix $R_{\Theta}^d$ is shared across all layers.

The rotation matrix is defined by a series of 2D rotations for each pair of values in the embedding vectors.

$$
\mathbf{R}_{\Theta,m}^d = 
\begin{bmatrix}
cos(m \theta_0) & -sin(m \theta_0) & 0 & 0 & \dots & 0 & 0 \\
sin(m \theta_0) & cos(m \theta_0) & 0 & 0 & \dots & 0 & 0 \\
0 & 0 & cos(m \theta_1) & -sin(m \theta_1) & \dots & 0 & 0 \\
0 & 0 & sin(m \theta_1) & cos(m \theta_1) & \dots & 0 & 0 \\
\vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\
0 & 0 & \dots & 0 & 0 & cos(m \theta_{d/2-1}) & -sin(m \theta_{d/2-1}) \\
0 & 0 & \dots & 0 & 0 & sin(m \theta_{d/2-1}) & cos(m \theta_{d/2-1}) \\
\end{bmatrix}
$$

where

$$
\theta_i = \frac{1}{\Theta^{2i/d}}
$$

which is computed using the more efficient form

$$
\mathbf{R}_{\Theta,m}^d \mathbf{x} = 
\begin{bmatrix}
x_1 \\
x_2 \\
x_3 \\
x_4 \\
\vdots \\
x_{d/2-2} \\
x_{d/2-1} \\
\end{bmatrix}
\begin{bmatrix}
cos(m \theta_0) \\
cos(m \theta_0) \\
cos(m \theta_1) \\
cos(m \theta_1) \\
\vdots \\
cos(m \theta_{d/2-1}) \\
cos(m \theta_{d/2-1}) \\
\end{bmatrix}
+
\begin{bmatrix}
-x_2 \\
x_1 \\
-x_4 \\
x_3 \\
\vdots \\
-x_{d/2-1} \\
x_{d/2-2} \\
\end{bmatrix}
\begin{bmatrix}
sin(m \theta_0) \\
sin(m \theta_0) \\
sin(m \theta_1) \\
sin(m \theta_1) \\
\vdots \\
sin(m \theta_{d/2-1}) \\
sin(m \theta_{d/2-1}) \\
\end{bmatrix}
$$

For now, we simply compute the $cos$ and $sin$ terms which rely only on base $\Theta$, head dimension $d$, and the sequence length $n$.

In [18]:
base = config.rope_base
d = config.d_head

# Compute theta_i = 1 / base^(2i/d) from i = 0 to d/2-1
thetas = 1.0 / base**(2 * torch.arange(d // 2) / d)

thetas.shape

torch.Size([64])

In [19]:
# Compute m * theta_i for position m in 0 to n
frequencies = torch.stack([m*thetas for m in range(n)])

# Duplicate each row
frequencies = torch.cat((frequencies, frequencies), dim=-1)

frequencies.shape

torch.Size([8, 128])

In [20]:
rope_cos = torch.cos(frequencies)
rope_sin = torch.sin(frequencies)

rope_cos.shape, rope_sin.shape

(torch.Size([8, 128]), torch.Size([8, 128]))

In [21]:
# Sanity check
assert rope_cos.shape[0] == n and rope_cos.shape[1] == config.d_head
assert rope_sin.shape[0] == n and rope_sin.shape[1] == config.d_head

In [22]:
# Add extra dimension so we can multiply against multi-head q,k,v
rope_cos = rope_cos.unsqueeze(0)
rope_sin = rope_sin.unsqueeze(0)

rope_cos.shape, rope_sin.shape

(torch.Size([1, 8, 128]), torch.Size([1, 8, 128]))

## Input Embeddings

In [23]:
x = v

In [24]:
llama_x = llama_v
llama_position_ids = torch.arange(n).unsqueeze(0).to(device)
llama_rope_cos, llama_rope_sin = llama_model.rotary_emb(llama_x, llama_position_ids)

# Context

<img src="transformer-layers.svg" class="stickshift-figure" width="800">

In [25]:
llama_layer = llama_model.layers[0]

## Normalize Input

In [26]:
class RMSNorm(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(config.d_model))
        self.eps = config.rms_norm_eps

    def forward(self, x):
        variance = x.pow(2).mean(-1, keepdim=True)
        return self.weight * x * torch.rsqrt(variance + self.eps)

In [27]:
# Configure input normalization
normalize_inputs = RMSNorm(config=config)

# Load pre-trained state
load_state(normalize_inputs, "normalize_inputs")

In [28]:
residual = x
hidden_states = normalize_inputs(x)
hidden_states.shape

torch.Size([8, 4096])

In [29]:
hidden_states

tensor([[ 1.5436e-03, -1.1578e-02, -3.0144e-02,  ...,  3.5905e-02,
          3.0229e-04,  3.5905e-03],
        [ 1.0366e-01, -2.5630e-02,  1.2508e-01,  ...,  1.7744e-01,
          1.2730e-02,  2.7645e-02],
        [-1.7112e-02,  2.0110e-02,  1.5202e-02,  ...,  1.2618e-01,
          2.4905e-02,  3.1110e-02],
        ...,
        [ 9.2948e-03,  2.6309e-02,  1.3447e-01,  ...,  1.9051e-02,
          2.4277e-02, -1.5136e-02],
        [ 4.7014e-02,  1.7862e-01,  6.1193e-01,  ...,  1.2317e-02,
         -4.6805e-02, -5.2024e-03],
        [-3.0458e-02, -4.4975e-02,  4.0612e-01,  ...,  2.4310e-02,
         -7.0178e-03, -1.7866e-02]], grad_fn=<MulBackward0>)

In [30]:
llama_residual = llama_x
llama_hidden_states = llama_layer.input_layernorm(llama_x)
llama_hidden_states.shape

errors = compare_embeddings(hidden_states, llama_hidden_states)
assert errors.max() < 0.001

## Queries, Keys, Values

In [31]:
def split_heads(x, n_heads):
    return x.view(-1, n_heads, config.d_head).transpose(-3, -2)

def combine_heads(x):
    return x.transpose(-3, -2).contiguous().view(-1, int(config.n_heads * config.d_head))

In [32]:
# Configure query, key, value projections
queries = nn.Linear(
    in_features=config.d_model, 
    out_features=config.n_heads * config.d_head,
    bias=False,
)
keys = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
)
values = nn.Linear(
    in_features=config.d_model, 
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
)
attention_outputs = nn.Linear(
    in_features=config.d_model, 
    out_features=config.d_model,
    bias=False,
)

# Load pre-trained state
load_state(queries, "queries", keys, "keys", values, "values", attention_outputs, "attention_outputs")

In [33]:
# Project token embeddings to query, key, and value spaces
q = queries(hidden_states)
k = keys(hidden_states)
v = values(hidden_states)

q.shape, k.shape, v.shape

(torch.Size([8, 4096]), torch.Size([8, 1024]), torch.Size([8, 1024]))

In [34]:
llama_q = llama_layer.self_attn.q_proj(llama_hidden_states)
llama_k = llama_layer.self_attn.k_proj(llama_hidden_states)
llama_v = llama_layer.self_attn.v_proj(llama_hidden_states)
assert compare_embeddings(q, llama_q).max() < 0.001
assert compare_embeddings(k, llama_k).max() < 0.001
assert compare_embeddings(v, llama_v).max() < 0.001

In [35]:
# Split attention heads
q = split_heads(q, config.n_heads)
k = split_heads(k, config.n_kv_heads)
v = split_heads(v, config.n_kv_heads)

q.shape, k.shape, v.shape

(torch.Size([32, 8, 128]), torch.Size([8, 8, 128]), torch.Size([8, 8, 128]))

In [36]:
bsz, q_len, _ = llama_hidden_states.size()
llama_q = llama_q.view(bsz, q_len, llama_layer.self_attn.num_heads, llama_layer.self_attn.head_dim).transpose(1, 2)
llama_k = llama_k.view(bsz, q_len, llama_layer.self_attn.num_key_value_heads, llama_layer.self_attn.head_dim).transpose(1, 2)
llama_v = llama_v.view(bsz, q_len, llama_layer.self_attn.num_key_value_heads, llama_layer.self_attn.head_dim).transpose(1, 2)
assert compare_embeddings(q, llama_q).max() < 0.001
assert compare_embeddings(k, llama_k).max() < 0.001
assert compare_embeddings(v, llama_v).max() < 0.001

In [37]:
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

# Rotate queries and keys
q = (q * rope_cos) + (rotate_half(q) * rope_sin)
k = (k * rope_cos) + (rotate_half(k) * rope_sin)

q.shape, k.shape, v.shape

(torch.Size([32, 8, 128]), torch.Size([8, 8, 128]), torch.Size([8, 8, 128]))

In [38]:
llama_q, llama_k = apply_rotary_pos_emb(llama_q, llama_k, llama_rope_cos, llama_rope_sin)
assert compare_embeddings(q, llama_q).max() < 0.001
assert compare_embeddings(k, llama_k).max() < 0.001
assert compare_embeddings(v, llama_v).max() < 0.001

In [39]:
# Repeat key, value groups
k = k.repeat_interleave(config.n_kv_groups, dim=0)
v = v.repeat_interleave(config.n_kv_groups, dim=0)

q.shape, k.shape, v.shape

(torch.Size([32, 8, 128]), torch.Size([32, 8, 128]), torch.Size([32, 8, 128]))

In [40]:
llama_k = repeat_kv(llama_k, llama_layer.self_attn.num_key_value_groups)
llama_v = repeat_kv(llama_v, llama_layer.self_attn.num_key_value_groups)
assert compare_embeddings(q, llama_q).max() < 0.001
assert compare_embeddings(k, llama_k).max() < 0.001
assert compare_embeddings(v, llama_v).max() < 0.001

## Attention

In [41]:
# Compute masked attention bias
mask = torch.ones(n, n, dtype=torch.bool).tril(diagonal=0)
bias = torch.zeros(n, n)
bias.masked_fill_(mask.logical_not(), float("-inf"))

# Compute attention weights
w = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + bias, dim=-1)

w.shape

torch.Size([32, 8, 8])

In [42]:
# Compute attention for all heads in parallel
a = w @ v

a.shape

torch.Size([32, 8, 128])

In [43]:
llama_a = torch.nn.functional.scaled_dot_product_attention(
    llama_q,
    llama_k,
    llama_v,
    is_causal=True,
)

assert compare_embeddings(a, llama_a).max() < 0.02

In [44]:
# Combine attention heads
a = combine_heads(a)

a.shape

torch.Size([8, 4096])

In [45]:
llama_a = llama_a.transpose(1, 2).contiguous().view(bsz, q_len, -1)
assert compare_embeddings(a, llama_a).max() < 0.02

In [46]:
# Project attention embeddings back to model space
hidden_states = attention_outputs(a)

hidden_states.shape

torch.Size([8, 4096])

In [47]:
llama_hidden_states = llama_layer.self_attn.o_proj(llama_a)
assert compare_embeddings(hidden_states, llama_hidden_states).max() < 0.01

## Add and Normalize

In [48]:
# Combine attention with input embeddings
hidden_states = residual + hidden_states
residual = hidden_states

In [49]:
# Configure attention normalization
normalize_attention = RMSNorm(config=config)

# Load pre-trained state
load_state(normalize_attention, "normalize_attention")

In [50]:
# Normalize attention
hidden_states = normalize_attention(hidden_states)

hidden_states.shape

torch.Size([8, 4096])

In [51]:
# Combine attention with input embeddings
llama_hidden_states = llama_residual + llama_hidden_states
llama_residual = llama_hidden_states

llama_hidden_states = llama_layer.post_attention_layernorm(llama_hidden_states)
assert compare_embeddings(hidden_states, llama_hidden_states).max() < 0.01

## FNN

In [52]:
# Configure FNN layers
gate = nn.Sequential(
    nn.Linear(
        in_features=config.d_model,
        out_features=config.d_fnn,
        bias=False,
    ),
    nn.SiLU()
)
up = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_fnn,
    bias=False,
)
down = nn.Linear(
    in_features=config.d_fnn,
    out_features=config.d_model,
    bias=False,
)

# Load pre-trained state
load_state(gate, "gate", up, "up", down, "down")

In [53]:
hidden_states = down(gate(hidden_states) * up(hidden_states))

In [54]:
llama_hidden_states = llama_layer.mlp(llama_hidden_states)
assert compare_embeddings(hidden_states, llama_hidden_states).max() < 0.01

## Add

In [55]:
hidden_states = residual + hidden_states

In [56]:
hidden_states.shape

torch.Size([8, 4096])

In [57]:
llama_hidden_states = llama_residual + llama_hidden_states
assert compare_embeddings(hidden_states, llama_hidden_states).max() < 0.01

## Stacking the Layers

In [58]:
# Initialize loop w/ initial input embeddings
z_i = x

# Apply layer logic in a loop
for layer in range(config.n_layers):
    
    # Use previous layer's outputs as inputs
    hidden_states_i = z_i

    # Load pre-trained state for layer
    load_pretrained_state(layer)

    #
    # Inputs
    #

    # Normalize inputs
    residual_i = hidden_states_i
    hidden_states_i = normalize_inputs(hidden_states_i)

    #
    # Attention
    #
    
    # Project hidden_states_i to query, key, and value spaces
    q_i = queries(hidden_states_i)
    k_i = keys(hidden_states_i)
    v_i = values(hidden_states_i)
    
    # Split q, k, v into separate attention heads
    q_i = split_heads(q_i, config.n_heads)
    k_i = split_heads(k_i, config.n_kv_heads)
    v_i = split_heads(v_i, config.n_kv_heads)
    
    # Rotate queries and keys
    q_i = (q_i * rope_cos) + (rotate_half(q_i) * rope_sin)
    k_i = (k_i * rope_cos) + (rotate_half(k_i) * rope_sin)

    # Expand keys and values for GQA
    k_i = k_i.repeat_interleave(config.n_kv_groups, dim=0)
    v_i = v_i.repeat_interleave(config.n_kv_groups, dim=0)
    
    # Compute masked attention bias
    mask = torch.ones(n, n, dtype=torch.bool).tril(diagonal=0)
    b_i = torch.zeros(n, n)
    b_i.masked_fill_(mask.logical_not(), float("-inf"))

    # Compute attention for all heads in parallel
    w_i = softmax(
        q_i @ k_i.transpose(-2, -1) / np.sqrt(config.d_head) + b_i, 
        dim=-1,
    )
    a_i = w_i @ v_i
    
    # Recombine attention heads
    a_i = combine_heads(a_i)
    
    # Project attention embeddings back to model space
    hidden_states_i = attention_outputs(a_i)

    # Combine attention with input embeddings
    hidden_states_i = residual_i + hidden_states_i
    residual_i = hidden_states_i
    
    # Normalize attention
    hidden_states_i = normalize_attention(hidden_states_i)

    #
    # FNN
    #

    # Transform
    hidden_states_i = down(gate(hidden_states_i) * up(hidden_states_i))

    # Combine FNN with attention embeddings
    hidden_states_i = residual_i + hidden_states_i
    residual_i = hidden_states_i

    #
    # Outputs
    #

    z_i = hidden_states_i

# Save outputs from last layer
z = z_i

In [59]:
# Configure context normalization
normalize_context = RMSNorm(config=config)

# Load pre-trained state
load_state(normalize_context, "normalize_context")

In [60]:
z = normalize_context(z)

In [61]:
llama_z = llama_model(input_ids=batch.input_ids.to(device)).last_hidden_state
assert compare_embeddings(z, llama_z).max() < 0.01

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


# Head

In [62]:
# Configure language modeling head
classifier = nn.Linear(
    in_features=config.d_model, 
    out_features=config.vocab_size, 
    bias=False,
)

# Load pre-trained state
load_state(classifier, "classifier")

In [63]:
# Use last embedding to represent the entire sequence
features = z[-1]

features.shape

torch.Size([4096])

In [64]:
logits = classifier(features)
logits.shape

torch.Size([128256])

In [66]:
output_id = logits.argmax()
tokenizer.decode(output_id)

' Boston'

In [67]:
# Sanity check
assert output_id == 10406