# Transformer Teardown: Llama 3.1

> Trace an Inference Through Each Layer in a SOTA Transformer Using Llama 3.1 Open Source Foundation Models

In [the last Transformer Teardown](https://stickshift.github.io/2024/09/04/transformer-teardown.html), we dissected a DistilBERT text classification pipeline, tracing a single inference through the entire Transformer stack from raw data to final prediction. We learned about the main stages of a Transformer pipeline as well as fundamental Transformer concepts such as token embeddings and Multi-Head Self Attention. Exploring BERT-based models is a fantastic way to see the core Transformer concepts in action. But BERT was released 6 years ago! ChatGPT wouldn't even exist for another 4 years. It's safe to say a lot has changed since then.

In this post, we'll fast forward to present day. We'll use the same teardown process to unpack the latest [Llama 3.1](https://llama.meta.com/) open source foundation models released by Meta over the summer. We'll break the model down and walk through each step one cell at a time, giving you a close-up view of a modern LLM's inner workings. By the time we're done, you'll leave with a much stronger understanding of the core mechanisms driving the Generative AI revolution.

# Llama Foundation Models

[Llama](https://llama.meta.com/) is a family of general purpose, state-of-the-art foundation models from Meta. According to the 3.1 technical report, the latest models can "answer questions in at least 8 languages, write high quality code, solve complex reasoning problems, and use tools in a zero-shot way." (Dubey et al. 2024) Llama 3.1 includes 8B, 70B, and 405B sizes. While Meta recommends a cluster with at least 8 GPUs to run the 70B and 405B sizes, the 8B model is small enough to run on a single GPU with 20GB of memory. For more technical stats, see [the official Llama 3.1 model card](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md) in GitHub.

Over the course of this post, we'll implement a complete text generation pipeline using only the research literature, pre-trained weights from the `Meta-Llama3.1-8B-Instruct` checkpoint, and Meta's reference implementation as a guide. In the next section, we'll review the stages of an end-to-end, text generation pipeline. In the sections that follow, we'll walk through a detailed teardown of each stage—tracing an inference from raw data to the first output token. In the last section, we'll put all the pieces together into a custom chatbot capable of generating long form content.

Let the teardown begin!

# Setup

In [1]:
from functools import partial
import json
import math
import os
from pathlib import Path
from sys import stdout
from textwrap import dedent
import warnings

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
from pandas import Series
from pydantic import BaseModel, validate_call
from pytest import approx
from tqdm import tqdm

import torch
from torch import nn, Tensor
from torch.nn.functional import relu, silu, softmax

from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import RMSNorm

import stickshift as ss
from stickshift import default_arg, take
from stickshift.models import llama
from stickshift.torch import device as torch_device

In [2]:
# Ignore all warnings
warnings.filterwarnings("ignore")

# Configure gpu
device = torch_device()

In [3]:
%%html
<style>
figure > img {
    display:block;
    margin-left: auto !important;
    margin-right: auto !important;
}
figcaption {
    text-align: center;
}
</style>

In [4]:
%pprint

Pretty printing has been turned OFF


# Text Generation Pipeline

In [the last teardown](https://stickshift.github.io/2024/09/04/transformer-teardown.html), we looked at a text classification Transformer. This time we're going to dissect a *text generation* Transformer. Instead of simply applying a label to the input text, the Head stage will be responsible for *generating* new content. Don't worry—it's not as complicated as it sounds.

Shown below in Figure 1, we can see the stages of a text generation pipeline are the same as the text classification pipeline we saw last time. The Tokenize stage splits raw text into tokens. The Embeddings stage converts individual tokens into embedding vectors. The Context Layers stage augments the input embeddings with contextual signals drawn from the surrounding tokens. Finally, the Head stage converts the contextualized embeddings into task-specific predictions.

<figure>
<img src="transformer-pipeline.svg" width="800">
<figcaption>Figure 1: Text Generation Pipeline</figcaption>
</figure>

The key differences between text generation and text classification are the *task-specific* predictions in the Head stage. While text classification Transformers predict a label for the raw text, *text generation Transformers predict the next token.*

But one token is just the beginning. The magical powers driving the Generative AI explosion come from simply running the predictions in a loop. The next token predicted in each iteration is appended to the end of the input sequence, and the process repeats. Over and over again.

# Checkpoint: 8B-Instruct

We'll start by loading the configuration and pre-trained weights for the `Meta-Llama3.1-8B-Instruct` checkpoint. The weights for all Llama checkpoints can be downloaded directly from [Meta](https://llama.meta.com/) and [Hugging Face](https://huggingface.co/meta-llama).

In [5]:
# Load model config
config = llama.config("Meta-Llama3.1-8B-Instruct")

# Load pre-trained model parameters
checkpoint = torch.load(config.checkpoint_path / "consolidated.00.pth", weights_only=True, map_location=device)

config.model_dump()

{'checkpoint_path': PosixPath('/Users/andrewyoung/.llama/checkpoints/Meta-Llama3.1-8B-Instruct'), 'vocab_size': 128256, 'd_model': 4096, 'd_head': 128, 'd_ffn': 14336, 'n_layers': 32, 'n_heads': 32, 'n_kv_heads': 8, 'n_kv_groups': 4, 'rms_norm_eps': 1e-05, 'rope_theta': 500000.0, 'max_seq_len': 8192, 'temperature': 0.6, 'top_k': 50, 'top_p': 0.9, 'max_output_tokens': 500}

We'll reference a number of the settings in `config` throughout the teardown. For now, a few interesting ones to note are `d_model`, `d_fnn`, `n_layers`, and `n_heads`. These represent the main differences between the 8B, 70B, and 405B sizes.

In [6]:
def load_pretrained_state(layer):    
    # Load pre-trained state
    llama.load_state(
        normalize_attention, "normalize_attention", 
        normalize_ffn, "normalize_ffn", 
        w_q, "w_q", 
        w_k, "w_k", 
        w_v, "w_v", 
        attention_outputs, "attention_outputs",
        ffn_gates, "ffn_gates",
        ffn_inputs, "ffn_inputs",
        ffn_outputs, "ffn_outputs",
        checkpoint=checkpoint,
        layer=layer,
    ) 

# Raw Text

Before we can tear anything down, we need a prompt. Since our goal is to trace an inference from raw text to the first output token, we want to start with a prompt that's specific enough to generate a consistent, one-word answer. If we do everything right, the first output token we predict should be "Boston".

In [7]:
# Prompt
prompt = "<|start_header_id|>user<|end_header_id|>\n\n"
prompt += "What is the capital of Massachusetts? Answer in one word."
prompt += "<|eot_id|>"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"

You can see `prompt` already includes a number of special tokens. While the chatbot we build at the end will add these for us automatically, we need to manually inject them for now. You can read more about the Llama 3.1 prompt syntax in the [Llama Prompting Guide](https://www.llama.com/docs/how-to-guides/prompting).

# Tokenize

The Tokenize stage splits raw text into tokens using a fixed vocabulary. Llama 3.1 uses a vocabulary of 128k tokens built on top of [OpenAI's tiktoken](https://github.com/openai/tiktoken) tokenizer. We'll dig into the gory details in the later stages, but here we'll simply use the off-the-shelf Tokenizer from Meta's [llama-models](https://github.com/meta-llama/llama-models) reference implementation.

In [8]:
# Load tokenizer model from checkpoint
tokenizer = Tokenizer(str(config.checkpoint_path / "tokenizer.model"))

In [9]:
# Split raw text into tokens
token_ids = tokenizer.encode(prompt, bos=True, eos=False, allowed_special="all")
token_ids

[128000, 128006, 882, 128007, 271, 3923, 374, 279, 6864, 315, 22108, 30, 22559, 304, 832, 3492, 13, 128009, 128006, 78191, 128007, 271]

In [10]:
len(token_ids)

22

We see `tokenizer.encode` split our prompt into 22 token ids. These ids represent the index of each token in Llama 3.1's 128k token vocabulary. We can always reverse the process with `tokenizer.decode`. If you look closely at the cell output below, you'll notice the tokenizer injected another special token `(128000, '<|begin_of_text|>')` to mark the beginning of the sequence.

In [11]:
# Decode token ids back into raw text
tokenizer.decode(token_ids)

'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWhat is the capital of Massachusetts? Answer in one word.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

Our last step is to convert `token_ids` into a PyTorch tensor. The `x` variable represents our "current state". We'll trace `x` through every stage of the pipeline.

In [12]:
# Load token_ids into a tensor
x = torch.tensor(token_ids, device=device)

x.shape

torch.Size([22])

# Embeddings

The Embeddings stage converts individual tokens into embeddings. Embeddings (Bengio et al. 2000) are vectors that represent tokens as unique points in a multi-dimensional vector subspace. Embeddings serve as the fundamental data structure of Transformers. The context layers we'll look at in the next stage take embeddings as input, transform them, and produce embeddings as output. All of the token embeddings in a Transformer have dimension $d_{model}$. Llama calculates the size of the FFNs from $d_{model}$ ($d_{ffn} \approx \frac{8}{3} d_{model}$), making $d_{model}$ the primary hyperparameter that distinguishes between the 8B, 70B, and 405B models.

In the following cells, we start by creating a lookup table with one embedding for each of the 128k tokens in the Llama 3.1 vocabulary. Next, we load the pre-trained embedding values from the model checkpoint. Finally, we use the lookup table to map each token value to an embedding vector.

In [13]:
# Initialize embeddings lookup table
embeddings = nn.Embedding(
    num_embeddings=config.vocab_size, 
    embedding_dim=config.d_model,
    device=device,
)

# Load pre-trained state
llama.load_state(embeddings, "embeddings", checkpoint=checkpoint)

In [14]:
# Map token values to embeddings
x = embeddings(x)

x.shape

torch.Size([22, 4096])

At this point, we split our prompt into a sequence of tokens and mapped each one to an embedding vector. From `x.shape` above, you can see the embeddings are represented as row vectors stacked together into an $n \times d_{model}$ tensor.

In [15]:
# Show sample
x

tensor([[ 2.6512e-04, -4.9973e-04, -5.8365e-04,  ...,  3.8147e-03,
          6.3419e-05,  1.1902e-03],
        [-1.6499e-04, -2.4319e-04,  1.6403e-04,  ..., -1.5163e-04,
          3.5095e-04,  7.3242e-04],
        [ 3.5095e-03,  7.2021e-03,  5.3406e-05,  ..., -7.2479e-04,
         -1.0620e-02,  8.2779e-04],
        ...,
        [-9.7656e-03, -3.4637e-03,  1.8616e-03,  ..., -7.1411e-03,
         -4.3030e-03,  8.6060e-03],
        [-4.6158e-04, -3.9291e-04, -6.5863e-06,  ..., -6.2561e-04,
         -5.0354e-04,  6.6757e-04],
        [-2.8687e-03,  3.8910e-03, -1.7357e-04,  ...,  8.0872e-04,
          5.0354e-04,  2.3041e-03]], device='mps:0',
       grad_fn=<EmbeddingBackward0>)

# Context Layers

The Context Layers in a Transformer are responsible for infusing each token embedding with contextual signals drawn from the rest of the sequence. The mechanism works by passing the token embeddings through multiple layers of attention and feedforward blocks. The attention blocks focus on relationships between tokens, augmenting each embedding with a weighted combination of the surrounding embeddings. The feedforward blocks capitalize on the extra context, transforming each augmented embedding using a fully connected neural network. This pattern of attention and transformation is repeated over and over again, gradually converting representations of individual words into representations of abstract semantic concepts over a series of small increments.

<figure>
<img src="transformer-layers.svg" width="800">
<figcaption>Context Layers</figcaption>
</figure>

**[Introduce pre-normalize]**

## Rotary Position Encoding (RoPE)

In [None]:
def rope(n):
    # Hyperparams
    base = config.rope_theta
    d = config.d_head
    
    # Compute theta_i = 1 / base^(2i/d) from i = 0 to d/2-1
    thetas = 1.0 / base**(2 * torch.arange(d // 2, device=device) / d)
    
    # Compute m * theta_i for position m in 0 to n
    frequencies = torch.stack([m*thetas for m in range(n)])
    
    # Duplicate each row
    frequencies = torch.cat((frequencies, frequencies), dim=-1)
    
    # Apply cos, sin
    cos = torch.cos(frequencies)
    sin = torch.sin(frequencies)
    
    # Sanity check
    assert cos.shape[0] == n and cos.shape[1] == config.d_head
    assert sin.shape[0] == n and sin.shape[1] == config.d_head

    return cos, sin

In [None]:
# Compute RoPE rotation matrices
rope_cos, rope_sin = rope(len(x))

rope_cos.shape, rope_sin.shape

## Attention

**[Intro attention algorithm]**

### Normalize Attention Inputs

In [None]:
# Configure attention normalization
normalize_attention = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_attention, "normalize_attention", checkpoint=checkpoint)

In [None]:
# Normalize attention inputs
residual = x
x = normalize_attention(x)

x.shape

### Project Queries, Keys, Values

**[Introduce GQA]**

In [None]:
# Configure query, key, value projections
w_q = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_heads * config.d_head,
    bias=False,
    device=device,
)
w_k = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)
w_v = nn.Linear(
    in_features=config.d_model,
    out_features=config.n_kv_heads * config.d_head,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(w_q, "w_q", w_k, "w_k", w_v, "w_v", checkpoint=checkpoint)

In [None]:
# Project embeddings to query, key, value spaces
q = w_q(x)
k = w_k(x)
v = w_v(x)

q.shape, k.shape, v.shape

**[Discuss GQA]**

### Split Attention Heads

In [None]:
def split_heads(x, n_heads):
    return x.view(-1, n_heads, config.d_head).transpose(-3, -2)

In [None]:
# Split attention heads
q = split_heads(q, config.n_heads)
k = split_heads(k, config.n_kv_heads)
v = split_heads(v, config.n_kv_heads)

q.shape, k.shape, v.shape

### Encode Positions (RoPE)

In [None]:
# Encode positions by rotating queries and keys
q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)

q.shape, k.shape, v.shape

### Expand Key / Value Groups (GQA)

In [None]:
# Expand key/value groups
k = k.repeat_interleave(config.n_kv_groups, dim=0)
v = v.repeat_interleave(config.n_kv_groups, dim=0)

q.shape, k.shape, v.shape

In [None]:
# Sanity check
assert q.shape == k.shape == v.shape

### Calculate Attention


In [None]:
# Compute attention mask M
n = len(x)
mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))

m

In [None]:
# Compute attention for all heads in parallel
a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v

a.shape

### Recombine Attention Heads

In [None]:
def combine_heads(x):
    return x.transpose(-3, -2).contiguous().view(-1, int(config.n_heads * config.d_head))

In [None]:
# Combine attention heads
a = combine_heads(a)

a.shape

### Project Attention Outputs

In [None]:
# Configure attention output projection
attention_outputs = nn.Linear(
    in_features=config.d_model, 
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(attention_outputs, "attention_outputs", checkpoint=checkpoint)

In [None]:
# Project attention embeddings back to model space
a = attention_outputs(a)

a.shape

### Combine w/ Residuals

In [None]:
# Combine attention embeddings with residuals
x = residual + a

x.shape

## FFN

### Normalize FFN Inputs

In [None]:
# Configure FFN normalization
normalize_ffn = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained state
llama.load_state(normalize_ffn, "normalize_ffn", checkpoint=checkpoint)

In [None]:
# Normalize FFN inputs
residual = x
x = normalize_ffn(x)

x.shape

### Transform

In [None]:
# Configure SwiGLU FFN
ffn_gates = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)
ffn_inputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.d_ffn,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(ffn_gates, "ffn_gates", ffn_inputs, "ffn_inputs", checkpoint=checkpoint)

In [None]:
# Apply transform
f = silu(ffn_gates(x)) * ffn_inputs(x)

f.shape

### Project FFN Outputs

In [None]:
# Configure FFN output projection
ffn_outputs = nn.Linear(
    in_features=config.d_ffn,
    out_features=config.d_model,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(ffn_outputs, "ffn_outputs", checkpoint=checkpoint)

In [None]:
# Project FFN embeddings back to model space
f = ffn_outputs(f)

f.shape

### Combine w/ Residuals

In [None]:
# Combine FFN embeddings with residuals
x = residual + f

x.shape

## Stacking the Layers

Now that we've gone through each step, let's put all the pieces together.

In [None]:
def context_layers(x):
    # Compute RoPE rotation matrices
    rope_cos, rope_sin = rope(len(x))

    # Apply layer logic in a loop
    for layer in range(config.n_layers):
    
        # Load pre-trained state for layer
        load_pretrained_state(layer)
    
        #
        # Attention
        #
    
        # Normalize attention inputs
        residual = x
        x = normalize_attention(x)
        
        # Project embeddings to query, key, value spaces
        q = w_q(x)
        k = w_k(x)
        v = w_v(x)
        
        # Split attention heads
        q = split_heads(q, config.n_heads)
        k = split_heads(k, config.n_kv_heads)
        v = split_heads(v, config.n_kv_heads)
    
        # Encode positions by rotating queries and keys
        q = (q * rope_cos) + (llama.rotate_half(q) * rope_sin)
        k = (k * rope_cos) + (llama.rotate_half(k) * rope_sin)
        
        # Expand key/value groups
        k = k.repeat_interleave(config.n_kv_groups, dim=0)
        v = v.repeat_interleave(config.n_kv_groups, dim=0)
    
        # Compute masked attention bias M
        n = len(x)
        mask = torch.ones(n, n, dtype=torch.bool, device=device).tril(diagonal=0)
        m = torch.zeros(n, n, device=device).masked_fill_(mask.logical_not(), float("-inf"))
        
        # Compute attention for all heads in parallel
        a = softmax(q @ k.transpose(-2, -1) / np.sqrt(config.d_head) + m, dim=-1) @ v
    
        # Combine attention heads
        a = combine_heads(a)
        
        # Project attention embeddings back to model space
        a = attention_outputs(a)
        
        # Combine attention embeddings with residuals
        x = residual + a
        
        #
        # FFN
        #
    
        # Normalize FFN inputs
        residual = x
        x = normalize_ffn(x)
    
        # Apply transform
        f = silu(ffn_gates(x)) * ffn_inputs(x)
    
        # Project FFN embeddings back to model space
        f = ffn_outputs(f)
        
        # Combine FFN embeddings with residuals
        x = residual + f

    return x

In [None]:
# Start over from initial tokens
x = torch.tensor(token_ids, device=device)

# Initial embeddings
x = embeddings(x)

# Contextualized embeddings
x = context_layers(x)

x.shape

In [None]:
x

# Head

## Normalize Head Inputs

In [None]:
# Configure head normalization
normalize_head = RMSNorm(config.d_model, config.rms_norm_eps).to(device)

# Load pre-trained weights
llama.load_state(normalize_head, "normalize_head", checkpoint=checkpoint)

In [None]:
# Normalize head inputs
x = normalize_head(x)

x.shape

## Project Head Outputs

In [None]:
# Configure output projection
head_outputs = nn.Linear(
    in_features=config.d_model,
    out_features=config.vocab_size,
    bias=False,
    device=device,
)

# Load pre-trained weights
llama.load_state(head_outputs, "head_outputs", checkpoint=checkpoint)

In [None]:
# Use last embedding to represent the entire sequence
x = x[-1]

# Project outputs to token space
x = head_outputs(x)

x.shape

## Top Token

In [None]:
# Select top scoring token
token_id = x.argmax()

# Decode token
token = tokenizer.decode([token_id]).strip()

token

In [None]:
# Verify answer
assert token == "Boston"

## Sampling

### Temperature

In [None]:
# Hyperparameters
temperature = config.temperature
temperature

In [None]:
# Apply temperature
x = x / temperature

### Ranking

In [None]:
# Convert logits to probabilities
probs = softmax(x)

# Sort probabilities in descending order
probs, indices = probs.sort(descending=True)

### Top K

In [None]:
# Hyperparameters
top_k = config.top_k
top_k

In [None]:
# Retain top k tokens
probs = probs[:top_k]
print(f"Retained {len(probs)} of {len(x)}")

### Top P

In [None]:
# Hyperparameters
top_p = config.top_p
top_p

In [None]:
# Find cutoff where cumulative probability exceeds top_p
cumulative_mask = probs.cumsum(dim=-1) > top_p
threshold_index = torch.argmax(cumulative_mask).item()

# Only apply threshold if top_p was exceeded
if cumulative_mask.any():
    probs = probs[:threshold_index+1]

print(f"Retained {len(probs)} of {len(x)}")

### Random Selection

In [None]:
# Print remaining token pool
for i, prob in enumerate(probs):
    print(f"token id {indices[i]}, token '{tokenizer.decode([indices[i]])}', score {prob:0.3f}")

In [None]:
# Sample from remaining tokens weighted by probability
sampled_index = torch.multinomial(probs, 1)

# Convert sampled_index to original logits
token_id = indices[sampled_index]

# Decode token
token = tokenizer.decode([token_id]).strip()

token

## Complete Head Stage

In [None]:
def head(x):
    # Normalize head inputs
    x = normalize_head(x)
    
    # Use last embedding to represent the entire sequence
    x = x[-1]
    
    # Project outputs to token space
    x = head_outputs(x)

    #
    # Temperature
    #
    
    # Apply temperature
    x = x / config.temperature

    #
    # Ranking
    #
    
    # Convert logits to probabilities
    probs = softmax(x)
    
    # Sort probabilities in descending order
    probs, indices = probs.sort(descending=True)

    #
    # Top K
    #
    
    # Retain top k tokens
    probs = probs[:config.top_k]

    #
    # Top P
    #
    
    # Find cutoff where cumulative probability exceeds top_p
    cumulative_mask = probs.cumsum(dim=-1) > config.top_p
    threshold_index = torch.argmax(cumulative_mask).item()
    
    # Only apply threshold if top_p was exceeded
    if cumulative_mask.any():
        probs = probs[:threshold_index+1]

    #
    # Random Selection
    #
    
    # Sample from remaining tokens weighted by probability
    sampled_index = torch.multinomial(probs, 1)
    
    # Convert sampled_index to original logits
    token_id = indices[sampled_index]

    return token_id.item()

# Generator

In [None]:
class Message(BaseModel):
    role: str
    content: str

In [None]:
@validate_call
def prepare_messages(messages: list[Message]):
    # Initialize prompt
    prompt = ""
    
    # Format each message
    for message in messages:
        prompt += f"<|start_header_id|>{message.role}<|end_header_id|>\n\n"
        prompt += message.content
        prompt += "<|eot_id|>"

    # Finish with the assistant role to prime the model's response
    prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"

    return prompt

In [None]:
@validate_call
def generate(messages: list[Message]):
    # Format message prompt
    prompt = prepare_messages(messages)
    
    # Split raw text into tokens
    token_ids = tokenizer.encode(prompt, bos=True, eos=False, allowed_special="all")
    
    # Generate output until we get a stop token or we exceed max_output_tokens.
    for _ in range(config.max_output_tokens):
        
        # Start over from initial tokens
        x = torch.tensor(token_ids, device=device)
        
        # Initial embeddings
        x = embeddings(x)
        
        # Contextualized embeddings
        x = context_layers(x)
        
        # Head
        token_id = head(x)
        
        # Check stopping criteria
        if token_id in tokenizer.stop_tokens:
            break
    
        # Print token
        token = tokenizer.decode([token_id])
        stdout.write(token)
        
        # Append to end of sequence
        token_ids.append(token_id)

In [None]:
generate([
    {
        "role": "user",
        "content": "What is capital of Massachusetts?",
    },
])