# transformers from scratch
<a target="_blank" href="https://colab.research.google.com/github/https://colab.research.google.com/drive/1clXCeAvjWm_MERBO-sTzkEDvSfIisU1c">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/headers/header-11.png" width="350">

# 0️⃣ introduction

this is a clean, first principles implementation of gpt-2 in pytorch. the architectural choices closely follow those used by the [`TransformerLens`](https://github.com/TransformerLensOrg/TransformerLens) library. this tutorial and associated exercises are an accompaniment to nanda's excellent transformer [walkthrough videos](https://www.youtube.com/watch?v=bOYE6E8JrtU).

these resources are also helpful:
- [visualization](https://bbycroft.net/llm) of the gpt (see the nano-gpt internals)
- 3 blue 1 brown, [transformers](https://www.youtube.com/watch?v=wjZofJX0v4M&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=6) and [attention](https://www.youtube.com/watch?v=eMlx5fFNoYc&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&index=8)

note: each exercise has a difficulty and importance rating. please do skip exercises / look at solutions if you don't feel like they're important enough to be worth doing, and you'd rather get to the good stuff!

## setup

In [1]:
%pip install transformer_lens==2.11.0 einops jaxtyping circuitsvis
# %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python # for mech interp visualizations
%pip install -U datasets

Collecting transformer_lens==2.11.0
  Downloading transformer_lens-2.11.0-py3-none-any.whl.metadata (12 kB)
Collecting jaxtyping
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting circuitsvis
  Downloading circuitsvis-1.43.3-py3-none-any.whl.metadata (983 bytes)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens==2.11.0)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens==2.11.0)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens==2.11.0)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.10->transformer_lens==2.11.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting 

In [None]:
import os
# import sys
# import math
# import webbrowser
# from collections import defaultdict
from dataclasses import dataclass
# from pathlib import Path
import datasets
import einops
import numpy as np
import torch
# import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
import wandb
# from typing import Callable
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from tqdm.notebook import tqdm
from transformer_lens import HookedTransformer
# from transformer_lens.utils import gelu_new, tokenize_and_concatenate
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast


device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

## sign in

In [None]:
!pip install python-dotenv
!git clone https://github.com/sepiatone/wb-colab-files.git
%cd wb-colab-files

In [None]:
from login import login_form
from submit import test_submit

login_form()

## content & learning objectives

### 1️⃣ understanding the inputs & outputs of a transformer

in this section, we'll take a first look at transformers - what their function is, how information moves inside a transformer, and what inputs & outputs they take


### 2️⃣ clean transformer implementation

in thi section, we'll understand the high level architecture of a transformer and then implement it from first principles


### 3️⃣ training a transformer

in this section, we'll learn how to train a transformer from scratch


### 4️⃣ sampling from a transformer

in this section we'll learn how to sample from a transformer

# 1️⃣ understanding the inputs & outputs of a transformer

##### learning objectives
- understand what a transformer is used for
- understand causal attention, and what a transformer's output represents
- understand what logits are, and how to use them to derive a probability distribution over the vocabulary
- learn what tokenization is, and how models do it

## 1.0 what is the point of a transformer? what is causal attention? what are logits?

**transformers exist to model text!**

we're going to focus on gpt-2 style transformers.

key feature: they generate text! you feed in language, and the model generates a probability distribution over tokens. and you can repeatedly sample from this to generate text!

(to explain this in more detail - you feed in a sequence of length $n$, then sample from the probability distribution over the $n+1$-th ~~word~~ token (explained soon!), use this to construct a new sequence of length $n+1$, then feed this new sequence into the model to get a probability distribution over the $n+2$-th token, and so on.)

### how is the model trained?

you give it a bunch of text, and train it to predict the next token.

importantly, if you give a model 100 tokens in a sequence, it predicts the next token for *each* prefix, i.e. it produces 100 logit vectors (= probability distributions) over the set of all words in our vocabulary, with the `i`-th logit vector representing the probability distribution over the token *following* the `i`-th token in the sequence. this is a key part of what allows transformers to be trained so efficiently; for every sequence of length $n$ we get $n$ different predictions to train on:

$$
p(x_1), \; p(x_2|x_1), \; p(x_3|x_1x_2), \; \ldots, \; p(x_n|x_1 \ldots x_{n-1})
$$

<details>
<summary>aside - logits</summary>

if you haven't encountered the term "logits" before, here's a quick refresher.

given an arbitrary vector $x$, we can turn it into a probability distribution via the **softmax** function: $x_i \to \frac{e^{x_i}}{\sum e^{x_j}}$. the exponential makes everything positive; the normalization makes it add to one.

the model's output is the vector $x$ (one for each prediction it makes). we call this vector a logit because it represents a probability distribution, and it is related to the actual probabilities via the softmax function.
</details>

how do we stop the transformer by "cheating" by just looking at the tokens it's trying to predict? answer - we make the transformer have *causal attention* (as opposed to *bidirectional attention*). causal attention only allows information to move forwards in the sequence, never backwards. the prediction of what comes after token 50 is only a function of the first 50 tokens, *not* of token 51. we say the transformer is **autoregressive**, because it only predicts future words based on past data.

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/transformer-overview-new.png" width="900">

## 1.1 transformer inputs - tokens

our tranformer's input is natural language (i.e. a sequence of characters, strings, etc). But ml models generally take vectors as input, not language. how do we convert language to vectors?

we can factor this into 2 questions:

1. how do we split up language into small sub-units?
2. how do we convert these sub-units into vectors?

Let's start with the second of these questions.

### converting sub-units to vectors

we basically make a massive lookup table, which is called an **embedding**. it has one vector for each possible sub-unit of language we might get (we call this set of all sub-units our **vocabulary**). we label every element in our vocabulary with an integer (this labelling never changes), and we use this integer to index into the embedding.

a key intuition is that one-hot encodings let you think about each integer independently. We don't bake in any relation between sub-units when we perform our embedding, because every sub-units has a completely separate embedding vector.

<details>
<summary>aside - one-hot encodings</summary>

we sometimes think about **one-hot encodings** of sub-units. these are vectors with zeros everywhere, except for a single one in the position corresponding to the sub-unit's index in the vocabulary. this means that indexing into the embedding is equivalent to multiplying the **embedding matrix** by the one-hot encoding (where the embedding matrix is the matrix we get by stacking all the embedding vectors on top of each other).

$$
\begin{aligned}
W_E &= \begin{bmatrix}
\leftarrow v_0 \rightarrow \\
\leftarrow v_1 \rightarrow \\
\vdots \\
\leftarrow v_{d_{vocab}-1} \rightarrow \\
\end{bmatrix} \quad \text{is the embedding matrix (size }d_{vocab} \times d_{embed}\text{),} \\
\\
t_i &= (0, \dots, 0, 1, 0, \dots, 0) \quad \text{is the one-hot encoding for the }i\text{th word (length }d_{vocab}\text{)} \\
\\
v_i &= t_i W_E \quad \text{is the embedding vector for the }i\text{th word (length }d_{embed}\text{).} \\
\end{aligned}
$$

</details>

a key point is that the embedding matrix is also learnt during model training.

now, let's answer the first question - how do we split language into sub-units?

### splitting language into sub-units

we need to define a standard way of splitting up language into a series of substrings, where each substring is a member of our **vocabulary** set.

could we use a dictionary, and have our vocabulary be the set of all words in the dictionary? no, because this couldn't handle arbitrary text (e.g. urls, punctuation, etc). we need a more general way of splitting up language.

could we just use the 256 ascii characters? this fixes the previous problem, but it loses structure of language - some sequences of characters are more meaningful than others. for example, "language" is a lot more meaningful than "hjksdfiu". we want "language" to be a single token, but not "hjksdfiu" - this is a more efficient use of our vocab.

what actually happens? the most common strategy is called **byte-pair encodings**.

we begin with the 256 ascii characters as our tokens, and then find the most common pair of tokens, and merge that into a new token. note that we do have a space character as one of our 256 tokens, and merges using space are very common. for instance, here are the five first merges for the tokenizer used by gpt-2 (you'll be able to verify this below).

```
" t"
" a"
"he"
"in"
"re"
```

note - you might see the character `Ġ` in front of some tokens. This is a special token that indicates that the token begins with a space. tokens with a leading space vs not are different.

you can run the code below to load in the `gpt2-small` model, and see more of its tokenizer's vocabulary:

In [None]:
reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,  # you'll learn about these arguments later!
)

sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n: n[1])

print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()

<details>
<summary>aside - HookedTransformer</summary>

`HookedTransformer` is a class from the `TransformerLens` library instrumented (provides hooks) to access and potentially modify the internal activations of the model during a forward pass. this is useful for understanding how information flows through the different layers and components of the transformer.
</details>

as you get to the end of the vocabulary, you'll be producing some pretty weird-looking esoteric tokens (because you'll already have exhausted all of the short frequently-occurring ones):

In [None]:
print(sorted_vocab[-20:])

<details>
<summary>fun (completely optional) exercise - can you guess what the first-formed 3/4/5/6/7-letter encodings in gpt-2's vocabulary are?</summary>
run this code to find out:

```python
lengths = dict.fromkeys(range(3, 8), "")
for tok, idx in sorted_vocab:
    if not lengths.get(len(tok), True):
        lengths[len(tok)] = tok

for length, tok in lengths.items():
    print(f"{length}: {tok}")
```
</details>

transformers in the `TransformerLens` library have a `to_tokens` method that converts text to numbers. it also prepends them with a special token called bos (beginning of sequence) to indicate the start of a sequence. you can disable this with the `prepend_bos=False` argument.

<details>
<summary>aside - bos token</summary>

the beginning of sequence (bos) token is a special token used to mark the beginning of the sequence. confusingly, in gpt-2, the end of sequence (eos), beginning of sequence (bos) and padding (pad) tokens are all the same, `<|endoftext|>` with index `50256`.

why is this token added? some basic intuitions are:

* it provides context that this is the start of a sequence, which can help the model generate more appropriate text.
* it can act as a "rest position" for attention heads (more on this later, when we discuss attention).

`TransformerLens` adds this token automatically (including in forward passes of transformer models, e.g. it's implicitly added when you call `model("hello world")`). you can disable this behaviour by setting the flag `prepend_bos=False` in `to_tokens`, `to_str_tokens`, `model.forward` and any other function that converts strings to multi-token tensors.

**key point: *if you get weird off-by-one errors, check whether there's an unexpected `prepend_bos`!***

why are the bos, eos and pad tokens the same? This is because gpt-2 is an autoregressive model, and uses these kinds of tokens in a slightly different way to other transformer families (e.g. bert). For instance, gpt has no need to distinguish between bos and eos tokens, because it only processes text from left to right.

</details>

### some tokenization annoyances

there are a few funky and frustrating things about tokenization, which causes it to behave differently than you might expect. for instance:

#### whether a word begins with a capital or space matters!

In [None]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

#### arithmetic is a mess.

length is inconsistent, common numbers bundle together.

In [None]:
print(reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000"))

### key takeaways

- we learn a dictionary of vocab of tokens (sub-words).
- we (approx) losslessly convert language to integers via tokenizing it.
- we convert integers to vectors via a lookup table.
- note: input to the transformer is a sequence of *tokens* (ie integers), not vectors

## 1.2 transformer outputs - logits (and eventually text generation)

now that we understand the basic ideas here, let's go through the entire process of text generation, from our original string to a new token which we can append to our string and plug back into the model.

#### **step 1:** convert text to tokens

the sequence gets tokenized, so it has shape `[batch, seq_len]`. Here, the batch dimension is just one (because we only have one sequence).

In [None]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text).to(device)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

#### **step 2:** map tokens to logits


from our input of shape `[batch, seq_len]`, we get output of shape `[batch, seq_len, vocab_size]`. the `[i, j, :]`-th element of our output is a vector of logits representing our prediction for the `j+1`-th token in the `i`-th sequence.

In [None]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

(`run_with_cache` tells the model to cache all intermediate activations. this isn't important right now; we'll look at it in more detail later.)

#### **step 3:** convert the logits to a distribution with a softmax

This doesn't change the shape, it is still `[batch, seq_len, vocab_size]`.

In [None]:
probs = logits.softmax(dim=-1)
print(probs.shape)

#### **bonus step:** what is the most likely next token at each position?

In [None]:
most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])

print(list(zip(reference_gpt2.to_str_tokens(tokens), most_likely_next_tokens)))

we can see that, in a few cases (particularly near the end of the sequence), the model accurately predicts the next token in the sequence. We might guess that `"take over the world"` is a common phrase that the model has seen in training, which is why the model can predict it.

#### **step 4:** map distribution to a token

In [None]:
next_token = logits[0, -1].argmax(dim=-1)
next_char = reference_gpt2.to_string(next_token)
print(repr(next_char))

note that we're indexing `logits[0, -1]`. this is because logits have shape `[1, sequence_length, vocab_size]`, so this indexing returns the vector of length `vocab_size` representing the model's prediction for what token follows the **last** token in the input sequence.

in this case, we can see that the model predicts the token `' I'`.

### **step 5:** add this to the end of the input, re-run

there are more efficient ways to do this (e.g. where we cache some of the values each time we run our input, so we don't have to do as much calculation each time we generate a new value), but this doesn't matter conceptually right now.

In [None]:
print(f"Sequence so far: {reference_gpt2.to_string(tokens)[0]!r}")

for i in range(10):
    print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
    # Define new input sequence, by appending the previously generated token
    tokens = t.cat([tokens, next_token[None, None]], dim=-1)
    # Pass our new sequence through the model, to get new output
    logits = reference_gpt2(tokens)
    # Get the predicted token at the end of our sequence
    next_token = logits[0, -1].argmax(dim=-1)
    # Decode and print the result
    next_char = reference_gpt2.to_string(next_token)

## 1.3 key takeaways

- a transformer takes in language, predicts next token (for *each* token in a causal way)
- we convert language to a sequence of integers with a tokenizer
- we convert integers to vectors with a lookup table
- the output is a vector of logits (one for each input token), we convert to a probability distribution with a softmax, and can then convert this to a token (eg taking the largest logit, or sampling)
- we append this to the input + run again to generate more text (jargon: *autoregressive*)
- meta level point: transformers are sequence operation models, they take in a sequence, do processing in parallel at each position, and use attention to move information between positions!

# 2️⃣ clean transformer implementation

##### learning objectives

- understand that a transformer is composed of attention heads and mlps, with each one performing operations on the residual stream
- understand that the attention heads in a single layer operate independently, and that they have the role of calculating attention patterns (which determine where information is moved to & from in the residual stream)
- learn about & implement the following transformer modules:
  * embedding - a lookup table from tokens to residual stream vectors
  * positional embedding - a lookup table from position indices to residual stream vectors
  * layer normalization - transforming the input to have zero mean and unit variance
  * attention - the method of computing attention patterns for residual stream vectors
  * mlp - the collection of linear and nonlinear transformations which operate on each residual stream vector in the same way
  * unembedding - a matrix for converting residual stream vectors into a distribution over tokens

## 2.1 high-level architecture

watch nanda's [transformer circuits walkthrough](https://www.youtube.com/watch?v=KV5gbOmHbjU) if you want more intuitions!

(diagram is bottom to top, right-click and open for higher resolution.)

<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-new2.png" width="950">

### tokenization & embedding

The input tokens $t$ are integers. we get them from taking a sequence, and tokenizing it (like we saw in the previous section).

the token embedding is a lookup table mapping tokens to vectors, which is implemented as a matrix $W_E$. The matrix consists of a stack of token embedding vectors (one for each token).

### residual stream

the residual stream is the sum of all previous outputs of layers of the model, and is the input to each new layer. it has shape `[batch, seq_len, d_model]` (where `d_model` is the length of a single embedding vector).

the initial value of the residual stream is denoted $x_0$ in the diagram, and $x_i$ are later values of the residual stream (after more attention and mlp layers have been applied to the residual stream).

the residual stream is *really* fundamental. tt's the central object of the transformer. it's how the model remembers things, moves information between layers for composition, and it's the medium used to store the information that attention moves between positions.

<details>
<summary>aside - <b>logit lens</b></summary>

a key idea of transformers is the [residual stream as output accumulation](https://www.lesswrong.com/posts/X26ksz4p3wSyycKNB/gears-level-mental-models-of-transformer-interpretability#Residual_Stream_as_Output_Accumulation:~:text=The%20Models-,Residual%20Stream%20as%20Output%20Accumulation,-The%20residual%20stream). as we move through the layers of the model, shifting information around and processing it, the values in the residual stream represent the accumulation of all the inferences made by the transformer up to that point.

this is neatly illustrated by the **logit lens**. rather than getting predictions from the residual stream at the very end of the model, we can take the value of the residual stream midway through the model and convert it to a distribution over tokens. when we do this, we find surprisingly coherent predictions, especially in the last few layers before the end.
</details>

### transformer blocks

then we have a series of `n_layers` **transformer blocks** (also sometimes called **residual blocks**).

note: a block contains an attention layer *and* an mlp layer, but we say a transformer has $k$ layers if it has $k$ blocks (i.e. $2k$ total layers).



<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-block2.png" width="700">

### attention

first we have attention. this moves information from prior positions in the sequence to the current token.

we do this for *every* token in parallel using the same parameters. the only difference is that we look backwards only (to avoid "cheating"). this means later tokens have more of the sequence that they can look at.

attention layers are the only bit of a transformer that moves information between positions (i.e. between vectors at different sequence positions in the residual stream).

attention layers are made up of `n_heads` heads - each with their own parameters, own attention pattern, and own information how to copy things from source to destination. the heads act independently and additively, we just add their outputs together, and back to the stream.

each head does the following:
- produces an **attention pattern** for each destination token, a probability distribution of prior source tokens (including the current one) weighting how much information to copy.
- moves information (via a linear map) in the same way from each source token to each destination token.

a few key points:

- what information we copy depends on the source token's *residual stream*, but this doesn't mean it only depends on the value of that token, because the residual stream can store more information than just the token identity (the purpose of the attention heads is to move information between vectors at different positions in the residual stream!)
- we can think of each attention head as consisting of two different **circuits**:
  * one circuit determines **where to move information to and from** (this is a function of the residual stream for the source and destination tokens)
  * the other circuit determines **what information to move** (this is a function of only the source token's residual stream)
  * for reasons which will become clear later, we refer to the first circuit as the **QK circuit**, and the second circuit as the **OV circuit**

<details>
<summary>key intuition - attention as generalized convolution</summary>

we can think of attention as a kind of generalized convolution. standard convolution layers in image models work by imposing a "prior of locality", i.e. the assumption that pixels which are close together are more likely to share information. although language has some locality (two words next to each other are more likely to share information than two words 100 tokens apart), the picture is a lot more nuanced, because which tokens are relevant to which others depends on the context of the sentence. for instance, in the sentence `"When Mary and John went to the store, John gave a drink to Mary"`, the names in this sentence are the most important tokens for predicting that the final token will be `"Mary"`, and this is because of the particular context of this sentence rather than the tokens' position.

attention layers are effectively our way of saying to the transformer, "don't impose a prior of locality, but instead develop your own algorithm to figure out which tokens are important to which other tokens in any given sequence."
</details>

below is a schematic diagram of the attention layers. don't worry if you don't follow this right now, we'll go into more detail during implementation.

<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-attn-new-v2.png" width="1050">

### mlp

the mlp layers are just a standard neural network, with a singular hidden layer and a nonlinear activation function. The exact activation isn't conceptually important ([gelu](https://paperswithcode.com/method/gelu) seems to perform best).

Our hidden dimension is normally `d_mlp = 4 * d_model`. Exactly why the ratios are what they are isn't super important (people basically cargo-cult what gpt did back in the day!).

importantly, **the mlp operates on positions in the residual stream independently, and in exactly the same way**. it doesn't move information between positions.

intuition - once attention has moved relevant information to a single position in the residual stream, mlps can actually do computation, reasoning, lookup information, etc. *What the hell is going on inside mlps* is a pretty big open problem in transformer mechanistic interpretability - see the [toy model of superposition paper](https://transformer-circuits.pub/2022/toy_model/index.html) for more on why this is hard.

<details>
<summary>key intuition - mlps as key-value pairs</summary>

we can write the mlp's output as $f(x^T W^{in})W^{out}$, where $W^{in}$ and $W^{out}$ are the different weights of the mlp (ignoring biases), $f$ is the activation function, and $x$ is a vector in the residual stream. this can be rewritten as:

$$
f(x^T W^{in}) W^{out} = \sum_{i=1}^{d_{mlp}} f(x^T W^{in}_{[:, i]}) W^{out}_{[i, :]}
$$

we can view the vectors $W^{in}_{[:, i]}$ as the **input directions**, and $W^{out}_{[i, :]}$ as the **output directions**. we say the input directions are **activated** by certain textual features, and when they are activated, vectors are written in the corresponding output direction. this is very similar to the concept of keys and values in attention layers, which is why these vectors are also sometimes called keys and values (e.g. see the paper [transformer feed-forward layers are key-value memories](https://arxiv.org/pdf/2012.14913.pdf)).

terminology note - sometimes we refer to each of these $d_{mlp}$ input-output pairs as **neurons**.

<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/mlp-neurons-2.png" width="900">

---

here's a step-by-step breakdown of the linear algebra, if it was too fast above. we have:

$$
\begin{aligned}
x^T W^{in} &= x^T [W^{in}_{[:, 1]}\,, ...\;, W^{in}_{[:, n]}] \\
&= (x^T W^{in}_{[:, 1]}\,, \; ...\;, \; x^T W^{in}_{[:, n]})
\end{aligned}
$$

where $W^{in}_{[:, i]}$ are the columns of $W^{in}$. in other words, these values (the pre-gelu activations) are projections of $x$ along the input directions of the neurons.

if we add our activation function and the second matrix, then we get:

$$
\begin{aligned}
f(x^T W^{in})W^{out} &= (f(x^T W^{in}_{[:, 1]})\,, \; ...\;,\; f(x^T W^{in}_{[:, n]})) \begin{bmatrix} \leftarrow W^{out}_{[1, :]} \rightarrow \\ \vdots \\ \leftarrow W^{out}_{[n, :]} \rightarrow \end{bmatrix} \\
&= f(x^T W^{in}_{[:, 1]}) W^{out}_{[1, :]} + \;...\; + f(x^T W^{in}_{[:, n]}) W^{out}_{[n, :]} \\
&= \sum_{i=1}^n f(x^T W^{in}_{[:, i]}) W^{out}_{[i, :]}
\end{aligned}
$$

where $W^{out}_{[i, :]}$ are the rows of $W^{out}$. in other words, our output is a linear combination of the rows of $W^{out}$, with the coefficients of that linear combination given by the projections of $x$ along the columns of $W^{in}$.

</details>

<details>
<summary>key intuition - mlps as knowledge storage</summary>

we can think of mlps as where knowledge gets stored in our transformer. the attention mechanism is what moves information around between sequence positions, but the mlps is where this information is processed, and new information is written into the residual stream which is a function of the old information.

this is deeply connected to the key-value pairs model, since you can treat key-value pairs as a kind of associative memory system (where the key serves as a unique identifier, and the value holds the related information).

another related intuition (for which there is some evidence) is **mlps as memory management**. in an idealized case, we might find that the $i$-th neuron satisfies $W^{in}_{[:, i]} \approx - W^{out}_{[i, :]} \approx \vec v$ for some unit vector $\vec v$, meaning it may be responsible for erasing the positive component of vector $\vec x$ in the direction $\vec v$ (exercise - can you show why this is the case?). this can free up space in the residual stream for other components to write to.
</details>

<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-mlp-new-2.png" width="680">

### unembedding

finally, we unembed!

this just consists of applying a linear map $W_U$, going from final residual stream to a vector of logits - this is the output.

<details>
<summary>aside - tied embeddings</summary>

note: sometimes we use something called a **tied embedding** - this is where we use the same weights for our $W_E$ and $W_U$ matrices. in other words, to get the logit score for a particular token at some sequence position, we just take the vector in the residual stream at that sequence position and take the inner product with the corresponding token embedding vector. This is more training-efficient (because there are fewer parameters in our model), and it might seem pricipled at first. after all, if two words have very similar meanings, shouldn't they have similar embedding vectors because the model will treat them the same, and similar unembedding vectors because they could both be substituted for each other in most output?

however, this is actually not very principled, for the following main reason: **the direct path involving the embedding and unembedding should approximate bigram frequencies**.

let's break down this claim. **bigram frequencies** refers to the frequencies of pairs of words in the english language (e.g. the bigram frequency of "Barack Obama" is much higher than the product of the individual frequencies of the words "Barack" and "Obama"). if our model had no attention heads or mlp layers, then all we have is a linear map from our one-hot encoded token `T` to a probability distribution over the token following `T`. this map is represented by the linear transformation $t \to t^T W_E W_U$ (where $t$ is our one-hot encoded token vector). since the output of this transformation can only be a function of the token `T` (and no earlier tokens), the best we can do is have this map approximate the true frequency of bigrams starting with `T`, which appear in the training data. importantly, **this is not a symmetric map**. we want `T = "Barack"` to result in a high probability of the next token being `"Obama"`, but not the other way around!

even in multi-layer models, a similar principle applies. there will be more paths through the model than just the "direct path" $W_E W_U$, but because of the residual connections there will always exist a direct path, so there will always be some incentive for $W_E W_U$ to approximate bigram frequencies.

</details>

### layer normalization

* simple normalization function applied at the start of each layer (i.e. before each mlp, attention layer, and before the unembedding)
* converts each input vector (independently in parallel for each `(batch, seq)` residual stream vector) to have mean zero and variance 1.
* then applies an elementwise scaling and translation
* cool maths tangent: The scale ($\odot \gamma$) & translate ($+ \beta$) is just a linear map. layer normalization is only applied immediately before another linear map (either the mlp, or the query/key/value linear maps in the attention head, or the unembedding $W_U$). linear compose linear = linear, so we can just fold this into a single effective linear layer and ignore it.
    * `fold_ln=True` flag in `from_pretrained` does this for you.
* layer normalization is annoying for interpretability - it would be linear if not for the fact we divide by the variance, so you can't decompose the contributions of the input to the output independently. but it's *almost* linear - if you're changing a small part of the input you can pretend $\sqrt{\text{Var}[x] + \epsilon}$ is constant, so the operation is linear, but if you're changing $x$ enough to alter the norm substantially it's not linear.

<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-ln.png" width="750">

</details>

### positional embeddings

**problem:** attention operates over all pairs of positions. this means it's symmetric with regards to position - the attention calculation from token 5 to token 1 and token 5 to token 2 are the same by default

this is dumb because nearby tokens are more relevant. there's a lot of dumb hacks for this.

- we'll focus on **learned, absolute positional embeddings**. this means we learn a lookup table mapping the index of the position of each token to a residual stream vector, and add this to the embed.
  - note that we *add* rather than concatenate. this is because the residual stream is shared memory, and likely under significant superposition (the model compresses more features in there than the model has dimensions)
  - we basically never concatenate inside a transformer, unless doing weird shit like generating text efficiently.
- this connects to **attention as generalized convolution**
  - we argued that language does still have locality, and so it's helpful for transformers to have access to the positional information so they "know" two tokens are next to each other (and hence probably relevant to each other).

## 2.2 actual code!

model architecture table (this will be helpful for understanding the results you get when running the code block below):

| Parameter   | Value          |
|-------------|----------------|
| batch       | 1              |
| position    | 35             |
| d_model     | 768            |
| n_heads     | 12             |
| n_layers    | 12             |
| d_mlp       | 3072 (= 4 * `d_model`) |
| d_head      | 64 (= `d_model / n_heads`) |

### parameters and activations

it's important to distinguish between parameters and activations in the model.

- **parameters** are the weights and biases that are learned during training
  * these don't change when the model input changes
  * they can be accessed directly from the model, e.g. `model.W_E` is the embedding matrix
- **activations** are temporary numbers calculated during a forward pass, that are functions of the input
  * we can think of these values as only existing for the duration of a single forward pass, and disappearing afterwards
  * we can use 'hooks' to access these values during a forward pass (more on hooks later), but it doesn't make sense to talk about a model's activations outside the context of some particular input
  * attention patterns and attention scores are activations (this is slightly non-intuitve because they're used in a matrix multiplication with another activation)

#### print all activation shapes of the the reference model

run the following code to print all the activation shapes of the reference model:

In [None]:
for activation_name, activation in cache.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(f"{activation_name:30} {tuple(activation.shape)}")

#### print all parameter shapes of the reference model

In [None]:
for name, param in reference_gpt2.named_parameters():
    # only print for first layer
    if ".0." in name or "blocks" not in name:
        print(f"{name:18} {tuple(param.shape)}")

[this diagram](https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/full-merm.svg) shows the name of all activations and parameters in a fully general transformer model from `TransformerLens` (except for a few at the start and end, like the embedding and unembedding). lots of this won't make sense at first, but you can return to this diagram later and check that you understand most/all parts of it.

there's also an annotated version [here](https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-full-updated.png).

### config

the config object contains all the hyperparameters of the model. we can print the config of the reference model to see what it contains:

In [None]:
# as a reference - note there's a lot of stuff we don't care about in here, to do with library internals or other architectures
print(reference_gpt2.cfg)

we define a stripped down config for our model:

In [None]:
@dataclass
class Config:
    d_model: int = 768 # dimension of the embedding table
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257 # size of the input vocabulary
    init_range: float = 0.02
    n_ctx: int = 1024 # max context length
    d_head: int = 64 # size of an attention head
    d_mlp: int = 3072 # dimension of the hidden layer of a mlp layer
    n_heads: int = 12 # num attention heads for each attention layer
    n_layers: int = 12 # number of transformer blocks / layers (each block contains an attention layer + a mlp layer)


cfg = Config()
print(cfg)

### tests

tests are great, write lightweight ones to use as you go!

**naive test:** generate random inputs of the right shape, input to your model, check whether there's an error and print the correct output.

In [None]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randn(shape).to(device)
    print("input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randint(100, 1000, shape).to(device)
    print("input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("output shape:", output.shape, "\n")


def load_gpt2_test(cls, gpt2_layer, input, test_num=None):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple):
        output = output[0]
    print("output shape:", output.shape)
    try:
        reference_output = gpt2_layer(input)
    except:
        reference_output = gpt2_layer(input, input, input)
    print("reference output shape:", reference_output.shape, "\n")
    comparison = t.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")
    assert 1 - (comparison.sum() / comparison.numel()) < 1e-5, "more than 0.01% of the values are incorrect"
    if test_num:
      test_submit(test_num)

### exercise - implement `LayerNorm`

> ```yaml
> difficulty: 🔴🔴🔴⚪⚪
> importance: 🔵🔵🔵⚪⚪
>
> you should spend up to 10-15 minutes on this exercise.
> ```

the `LayerNorm` should do the following:

* make mean 0
* normalize to have variance 1
* scale with learned weights
* translate with learned bias

we can use the pytorch [`LayerNorm` documentation](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) as a reference

notes:
- the `LayerNorm` implementation always has `affine=True`, i.e. you do learn parameters `w` and `b` (which are represented as $\gamma$ and $\beta$ respectively in the pytorch documentation).
- remember that after the centering and normalization, each vector of length `d_model` in your input should have mean 0 and variance 1.
- as the pytorch documentation page says, your variance should be computed using `unbiased=False`.
- the `layer_norm_eps` argument in your config object corresponds to the $\epsilon$ term in the pytorch documentation (it is included to avoid division-by-zero errors).
- there is a `debug` argument in your config. if `debug=True`, then you can print output like the shape of objects in your `forward` function to help you debug (this is a very useful trick to improve your coding speed).
- in the code below the weights are initialized to ones and the bias to zeros to have by default a variance of one and mean of zero [see nanda's transformer walkthrough](https://youtu.be/dsjUDacBw8o?si=Jfnay1xZUZmjJ_pR&t=492)

(fill in the function, where it says `raise NotImplementedError()` - this will be the basic pattern for most other exercises in this section).

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # raise NotImplementedError()

        r_mean = residual.mean(dim = -1, keepdim = True)
        r_var = residual.var(dim = -1, keepdim = True, unbiased = False)
        r_std = (r_var + self.cfg.layer_norm_eps).sqrt()

        residual_normalized = (residual - r_mean) / r_std
        residual_normalized = (residual_normalized * self.w) + self.b

        return residual_normalized


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11], 1)

<details><summary>solution</summary>

```python
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b
```
</details>

### exercise - implement `Embed`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵🔵⚪⚪
>
> you should spend up to 5-10 minutes on this exercise.
> ```

this is basically a lookup table from tokens to residual stream vectors.

(hint - you can implement this in just one line, without any complicated functions. If you've been working on it for >10 mins, you're probably overthinking it!)

In [None]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        # raise NotImplementedError()
        return self.W_E[tokens]


rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens, 2)

<details>
<summary>help - i keep getting <code>RuntimeError: CUDA error: device-side assert triggered</code>.</summary>

this is a uniquely frustrating type of error message, because it (1) forces you to restart the kernel, and (2) often won't tell you where the error message actually originated from!

you can fix the second problem by adding the line `os.environ['CUDA_LAUNCH_BLOCKING'] = "1"` to the very top of your file (after importing `os`). this won't fix your bug, but it makes sure the correct origin point is identified.

as for actually fixing the bug, this error usually ends up being the result of bad indexing, e.g. you're trying to apply an embedding layer to tokens which are larger than your maximum embedding.
</details>


<details><summary>solution</summary>

```python
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]
```
</details>

### exercise - implement `PosEmbed`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵🔵⚪⚪
>
> you should spend up to 10-15 minutes on this exercise.
> ```

positional embedding can also be thought of as a lookup table, but rather than the indices being our token ids, the indices are just the numbers `0`, `1`, `2`, ..., `seq_len-1` (i.e. the position indices of the tokens in the sequence).

In [None]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        # raise NotImplementedError()

        # print(tokens.shape)
        batch, pos_n = tokens.shape
        return einops.repeat(self.W_pos[:pos_n], "position d_model -> batch position d_model", batch=batch)


rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens, 3)

<details><summary>solution</summary>

```python
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        batch, seq_len = tokens.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)
```
</details>

### exercise - implement `apply_causal_mask`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵🔵🔵🔵
>
> you should spend up to 10-15 minutes on this exercise.
> ```

the causal mask function will be a method of the `Attention` class.
tt will take in attention scores, and apply a mask to them so that the model can only attend to previous positions (i.e. the model can't cheat by looking at future positions).

we will implement this function first, and test it, before moving onto the `forward` method of the `Attention` class.

notes:

- we can use [`torch.where`](https://pytorch.org/docs/stable/generated/torch.where.html), or the [`torch.masked_fill_`](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill.html) function when masking the attention scores
- the [`torch.triu`](https://pytorch.org/docs/stable/generated/torch.triu.html) function is useful for creating a mask that is true for all positions we want to set probabilities to zero for
- make sure to use the `self.IGNORE` attribute to set the masked positions to negative infinity

<details>
<summary>auestion - why do you think we mask the attention scores by setting them to negative infinity, rather than the attention probabilities by setting them to zero?</summary>

if we masked the attention probabilities, then the probabilities would no longer sum to 1.

we want to mask the scores and *then* take softmax, so that the probabilities are still valid probabilities (i.e. they sum to 1), and the values in the masked positions have no influence on the model's output.
</details>

In [None]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def apply_causal_mask(
        self,
        attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"],
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # raise NotImplementedError()

        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device = attn_scores.device)
        mask = t.triu(all_ones, diagonal = 1).bool()

        # print(mask)

        attn_scores = attn_scores.masked_fill(mask, self.IGNORE)

        return attn_scores

<details>
<summary>Hint (pseudocode)</summary>

```python
def apply_causal_mask(
    self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:

    # Define a mask that is True for all positions we want to set probabilities to zero for

    # Apply the mask to attention scores, then return the masked scores
```
</details>


<details><summary>Solution</summary>

```python
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def apply_causal_mask(
        self,
        attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"],
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores
```
</details>

In [None]:
from functools import partial

def test_causal_mask(apply_causal_mask):
    cfg = Config()
    attn_scores = t.randn((1, 1, 5, 5)).to(device)  # (batch, n_heads, query_pos, key_pos)
    attn = Attention(cfg)

    sol_apply_causal_mask = attn.apply_causal_mask
    apply_causal_mask = partial(apply_causal_mask, self=attn)

    expected = sol_apply_causal_mask(attn_scores=attn_scores.clone())
    actual = apply_causal_mask(attn_scores=attn_scores.clone())

    def print_scores():
        print(f"Actual Attention Probs: \n {t.softmax(actual, dim=-1)}")

    if t.any(t.isnan(actual)):
        nan_freq = t.sum(t.isnan(actual)).item() / actual.numel()
        print_scores()
        raise ValueError(
            f"Your masked attention scores contains {nan_freq * 100}% NaNs. Make sure you aren't multiplying 0 * neg inf anywhere."
        )

    attn_probs = t.softmax(
        actual, dim=-1
    )  # ignoring the scale factor, we just want to check if the mask is applied correctly

    if t.any(t.isnan(attn_probs)):
        print_scores()
        nan_freq = t.sum(t.isnan(attn_probs)).item() / attn_probs.numel()
        raise ValueError(
            f"Your post-softmax masked attention scores contains {nan_freq * 100}% NaNs. Make sure you aren't setting an entire row to neg inf."
        )

    if not t.allclose(actual, expected):
        print_scores()
        t.testing.assert_close(actual, expected)

    print("All tests in `test_causal_mask` passed!")

test_causal_mask(Attention.apply_causal_mask)

### exercise - implement `Attention`

> ```yaml
> difficulty: 🔴🔴🔴🔴⚪
> importance: 🔵🔵🔵🔵🔵
>
> you should spend up to 30-45 minutes on this exercise.
> ```

some terminology
- the destination token, ie. the current token, or the query token
- the source token(s), ie. the previous token(s) (including the current token), or the key token(s)

**step 1:** produce an attention pattern - for each destination token, a probability distribution over previous tokens (including current token)
- linear map from input -> query, key shape `[batch, position, head_index, d_head]`
- dot product every *pair* of queries and keys to get attn_scores `[batch, head_index, query_pos, key_pos]` (query = dest, key = source)
- **scale** and mask `attn_scores` to make it lower triangular, i.e. causal
- softmax along the `key_pos` dimension, to get a probability distribution for each query (destination) token - this is our attention pattern!

**step 2:** move information from source tokens to destination token using attention pattern (move = apply linear map)
- linear map from input -> value `[batch, key_pos, head_index, d_head]`
- mix along the `key_pos` with attn pattern to get `z`, which is a weighted average of the value vectors `[batch, query_pos, head_index, d_head]`
- map to output, `[batch, position, d_model]` (position = query_pos, we've summed over all heads)

note: when we say **scale**, we mean dividing by `sqrt(d_head)`. the purpose of this is to avoid vanishing gradients (which is a big problem when we're dealing with a function like softmax - if one of the values is much larger than all the others, the probabilities will be close to 0 or 1, and the gradients will be close to 0).

below is a much larger, more detailed version of the attention head diagram from earlier. this should give you an idea of the actual tensor operations involved.

a few clarifications on this diagram:
- whenever there is a third dimension shown in the pictures, this refers to the `head_index` dimension. we can see that all operations within the attention layer are done independently for each head.
- the objects in the box are activations; they have a batch dimension (for simplicity, we assume the batch dimension is 1 in the diagram). The objects to the right of the box are our parameters (weights and biases); they have no batch dimension.
- we arrange the keys, queries and values as `(batch, seq_pos, head_idx, d_head)`, because the biases have shape `(head_idx, d_head)`, so this makes it convenient to add the biases (recall the rules of array broadcasting!).

<img src="https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-attn-30.png" width="1400">

<details>
<summary><b>a few extra notes on attention (optional)</b></summary>

usually we have the relation `e = n * h` (i.e. `d_model = num_heads * d_head`). there are some computational justifications for this, but mostly this is just done out of convention (just like how we usually have `d_mlp = 4 * d_model`!).

the names **keys**, **queries** and **values** come from their analogy to retrieval systems. Broadly speaking:

- the **queries** represent some information that a token is **"looking for"**
- the **keys** represent the information that a token **"contains"**
  * so the attention score being high basically means that the source (key) token contains the information which the destination (query) token **is looking for**
- the **values** represent the information that is actually taken from the source token, to be moved to the destination token

this diagram can better help us understand the difference between the **QK** and **OV** circuit. we'll discuss this just briefly here, and will go into much more detail later on.

whe **QK** circuit consists of the operation of the $W_Q$ and $W_K$ matrices. in other words, it determines the attention pattern, i.e. where information is moved to and from in the residual stream. the functional form of the attention pattern $A$ is:

$$
A = \text{softmax}\left(\frac{x W_Q W_K^T x^T}{\sqrt{d_{head}}}\right)
$$

where $x$ is the residual stream (shape `[seq_len, d_model]`), and $W_Q$, $W_K$ are the weight matrices for a single head (i.e. shape `[d_model, d_head]`).

whe **OV** circuit consists of the operation of the $W_V$ and $W_O$ matrices. once attention patterns are fixed, these matrices operate on the residual stream at the source position, and their output is the thing which gets moved from source to destination position.

the functional form of an entire attention head is:

$$
\begin{aligned}
\text{output} &= \text{softmax}\left(\frac{x W_Q W_K^T x^T}{\sqrt{d_{head}}}\right) (x W_V W_O) \\
    &= Ax W_V W_O
\end{aligned}
$$

where $W_V$ has shape `[d_model, d_head]`, and $W_O$ has shape `[d_head, d_model]`.

here, we can clearly see that the **QK circuit** and **OV circuit** are doing conceptually different things, and should be thought of as two distinct parts of the attention head.

again, don't worry if you don't follow all of this right now - we'll go into **much** more detail on all of this in subsequent exercises. The purpose of the discussion here is just to give you a flavour of what's to come!

</details>

before implemention `Attention`, it's useful to visualize and play around with attention patterns - what exactly are we looking at here? (click on a head to lock onto just showing that head's pattern, it'll make it easier to interpret)

In [None]:
import circuitsvis as cv
from IPython.display import display

display(
    cv.attention.attention_patterns(
        tokens=reference_gpt2.to_str_tokens(reference_text), attention=cache["pattern", 0][0]
    )
)

you can also use the `attention_heads` function, which presents the data in a different way (the syntax is exactly the same as `attention_patterns`).

note: if you display this in vscode then it may exhibit a bug where the main plot continually shrinks in size - if this happens, you should instead save the html (i.e. with `html = cv.attention.attention_heads(...); with open("attn_heads.html", "w") as f: f.write(str(html))`) and open the plot in your browser.

<!-- <details>
<summary>Help - my <code>attention_heads</code> plots are behaving weirdly.</summary>

This seems to be a bug in `circuitsvis` - on VSCode, the attention head plots continually shrink in size.

Until this is fixed, one way to get around it is to open the plots in your browser. You can do this inline with the `webbrowser` library:

```python
attn_heads = cv.attention.attention_heads(
    tokens=reference_gpt2.to_str_tokens(reference_text),
    attention=cache["pattern", 0][0]
)

path = "attn_heads.html"

with open(path, "w") as f:
    f.write(str(attn_heads))

webbrowser.open(path)
```

To check exactly where this is getting saved, you can print your current working directory with `os.getcwd()`.
</details> -->

In [None]:
display(
    cv.attention.attention_heads(
        tokens=reference_gpt2.to_str_tokens(reference_text), attention=cache["pattern", 0][0]
    )
)

you should fill in the forward method for `Attention` below. you should also copy your code for `apply_causal_mask` to this new implementation of `Attention` (you can delete the rest of the old implementation code).

note: this implementation will probably be the most challenging exercise on this page, so don't worry if it takes you some time! you should look at parts of the solution if you're stuck.

a few tips:
- don't forget the attention score scaling (this should come before the masking).
- try not to combine a large number of operations into a single line of code.
- try to make your variable names descriptive (i.e. it's not just `x = some_fn_of(x), x = some_other_fn_of(x), ...`).

In [None]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def forward(self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # raise NotImplementedError()

        query = einops.einsum(
                  normalized_resid_pre, self.W_Q, 'batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head'
                ) + self.b_Q
        key = einops.einsum(
                  normalized_resid_pre, self.W_K, 'batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head'
                ) + self.b_K

        value = einops.einsum(
                  normalized_resid_pre, self.W_V, 'batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head'
                ) + self.b_V

        # calculate the attention scores, then scale, then apply the causal attention mask, then apply softmax to get the probabilities
        # attn_scores = t.matmul(query, key.transpose(-2, -1)) # [batch posn n_heads n_heads]
        attn_scores = einops.einsum(query, key, 'batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos')
        attn_scores_scaled = attn_scores / (self.cfg.d_head ** 0.50)
        attn_scores_masked = self.apply_causal_mask(attn_scores_scaled)
        attn_pattern = attn_scores_masked.softmax(-1) # [batch n_heads query_pos key_pos]

        # calculate z, the weighted avg of the values, weighted on hte basis of the attention pattern probabilities
        z = einops.einsum(value, attn_pattern, 'batch key_pos n_heads d_head, batch n_heads query_pos key_pos -> batch query_pos n_heads d_head')

        result = einops.einsum(z, self.W_O, 'batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model') + self.b_O

        return result


    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # You should copy your solution from earlier

        # raise NotImplementedError()

        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device = attn_scores.device)
        mask = t.triu(all_ones, diagonal = 1).bool()

        attn_scores = attn_scores.masked_fill(mask, self.IGNORE)

        return attn_scores

In [None]:
test_causal_mask(Attention.apply_causal_mask)
rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"], 4)

<details>
<summary>Hint (pseudocode for the forward method)</summary>

```python
def forward(
    self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
) -> Float[Tensor, "batch posn d_model"]:

    # Calculate query, key and value vectors
    q, k, v = ...

    # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
    attn_scores = ...
    attn_scores_masked = ...
    attn_pattern = ...

    # Take weighted sum of value vectors, according to attention probabilities
    z = ...

    # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
    attn_out = ...
    return attn_out
```
</details>


<details><summary>Solution</summary>

```python
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def forward(self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # Calculate query, key and value vectors
        q = (
            einops.einsum(
                normalized_resid_pre, self.W_Q, "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
            )
            + self.b_Q
        )
        k = (
            einops.einsum(
                normalized_resid_pre, self.W_K, "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
            )
            + self.b_K
        )
        v = (
            einops.einsum(
                normalized_resid_pre, self.W_V, "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
            )
            + self.b_V
        )

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q, k, "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K"
        )
        attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head**0.5)
        attn_pattern = attn_scores_masked.softmax(-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v, attn_pattern, "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head"
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = (
            einops.einsum(z, self.W_O, "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model")
            + self.b_O
        )

        return attn_out

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores
```
</details>

### exercise - implement `MLP`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵🔵🔵⚪
>
> you should spend up to 10-15 minutes on this exercise.
> ```

implement the mlp layer, which consists of:
- a linear layer, with weight `W_in`, bias `b_in`
- a nonlinear function (we usually use gelu; the function `gelu_new` has been imported for this purpose)
- a linear layer, with weight `W_out`, bias `b_out`

In [None]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # raise NotImplementedError()

        x = einops.einsum(normalized_resid_mid, self.W_in, 'batch posn d_model, d_model d_mlp -> batch posn d_mlp') + self.b_in
        x = gelu_new(x)
        x = einops.einsum(x, self.W_out, 'batch posn d_mlp, d_mlp d_model -> batch posn d_model') + self.b_out

        return x

rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"], 5)

<details><summary>solution</summary>

```python
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        pre = (
            einops.einsum(
                normalized_resid_mid, self.W_in, "batch position d_model, d_model d_mlp -> batch position d_mlp"
            )
            + self.b_in
        )
        post = gelu_new(pre)
        mlp_out = (
            einops.einsum(post, self.W_out, "batch position d_mlp, d_mlp d_model -> batch position d_model")
            + self.b_out
        )
        return mlp_out
```
</details>

### exercise - implement `TransformerBlock`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵🔵⚪⚪
>
> you should spend up to 10-15 minutes on this exercise.
> ```

now, we can put together the attention, mlp and layernorms into a single transformer block. remember to implement the residual connections correctly!

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre: Float[Tensor, "batch position d_model"]) -> Float[Tensor, "batch position d_model"]:
        # raise NotImplementedError()

        x = self.ln1(resid_pre)
        resid_mid = self.attn(x) + resid_pre
        x = self.ln2(resid_mid)
        x = self.mlp(x) + resid_mid

        return x


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0], 6)

<details>
<summary>help - i'm getting 100% accuracy on all modules before this point, but only about 90% accuracy on this one.</summary>

this might be because your layernorm implementation divides by `std + eps` rather than `(var + eps).sqrt()`. the latter matches the implementation used by GPT-2 (and this error only shows up in these tests).

</details>


<details><summary>solution</summary>

```python
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre: Float[Tensor, "batch position d_model"]) -> Float[Tensor, "batch position d_model"]:
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post
```
</details>

### exercise - implement `Unembed`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵🔵⚪⚪
>
> you should spend up to ~10 minutes on this exercise.
> ```

the unembedding is just a linear layer (with weight `W_U` and bias `b_U`).

In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        # raise NotImplementedError()

        return einops.einsum(normalized_resid_final, self.W_U, 'batch pos d_model, d_model d_vocab -> batch pos d_vocab') + self.b_U


rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"], 7)

<details><summary>solution</summary>

```python
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return (
            einops.einsum(
                normalized_resid_final,
                self.W_U,
                "batch posn d_model, d_model d_vocab -> batch posn d_vocab",
            )
            + self.b_U
        )
```
</details>

### exercise - implement `DemoTransformer`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵🔵⚪⚪
>
> you should spend up to 10-15 minutes on this exercise.
> ```

In [None]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        # raise NotImplementedError()

        residual = self.embed(tokens) + self.pos_embed(tokens)

        for block in self.blocks:   residual = block(residual)

        logits = self.unembed(self.ln_final(residual))

        return logits


rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens, 8)

<details><summary>solution</summary>

```python
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits
```
</details>

**try it out!**

In [None]:
demo_gpt2 = DemoTransformer(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)

demo_logits = demo_gpt2(tokens)

let's take a test string, and calculate the loss!

we're using the formula for **cross-entropy loss**. The cross entropy loss between a modelled distribution $Q$ and target distribution $P$ is:

$$
-\sum_x P(x) \log Q(x)
$$

In the case where $P$ is just the empirical distribution from target classes (i.e. $P(x^*) = 1$ for the correct class $x^*$) then this becomes:

$$
-\log Q(x^*)
$$

in other words, the negative log prob of the true classification.

In [None]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:
    log_probs = logits.log_softmax(dim=-1) # [batch posn d_vocab]
    # get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1) # [batch len(tokens)-1]

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

we can also greedily generate text, by taking the most likely next token and continually appending it to our prompt before feeding it back into the model:

In [None]:
test_string = """The Total Perspective Vortex derives its picture of the whole Universe on the principle of"""
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

in section 4️⃣ we'll learn to generate text in slightly more interesting ways than just argmaxing the output (which can lead to unnatural patterns like repetition, as you can see above).

# 3️⃣ training a transformer

##### learning objectives

- understand how to train a transformer from scratch
- write a basic transformer training loop
- interpret the transformer's falling cross entropy loss with reference to features of the training data (e.g. bigram frequencies)

now that we've built our transformer, and verified that it performs as expected when we load in weights, let's try training it from scratch!

this is a lightweight demonstration of how you can actually train your own gpt-2 with this code! here we train a tiny model on a tiny dataset, but it's fundamentally the same code for training a larger/more real model (though you'll need beefier gpus and data parallelism to do it remotely efficiently, and fancier parallelism for much bigger ones).

for our purposes, we'll train a 2 layer, 4 heads per layer model, with context length 256, for 10*200 steps of batch size 16, just to show what it looks like (and so the notebook doesn't melt your colab / machine!).

## create model

In [None]:
model_cfg = Config(
    debug=False,
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab=reference_gpt2.cfg.d_vocab,
)
model = DemoTransformer(model_cfg)

## training args


note, for this optimization we'll be using **weight decay**.

In [None]:
@dataclass
class TransformerTrainingArgs:
    batch_size = 16
    epochs = 20
    max_steps_per_epoch = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: str | None = "day1-demotransformer"
    wandb_name: str | None = None


args = TransformerTrainingArgs()

## create data

we load in a tiny dataset made by nanda, with the first 10K entries in the 'Pile' (inspired by Stas' version for OpenWebText!)

In [None]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
print(dataset)
# print(dataset[0]["text"][:100]) # Cannot index a streaming dataset

`tokenize_and_concatenate` is a useful function which takes our dataset of strings, and returns a dataset of token ids ready to feed into the model. we then create a dataloader from this tokenized dataset. the useful method `train_test_split` can give us a training and testing set.

In [None]:
tokenized_dataset = tokenize_and_concatenate(
    dataset,
    reference_gpt2.tokenizer,
    streaming=False,
    max_length=model.cfg.n_ctx,
    column_name="text",
    add_bos_token=True,
    num_proc=4,
)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(
    dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
)

when we iterate through these dataloaders, we will find dictionaries with the single key `'tokens'`, which maps to a tensor of token ids with shape `(batch, seq_len)`.

In [None]:
first_batch = train_loader.dataset[: args.batch_size]

print(first_batch.keys())
print(first_batch["tokens"].shape)

## training loop



note: if you did the material on [training loops](https://arena-ch0-fundamentals.streamlit.app/[0.3]_ResNets#training-loop) during the first week, this should all be familiar to you. if not, you can skim that section for an overview of the key concepts. the start of the **training loop** section is most important, and the subsections on [modularisation](https://arena-ch0-fundamentals.streamlit.app/[0.3]_ResNets#modularisation) and [dataclasses](https://arena-ch0-fundamentals.streamlit.app/[0.3]_ResNets#aside-dataclasses) are also very useful. lastly, we'll also be using 'Weights and Biases' (w&b) to train our model - you can read about how to use it [here](https://arena-ch0-fundamentals.streamlit.app/[0.4]_Optimization#what-is-weights-and-biases).

here are (roughly) all the things you should know for the following exercises:
- the key parts of a gradient update step are:
  * calculating the (cross-entropy) loss between a model's output and the true labels,
  * `loss.backward()` - calculate gradients of the loss with respect to the model parameters,
  * `optimizer.step()` - update the model parameters using the gradients,
  * `optimizer.zero_grad()` - zero the gradients so they don't accumulate.
- we can nicely package up training loops into a class, which includes methods for training and validation steps among other things. this helps with writing code that can be reused in different contexts.
- we can use dataclasses to store all the arguments relevant to training in one place, and then pass them to our trainer class. autocompletion is one nice bonus of this!
- be careful of scope here, you want to make sure you're referring to `self.args` within the trainer class, rather than the global `args`.
- you can use w&b to track experiments and log relevant variables. the three essential functions are:
  * `wandb.init()` - initialize a new run, takes arguments `project`, `name` and `config` (among others).
  * `wandb.log()` - log a dictionary of variables, e.g. `{"loss": loss}`. also takes a `step` argument.
  * `wandb.finish()` - called at the end of training (no arguments).

### exercise - write training loop

> ```yaml
> difficulty: 🔴🔴🔴⚪⚪
> importance: 🔵🔵🔵🔵⚪
>
> you should spend up to 10-20 minutes on this exercise.
> ```

you should fill in the methods below. some guidance:

- remember we were able to calculate cross entropy loss using the `get_log_probs` function in the previous section
- you should use the optimizer `t.optim.AdamW` (adam with weight decay), and with hyperparameters `lr` and `weight_decay` taken from your `TransformerTrainingArgs` dataclass instance
0 we've given you the argument `max_steps_per_epoch`, a hacky way of making sure the training phase in each epoch doesn't go on for too long. you can terminate each training phase after this many steps. it's set to a default value that should lead to a very short run demonstrating nontrivial model performance
- remember to move tokens to your device, via `tokens.to(device)` (this should be a global variable, defined at the top of your notebook)
- you can refer back to the training loops from the previous chapter of the course - see assignment 3, section on resnets and model training

In [None]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        self.train_loader = DataLoader(
            dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
        )
        self.test_loader = DataLoader(
            dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
        )

    def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        # raise NotImplementedError()

        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)

        loss = -get_log_probs(logits, tokens).mean()
        loss.backward()

        self.optimizer.step()
        self.optimizer.zero_grad()

        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)

        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        # raise NotImplementedError()

        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            tokens = batch["tokens"].to(device)
            logits: Tensor = self.model(tokens)[:, :-1]
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == tokens[:, 1:]).sum().item()
            total_samples += tokens.size(0) * (tokens.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)

        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()


# See the full run here: https://api.wandb.ai/links/callum-mcdougall/4xtin05h
model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()

<details><summary>solution</summary>

```python
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        self.train_loader = DataLoader(
            dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
        )
        self.test_loader = DataLoader(
            dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
        )

    def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)
        loss = -get_log_probs(logits, tokens).mean()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0

        for batch in tqdm(self.test_loader, desc="Evaluating"):
            tokens = batch["tokens"].to(device)
            logits: Tensor = self.model(tokens)[:, :-1]
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == tokens[:, 1:]).sum().item()
            total_samples += tokens.size(0) * (tokens.size(1) - 1)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()

        wandb.finish()
```
</details>

<!-- Note - this section of the course used to use PyTorch Lightning, but this has now been taken out. If you want, you can look at the old version of the training code which used PyTorch Lightning in the dropdown below.

<details>
<summary>PyTorch Lighting training loop</summary>

```python
class LitTransformer(pl.LightningModule):
	def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer, data_loader: DataLoader):
		super().__init__()
		self.model = model
		self.cfg = model.cfg
		self.args = args
		self.data_loader = data_loader

	def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
		logits = self.model(tokens)
		return logits

	def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Float[Tensor, ""]:
		'''
		Here you compute and return the training loss and some additional metrics for e.g.
		the progress bar or logger.
		'''
		tokens = batch["tokens"].to(device)
		logits = self.model(tokens)
		loss = -get_log_probs(logits, tokens).mean()
		self.log("train_loss", loss)
		return loss

	def configure_optimizers(self):
		'''
		Choose what optimizers and learning-rate schedulers to use in your optimization.
		'''
		optimizer = t.optim.AdamW(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
		return optimizer

	def train_dataloader(self):
		return self.data_loader


litmodel = LitTransformer(args, model, data_loader)
logger = WandbLogger(save_dir=args.log_dir, project=args.log_name, name=args.run_name)

trainer = pl.Trainer(
    max_epochs=args.max_epochs,
    logger=logger,
    log_every_n_steps=args.log_every_n_steps
)
trainer.fit(model=litmodel, train_dataloaders=litmodel.data_loader)
wandb.finish()
```

</details>

<details>
<summary>Explanation for why PyTorch Lightning is no longer used</summary>

TLDR - it provides nice modularization and saving of code, but it abstracts away a lot of the details of training loops, and so isn't very useful for educational purposes. Also, it imposes a lot of structure on how the training loops work without allowing for much flexibility, and lots of the code we'll write later (e.g. linear probes or RL) doesn't fit well into this framework. However, it can be a very useful tool to learn about once you've got the basics down and you're looking to benefit from the suite of extra features it provides.

</details> -->

when you run the code for the first time, you'll have to login to w&b, and paste an api key. after this is done, your w&b training run will start. it'll give you a lot of output text, one line of which will look like:

```
View run at https://wandb.ai/<USERNAME>/<PROJECT-NAME>/runs/<RUN-NAME>
```

which you can click on to visit the run page.

note: to see the plots more clearly in w&B, you can click on the **edit panel** of your plot (the small pencil symbol at the top-right), then move the **smoothing** slider to the right.

### a note on this loss curve (optional)


what's up with the shape of our loss curve? it seems like we start at around 10-11, drops down very fast, but then levels out. it turns out, this is all to do with the kinds of algorithms the model learns during training.



when it starts out, your model will be outputting random noise, which might look a lot like "predict each token with approximately uniform probability", i.e. $Q(x) = 1/d_\text{vocab}$ for all $x$. this gives us a cross entropy loss of $\log (d_\text{vocab})$.

In [None]:
d_vocab = model.cfg.d_vocab

print(f"d_vocab = {d_vocab}")
print(f"cross entropy loss on uniform distribution = {math.log(d_vocab):.4f}")

the next thing we might expect the model to learn is the frequencies of words in the english language. after all, small common tokens like `" and"` or `" the"` might appear much more frequently than others. this would give us an average cross entropy loss of:

$$
- \sum_x p_x \log p_x
$$

where $p_x$ is the actual frequency of the word in our training data.

we can evaluate this quantity as follows:

In [None]:
toks = tokenized_dataset[:]["tokens"].flatten()

d_vocab = model.cfg.d_vocab
freqs = t.bincount(toks, minlength=d_vocab)
probs = freqs.float() / freqs.sum()

distn = t.distributions.categorical.Categorical(probs=probs)
entropy = distn.entropy()

print(f"entropy of training data = {entropy:.3f}")

after unigram frequencies, the next thing our model usually learns is **bigram frequencies** (i.e. the frequency of pairs of adjacent tokens in the training data). for instance, `"I"` and `" am"` are common tokens, but their bigram frequency is much higher than it would be if they occurred independently. bigram frequencies actually take you pretty far, since they also help with:

* some simple grammatical rules (e.g. a full stop being followed by a capitalized word)
* weird quirks of tokenization (e.g. `" manip"` being followed by `"ulative"`)
* common names (e.g. `"Barack"` being followed by `" Obama"`)


after approximating bigram frequencies, we need to start using smarter techniques, like trigrams (which can only be implemented using attention heads), **induction heads** (which we'll learn a lot more about in the next set of exercises!), and fact memorization or more basic grammar and syntax rules. marginal improvements start getting harder around this point, leading to a flattening of our loss curve.

### Exercise (optional) - log completions

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵⚪⚪⚪⚪
>
> You should spend up to 20-40 minutes on this exercise, if you choose to attempt it.
> Note, you might want to come back to this exercise *after* you learn how sampling works.
> ```



Choose a handle of prompts, and log the model's completions on those sentences. We recommend you do this with a lower frequency than loss is logged (e.g. once every 10-100 batches).

The `wandb` syntax for logging text is pretty simple. Firstly, you can just print output as stdout and this is also logged to Weights & Biases (you can find it under the "Logs" section of your run). Alternatively, you can log data in the form of a table, and have it appear next to your other charts:

```python
wandb.log({"completions_table": wandb.Table(
    data = data,
    columns = ["epoch", "step", "text"]
)})
```

where `data` is a list of length-3 lists, with each list containing (epoch, step, text). If you choose this option, we recommend logging the table less frequently than you're sampling from the model, to make sure you're not sending too much data (because unfortunately wandb doesn't have methods to incrementally update the table during logging).

If you want to try this before going through the sampling exercises (which are quite long!), you can use the code below to sample output from the model. Note that the `TransformerSampler` object is already in inference mode, so you don't need to worry about this.

In [None]:
def sampling_fn(model: DemoTransformer, prompt: str) -> str:
    sampler = solutions.TransformerSampler(model, reference_gpt2.tokenizer)
    output = sampler.sample(prompt, temperature=0.7, top_p=0.95, max_tokens_generated=16)
    return output


model = DemoTransformer(model_cfg).to(device)

# Should be entirely random, because it uses a newly initialized model
print(sampling_fn(model, prompt="John and Mary went to the"))

In [None]:
# YOUR CODE HERE - rewrite the TransformerTrainer.train method, so that it logs completions


prompt_list = [
    "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for",
    "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.",
    "John and Mary went to the",
]

model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgsLogText()
trainer = TransformerTrainer(args, model)
trainer.train(sampling_fn, prompt_list)
# Read full report here - https://api.wandb.ai/links/callum-mcdougall/5ex16e5w

<details><summary>Solution</summary>

```python
@dataclass
class TransformerTrainingArgsLogText(TransformerTrainingArgs):
    text_sample_freq: int = 20
    table_log_freq: int = 200

    def __post_init__(self):
        assert (
            self.table_log_freq >= self.text_sample_freq
        ), "You should log the table less frequently than you add text to it."


def train_log_text(self: TransformerTrainer, sampling_fn: Callable, prompt_list: list[str]):
    """
    Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
    for each epoch at `self.args.max_steps_per_epoch` steps.

    This also takes 2 extra arguments:
        sampling_fn: function which takes model & a single prompt (i.e. text string) and returns text string output
        prompt_list: list of prompts we'll log output on
    """
    wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
    accuracy = np.nan
    progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

    # Create a list for storing data
    completions_list = []

    for epoch in range(self.args.epochs):
        for i, batch in enumerate(self.train_loader()):
            loss = self.training_step(batch)
            progress_bar.update()
            progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")

            # Control the adding of text to the table, and the logging of text
            if self.step % self.args.text_sample_freq == 0:
                text_completions = [sampling_fn(self.model, prompt) for prompt in prompt_list]
                completions_list.append([epoch, self.step, *text_completions])
            if self.step % self.args.table_log_freq == 0:
                wandb.log(
                    {
                        "completions_table": wandb.Table(
                            data=completions_list,
                            columns=["epoch", "step", *[f"prompt_{i}" for i in range(len(prompt_list))]],
                        )
                    }
                )

            if i >= self.args.max_steps_per_epoch:
                break

        accuracy = self.evaluate()

    wandb.finish()


TransformerTrainer.train = train_log_text


prompt_list = [
    "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for",
    "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.",
    "John and Mary went to the",
]

model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgsLogText()
trainer = TransformerTrainer(args, model)
trainer.train(sampling_fn, prompt_list)
# Read full report here - https://api.wandb.ai/links/callum-mcdougall/5ex16e5w
```
</details>

You shouldn't expect to see perfect logical coherence from your model, but you should at least see that it respects basic word frequencies, and follows basic rules of grammar some of the time. Hopefully this gives some perspective on how difficult training a transformer can be!

# 4️⃣ sampling from a transformer

##### learning objectives

- learn how to sample from a transformer
  * this includes basic methods like greedy search or top-k, and more advanced methods like beam search
- learn how to cache the output of a transformer, so that it can be used to generate text more efficiently
  * a caching system can reuse computations from previous forward passes to improve the model's text generation speed.
  * optionally, rewrite the sampling functions to make use of caching methods

let's discuss how we might go about producing output from a transformer.

one obvious method to sample tokens from a distribution would be to always take the token assigned the highest probability. but this can lead to some boring and repetitive outcomes, and at worst it can lock our transformer's output into a loop.

first, you should read huggingface's blog post [wow to generate text: using different decoding methods for language generation](https://huggingface.co/blog/how-to-generate). once you've done that, you can start the exercises below.

## `TransformerSampler` class

below, we've given you the `TransformerSampler` class. this contains the following important methods:

- `sample`, which is the highest-level method. it repeatedly calls `sample_next_token` to generate new tokens, until one of the termination criteria is met.
- `sample_next_token`, which samples a single new token based on some hyperparameters. this might involve various different sampling methods and techniques e.g. temperature scaling, top-k sampling, top-p sampling, etc.
- a set of other methods, which apply the previously mentioned sampling methods and techniques.

you can see how `sample_next_token` works, and as an example how greedy sampling is implemented via `greedy_search` - we just continually take the tokens with the highest logits at each step.

in the next exercise you'll implement the `sample` method, and then you'll go on to implement all the other methods.

<details>
<summary>question - why do you think <code>temperature=0.0</code> correspond to greedy sampling?</summary>

to apply a temperature to our sampling (as we'll see later) means to scale all logits by `(1 / temperature)`. The basic intuition here is:

* a higher temperature means a smaller scale factor, so the logits all approach zero, i.e. uniform distribution, and the sampling process is a lot more random (producing more diverse and varied outputs)
* a lower temperature means a larger scale factor, so the logits all approach infinity, i.e. a dirac delta function, and the sampling process is a lot more deterministic (producing less varied output)

as temperature gets close to zero, the difference between the largest logit and second largest logit becomes very large, so the distribution tends to "probability of 1 on the highest-likelihood token", i.e. greedy sampling. you can derive this formally if you prefer.
</details>

### exercise - implement `sample`

> ```yaml
> difficulty: 🔴🔴🔴🔴⚪
> importance: 🔵🔵🔵⚪⚪
>
> you should spend up to 25-40 minutes on this exercise.
> ```

the `sample` method generates new tokens autoregressively, by repeatedly:

- passing the current sequence of tokens through the model to get logits,
- using some sampling technique to select a new token, i.e. `sample_next_token(input_ids, logits, **kwargs)`,
- appending this new token to the input sequence,
- repeating the process until one of the termination criteria is met: either we generate `max_tokens_generated` new tokens, or we generate the end-of-sequence token (which we can access via `self.tokenizer.eos_token_id`).

lastly, we use the `tokenizer.decode` method to return the sampled string. you're also invited to use the `verbose` argument, for printing the decoded sequences while they're being generated (this can help with debugging).

below is some code which tests your sampling function by performing greedy sampling (which means always choosing the most likely next token at each step).

a few hints:
- don't forget about tensor shapes! your model's input should always have a batch dimension, i.e. it should be shape `(1, seq_len)`.
- the `sample_next_token` method will return an integer, so make sure you wrap this in a tensor before concatenating it to the end of your input ids.
- also remember to have your tensors be on the same device (we have a global `device` variable).
- remember to put your model in evaluation mode, using `model.eval()`.

In [None]:
class TransformerSampler:
    def __init__(self, model: DemoTransformer, tokenizer: GPT2TokenizerFast):
        self.model = model
        self.cfg = model.cfg
        self.tokenizer = tokenizer

    @t.inference_mode()
    def sample(self, prompt: str, max_tokens_generated=2, verbose=False, **kwargs):
        """
        returns a string of autoregressively generated text, starting from the prompt.

        sampling terminates at max_tokens_generated, or when the model generates an end-of-sequence token. kwargs are
        passed to sample_next_token, to give detailed instructions on how new tokens are chosen.
        """
        # raise NotImplementedError()

        seq_ip = prompt

        for _ in range(max_tokens_generated):
          # print(f"seq_ip: {seq_ip}")
          tokens_ip = self.tokenizer.encode(seq_ip, return_tensors = "pt")
          # print("tokens_ip.shape:", tokens_ip.shape, "tokens_ip[-1].shape", tokens_ip[-1].shape)
          tokens_ip = tokens_ip.to(device)
          logits = self.model(tokens_ip)
          # print("op logit shape:", logits[-1][-1][:].shape)
          token_op = self.sample_next_token(tokens_ip[-1], logits[-1][-1][:], **kwargs) # tokens_ip [ seq_len ], logits [ d_vocab ]
          op = self.tokenizer.decode(token_op)
          seq_op = seq_ip + op
          # print(f"seq_op: {seq_op}")
          seq_ip = seq_op

        return seq_op


    @staticmethod
    def sample_next_token(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "d_vocab"],
        temperature=1.0,
        top_k=0,
        top_p=0.0,
        frequency_penalty=0.0,
        seed=None,
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"

        # set random seeds for reproducibility
        if seed is not None:
            t.manual_seed(seed)
            np.random.seed(seed)


        # print("type(input_ids):", type(input_ids), "input_ids.shape:", input_ids.shape, "type(logits):", type(logits), "logits.shape:", logits.shape)

        # apply all the specialized sampling methods
        if temperature == 0:
            return TransformerSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = TransformerSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = TransformerSampler.apply_frequency_penalty(input_ids, logits, frequency_penalty)
        if top_k > 0:
            return TransformerSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return TransformerSampler.sample_top_p(logits, top_p)
        return TransformerSampler.sample_basic(logits)

    @staticmethod
    def greedy_search(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        returns the most likely token (as an int).
        """

        idx = logits.argmax(dim = -1)
        return idx


    @staticmethod
    def apply_temperature(logits: Float[Tensor, "d_vocab"], temperature: float) -> Float[Tensor, "d_vocab"]:
        """
        applies temperature scaling to the logits.
        """

        return logits / temperature


    @staticmethod
    def apply_frequency_penalty(
        input_ids: Int[Tensor, "seq_len"], logits: Float[Tensor, "d_vocab"], freq_penalty: float
    ) -> Float[Tensor, "d_vocab"]:
        """
        applies a frequency penalty to the logits.
        """

        f = t.bincount(input_ids, minlength = logits.shape[0])
        logits = logits - (freq_penalty * f)
        return logits


    @staticmethod
    def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
        """
        samples from the distribution defined by the logits.
        """

        dist = t.distributions.categorical.Categorical(logits = logits)
        idx = dist.sample()

        return idx


    @staticmethod
    def sample_top_k(logits: Float[Tensor, "d_vocab"], k: int) -> int:
        """
        samples from the top k most likely tokens.
        """
        # raise NotImplementedError()

        elements_top_k, indxs_top_k = t.topk(logits, k) # [ (values, indices)]

        dist = t.distributions.categorical.Categorical(logits = elements_top_k)
        idx = dist.sample()

        return indxs_top_k[idx]


    @staticmethod
    def sample_top_p(logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1) -> int:
        """
        samples from the most likely tokens which make up at least p cumulative probability.
        """
        #raise NotImplementedError()

        # idx_end = 0

        logits_sorted, indx_sorted = t.sort(logits, descending = True, stable = True)

        probs_cum = t.cumsum(nn.functional.softmax(logits_sorted, dim = 0), dim = 0)

        # print("cum_prob:", probs_cum)

        # choose which tokens to keep, in the set we sample from
        num_keep = t.searchsorted(probs_cum, top_p, side = "left").item() + 1
        num_keep = max(num_keep, min_tokens_to_keep)

        # for i in range(len(cum_prob)):
        #  if cum_prob[i] >= top_p:
        #    idx_end = i
        #    break

        # if i < min_tokens_to_keep - 1: idx_end = min_tokens_to_keep

        # print("idx_end:", idx_end)

        idx_keep = indx_sorted[:num_keep]

        dist = t.distributions.categorical.Categorical(logits = logits[idx_keep])
        idx = dist.sample()

        return idx_keep[idx]


    @t.inference_mode()
    def beam_search(
        self,
        prompt: str,
        num_return_sequences: int,
        num_beams: int,
        max_new_tokens: int,
        no_repeat_ngram_size: int | None = None,
    ) -> list[tuple[float, str]]:
        """
        implements a beam search, by repeatedly performing the `generate` and `filter` steps (starting from the initial
        prompt) until either of the two stopping criteria are met: (1) we've generated `max_new_tokens` tokens, or (2)
        we've generated `num_returns_sequences` terminating sequences.
        """
        # raise NotImplementedError()
        pass


t.set_grad_enabled(False)  # gradients are not necessary for sampling

model = DemoTransformer(Config()).to(device)
model.load_state_dict(reference_gpt2.state_dict(), strict=False)
tokenizer = reference_gpt2.tokenizer
sampler = TransformerSampler(model, tokenizer)

prompt = "Jingle bells, jingle bells, jingle all the way"
print(f"Testing greedy decoding\nPrompt:   {prompt!r}")

expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
output = sampler.sample(prompt, max_tokens_generated=8, temperature=0.0)

print(f"Expected: {expected!r}\nActual:   {output!r}\n")
assert output == expected
test_submit(9)

print("Tests passed!")

<details>
<summary>solution</summary>

```python
@t.inference_mode()
def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
    """
    Returns a string of autoregressively generated text, starting from the prompt.

    Sampling terminates at max_tokens_generated, or when the model generates an end-of-sequence token. kwargs are
    passed to sample_next_token, to give detailed instructions on how new tokens are chosen.
    """
    self.model.eval()
    input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(device)[0]

    for i in range(max_tokens_generated):
        # Get new logits (make sure we don't pass in more tokens than the model's context length)
        logits = self.model(input_ids[None, -self.cfg.n_ctx :])
        # We only take logits for the last token, because this is what we're sampling
        logits = logits[0, -1]
        # Get next token (as a tensor of size (1, 1) so we can concat it to input_ids)
        next_token = t.tensor([TransformerSampler.sample_next_token(input_ids, logits, **kwargs)], device=device)
        # Create new input ids string, with shape (1, old_seq_len + 1)
        input_ids = t.cat([input_ids, next_token], dim=-1)
        # Print out results, if required
        if verbose:
            print(self.tokenizer.decode(input_ids), end="\r")
        # If our new token was the end-of-text token, stop
        if next_token == getattr(self.tokenizer, "eos_token_id", None):
            break

    return self.tokenizer.decode(input_ids)
```

</details>

## sampling with categorical

now, we'll move into implementing specific sampling methods. In each of these cases, you should return to the class definition above and fill in the corresponding method.

pytorch provides a [`distributions`](https://pytorch.org/docs/stable/distributions.html#distribution) package with a number of convenient methods for sampling from various distributions.

for now, we just need [`t.distributions.categorical.Categorical`](https://pytorch.org/docs/stable/distributions.html#categorical). use this to implement `sample_basic`, which just samples from the provided logits (which may have already been modified by the temperature and frequency penalties).

note that this will be slow since we aren't batching the samples, but don't worry about speed for now.

### exercise - `sample_basic`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵🔵⚪⚪⚪
>
> you should spend up to 5-15 minutes on this exercise.
> ```

implement basic sampling in the `TransformerSampler` class above (i.e. the `sample_basic` method), then run the code below to verify your solution works.

In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {" church": 0.0648, " house": 0.0367, " temple": 0.0145, " same": 0.0104, " Church": 0.0097}
frequency_of_top_5 = defaultdict(int)

N = 10_000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits)
    frequency_of_top_5[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word]
    observed_freq = frequency_of_top_5[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."

print("Tests passed!")
test_submit(10)

<details>
<summary>Solution</summary>

```python
@staticmethod
def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
    """
    Samples from the distribution defined by the logits.
    """
    sampled_token = t.distributions.categorical.Categorical(logits=logits).sample()
    return sampled_token.item()
```

</details>

### exercise - `apply_temperature`

> ```yaml
> difficulty: 🔴⚪⚪⚪⚪
> importance: 🔵🔵⚪⚪⚪
>
> you should spend up to 5-10 minutes on this exercise.
> ```

temperature sounds fancy, but it's literally just dividing the logits by the temperature. You should implement this in your `TransformerSampler` class now.

In [None]:
logits = t.tensor([1, 2]).log()

cold_logits = TransformerSampler.apply_temperature(logits, temperature=0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
t.testing.assert_close(cold_logits, 1000.0 * logits)

hot_logits = TransformerSampler.apply_temperature(logits, temperature=1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
t.testing.assert_close(hot_logits, 0.001 * logits)

print("Tests passed!")
test_submit(11)

<details>
<summary>solution</summary>

```python
@staticmethod
def apply_temperature(logits: Float[Tensor, "d_vocab"], temperature: float) -> Float[Tensor, "d_vocab"]:
    """
    Applies temperature scaling to the logits.
    """
    return logits / temperature
```

</details>

### exercise - `apply_frequency_penalty`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵⚪⚪⚪⚪
>
> you should spend up to 10-15 minutes on this exercise.
> ```

the frequency penalty is simple as well: count the number of occurrences of each token, then subtract `freq_penalty` for each occurrence. hint: use `t.bincount` (documentation [here](https://pytorch.org/docs/stable/generated/torch.bincount.html)) to do this in a vectorized way.

you should implement the `apply_frequency_penalty` method in your `TransformerSampler` class now, then run the cell below to check your solution.

<details>
<summary>help - i'm getting a <code>RuntimeError</code>; my tensor sizes don't match.</summary>

look at the documentation page for `t.bincount`. you might need to use the `minlength` argument - why?
</details>

In [None]:
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt")
logits = t.ones(tokenizer.vocab_size)
penalized_logits = TransformerSampler.apply_frequency_penalty(input_ids.squeeze(), logits, 2.0)

assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space, 1-2*6=-11"
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space, 1-2*3=-5"

print("Tests passed!")
test_submit(12)

<details>
<summary>solution</summary>

```python
@staticmethod
def apply_frequency_penalty(
    input_ids: Int[Tensor, "seq_len"], logits: Float[Tensor, "d_vocab"], freq_penalty: float
) -> Float[Tensor, "d_vocab"]:
    """
    Applies a frequency penalty to the logits.
    """
    d_vocab = logits.size(0)
    id_freqs = t.bincount(input_ids, minlength=d_vocab)
    return logits - freq_penalty * id_freqs
```

</details>

### Sampling - Manual Testing

Run the below cell to get a sense for the `temperature` and `freq_penalty` arguments. Play with your own prompt and try other values.

Note: your model can generate newlines or non-printing characters, so calling `print` on generated text sometimes looks awkward on screen. You can call `repr` on the string before printing to have the string escaped nicely.

In [None]:
sampler = TransformerSampler(model, tokenizer)

N_RUNS = 1
your_prompt = "Jingle bells, jingle bells, jingle all the way"
cases = [
    ("High freq penalty", dict(frequency_penalty=100.0)),
    ("Negative freq penalty", dict(frequency_penalty=-3.0)),
    ("Too hot!", dict(temperature=2.0)),
    ("Pleasantly cool", dict(temperature=0.7)),
    ("Pleasantly warm", dict(temperature=0.9)),
    ("Too cold!", dict(temperature=0.01)),
]

table = Table("Name", "Kwargs", "Output", title="Sampling - Manual Testing")

for name, kwargs in cases:
    for i in range(N_RUNS):
        output = sampler.sample(your_prompt, max_tokens_generated=24, **kwargs)
        table.add_row(name, str(kwargs), repr(output) + "\n")

rprint(table)

## top-k sampling

conceptually, the steps in top-k sampling are:
- find the `top_k` largest probabilities (you can use [`torch.topk`](https://pytorch.org/docs/stable/generated/torch.topk.html))
- set all other probabilities to zero
- normalize and sample

### exercise - `sample_top_k`

> ```yaml
> difficulty: 🔴🔴⚪⚪⚪
> importance: 🔵⚪⚪⚪⚪
>
> you should spend up to 5-10 minutes on this exercise.
> ```

implement the method `sample_top_k` now. your implementation should stay in log-space throughout (don't exponentiate to obtain probabilities). this means you don't actually need to worry about normalizing, because `Categorical` accepts unnormalised logits.

In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {" church": 0.0648, " house": 0.0367, " temple": 0.0145, " same": 0.0104, " Church": 0.0097}
topk_5_sum = sum(expected_top_5.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_k=5)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word] / topk_5_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq = {expected_freq:.4f}, observed freq = {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01

<details>
<summary>solution</summary>

```python
@staticmethod
def apply_frequency_penalty(
    input_ids: Int[Tensor, "seq_len"], logits: Float[Tensor, "d_vocab"], freq_penalty: float
) -> Float[Tensor, "d_vocab"]:
    """
    Applies a frequency penalty to the logits.
    """
    d_vocab = logits.size(0)
    id_freqs = t.bincount(input_ids, minlength=d_vocab)
    return logits - freq_penalty * id_freqs

@staticmethod
def sample_basic(logits: Float[Tensor, "d_vocab"]) -> int:
    """
    Samples from the distribution defined by the logits.
    """
    sampled_token = t.distributions.categorical.Categorical(logits=logits).sample()
    return sampled_token.item()

@staticmethod
def sample_top_k(logits: Float[Tensor, "d_vocab"], k: int) -> int:
    """
    Samples from the top k most likely tokens.
    """
    top_k_logits, top_k_token_ids = logits.topk(k)
    # Get sampled token (which is an index corresponding to the list of top-k tokens)
    sampled_token_idx = t.distributions.categorical.Categorical(logits=top_k_logits).sample()
    # Get the actual token id, as an int
    return top_k_token_ids[sampled_token_idx].item()

@staticmethod
def sample_top_p(logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1) -> int:
    """
    Samples from the most likely tokens which make up at least p cumulative probability.
    """
    # Sort logits, and get cumulative probabilities
    logits_sorted, indices = logits.sort(descending=True, stable=True)
    cumul_probs = logits_sorted.softmax(-1).cumsum(-1)
    # Choose which tokens to keep, in the set we sample from
    n_keep = t.searchsorted(cumul_probs, top_p, side="left").item() + 1
    n_keep = max(n_keep, min_tokens_to_keep)
    keep_idx = indices[:n_keep]
    keep_logits = logits[keep_idx]
    # Perform the sampling
    sample = t.distributions.categorical.Categorical(logits=keep_logits).sample()
    return keep_idx[sample].item()
```

</details>

the [gpt-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) famously included an example prompt about unicorns. now it's your turn to see just how cherry picked this example was.

the paper claims they used `top_k=40` and best of 10 samples.

In [None]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."

output = sampler.sample(your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)

rprint(f"Your model said:\n\n[bold dark_orange]{output}")

this is pretty incredible! For some perspective on how much of a paradigm shift even basic models like this represented, we recommend reading [this section from simulators](https://www.lesswrong.com/posts/vJFdjigzmcXMhNTsx/simulators#The_limit_of_sequence_modeling).

## top-p aka nucleus sampling

the basic idea is that we choose the most likely words, up until the total probability of words we've chosen crosses some threshold. then we sample from those chosen words based on their logits.

the steps are:

- sort the probabilities from largest to smallest
- find the cutoff point where the cumulative probability first equals or exceeds `top_p`. We do the cutoff inclusively, keeping the first probability above the threshold.
- if the number of kept probabilities is less than `min_tokens_to_keep`, keep that many tokens instead.
- set all other probabilities to zero
- normalize and sample

for example, if our probabilities were `(0.4, 0.3, 0.2, 0.1)` and our cutoff was `top_p=0.8`, then we'd sample from the first three elements (because their total probability is `0.9` which is over the threshold, but the first two only have a total prob of `0.7` which is under the threshold). once we've chosen to sample from those three, we would renormalise them by dividing by their sum, so the probabilities we use when sampling are `(0.4/0.9, 0.3/0.9, 0.2/0.9)`.

optionally, refer to the paper [the curious sase of neural text degeneration](https://arxiv.org/pdf/1904.09751.pdf) for some comparison of different methods.

### exercise - `sample_top_p`

> ```yaml
> difficulty: 🔴🔴🔴⚪⚪
> importance: 🔵⚪⚪⚪⚪
>
> you should spend up to 15-20 minutes on this exercise.
> ```

In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_10pct = {
    " church": 0.0648,
    " house": 0.0367,  # These are the two most likely tokens, and add up to >10%
}
top_10pct_sum = sum(expected_top_10pct.values())

observed_freqs = defaultdict(int)

N = 7500 # 5000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_p=0.1)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_10pct:
    expected_freq = expected_top_10pct[word] / top_10pct_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."

<details>
<summary>help - i'm stuck on how to implement this function.</summary>

first, sort the logits using the `sort(descending=True)` method (this returns values and indices). then you can get `cumulative_probs` by applying softmax to these logits and taking the cumsum. then, you can decide how many probabilities to keep by using the `t.searchsorted` function.

once you've decided which probabilities to keep, it's easiest to sample from them using the original logits (you should have preserved the indices when you called `logits.sort`). This way, you don't need to worry about renormalising like you would if you were using probabilities.
</details>

<details>
<summary>solution</summary>

```python
@staticmethod
def sample_top_p(logits: Float[Tensor, "d_vocab"], top_p: float, min_tokens_to_keep: int = 1) -> int:
    """
    Samples from the most likely tokens which make up at least p cumulative probability.
    """
    # Sort logits, and get cumulative probabilities
    logits_sorted, indices = logits.sort(descending=True, stable=True)
    cumul_probs = logits_sorted.softmax(-1).cumsum(-1)
    # Choose which tokens to keep, in the set we sample from
    n_keep = t.searchsorted(cumul_probs, top_p, side="left").item() + 1
    n_keep = max(n_keep, min_tokens_to_keep)
    keep_idx = indices[:n_keep]
    keep_logits = logits[keep_idx]
    # Perform the sampling
    sample = t.distributions.categorical.Categorical(logits=keep_logits).sample()
    return keep_idx[sample].item()
```

</details>

now, an example of top-p sampling:

In [None]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
output = sampler.sample(your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
rprint(f"Your model said:\n\n[bold dark_orange]{output}")

## beam search

finally, we'll implement a more advanced way of searching over output: **beam search**.

you should read the [huggingface page](https://huggingface.co/blog/how-to-generate#beam-search) on beam search before moving on.

in beam search, we maintain a list of size `num_beams` completions which are the most likely completions so far as measured by the product of their probabilities. since this product can become very small, we use the sum of log probabilities instead. note - log probabilities are *not* the same as your model's output. we get log probabilities by first taking softmax of our output and then taking log. you can do this with the [`log_softmax`](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html) function / tensor method.

<details>
<summary>Log probabilities are equal to the logit output after being translated by some amount X (where X is a function of the original logit output). Can you prove this?</summary>

Suppose our vector of logits is $x$, and we take softmax to get a vector of probabilities $p$, then log again to get a vector of log probabilities $l$. Then the $i$-th element of this vector of logprobs is:

$$
\begin{align}
l_i &= \log p_i \\
&= \log \frac{\exp(x_i)}{\sum_j \exp(x_j)} \\
&= x_i - \log \sum_j \exp(x_j) \\
&= x_i - C
\end{align}
$$

where $C = \log \sum_j \exp(x_j)$ is the same for all elements. So we can see that $l_i$ is equal to the logit output $x_i$ after being translated by $C$.

It's important not to mix up logits and logprobs!
</details>

<details>
<summary>Why do you think we use log softmax rather than logit output?</summary>

Logit output is translation invariant. If we had two different beams and we were generating the next tokens in those beams, there would be no reasonable way to compare the two beams to each other, because we could shift the logit vector for one beam by a constant amount without changing the distribution.

</details>

at each iteration, we run the batch of completions through the model and take the log-softmax to obtain `d_vocab` log-probs for each completion, or `num_beams * d_vocab` possible next completions in total.

if we kept all of these, then we would have `num_beams * d_vocab * d_vocab` completions after the next iteration which is way too many, so instead we sort them by their score and loop through from best (highest) log probability to worst (lowest).

the illustration below might help (based on real results from this method). Here, we have the following hyperparameters:

```python
num_beams = 3
max_new_tokens = 3
num_return_sequences = 2
```

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/beam-search-3.png" width="1000">

note how after each "generate" stage, we have `num_beams ** 2` possible completions, which we then filter down to `num_beams`. this is because we need this many in order to find the best `num_beams` completions overall - for example, it's possible that all the best beams of length `n+1` come from the same beam of length `n`, in which case we'll need to keep all `num_beams` that we generated from that single beam.

how do we deal with sequences that terminate early (i.e. by generating an EOS token)? answer - we append them to the list of completions which we'll return at the end, and remove them from the generation tree. our algorithm terminates when either all our sequences have length `max_new_tokens` larger than the initial prompt length, or we've generated `num_returns_sequences` terminating sequences.

### exercise - implement `beam_search`

> ```yaml
> difficulty: 🔴🔴🔴🔴🔴
> importance: 🔵⚪⚪⚪⚪
>
> you should spend up to 30-50 minutes on this exercise.
> ```

we've given you one implementation of `beam_search` below, which calls the `generate` and `filter` methods of the `Beams` class (these correspond to the two stages in the diagram above). The `beam_search` method works as follows:

- create a list `final_logprobs_and_completions` for storing the final output, as tuples of (logprob sum, string completion).
- perform `max_new_tokens` steps of generation (producing a new set of beams) and filtering (getting the best beams from these combinations), while also adding terminated beams to the list of best beams
- return these terminated beams plus the best ones we have at the end of the steps.

so all you need to do is fill in the `generate` and `filter` methods. below, you'll find some unit tests for the `generate` and `filter` methods. when you've passed these tests, you should be able to run the full `beam_search` function.

**important note** - by default, beam search produces a lot of repeated words / phrases / sentences. this makes sense - if the model finds some completion with a much higher logit sum than most completions in its beam search space, then it will want to repeat this completion even if it doesn't make a lot of sense in context. a common solution is to ban repetition of n-grams, which you should also implement in the function below. In other words, rather than sampling tokens from each sequence by taking `logprobs.topk(k)` in your `generate` method, you should take the `k` top tokens after filtering out those that give you repeated n-grams of length `no_repeat_ngram_size`. good values of this parameter to try are 2 or 3 (although we recommend you try without this parameter first, so you can see how much of a difference it makes!).

In [None]:
@dataclass
class Beams:
    """class to store beams during beam search"""
    model: DemoTransformer
    tokenizer: GPT2TokenizerFast
    logprob_sums: Float[Tensor, "batch"]
    tokens: Int[Tensor, "batch seq"]

    def __getitem__(self, batch_idx) -> "Beams":
        """Allows you to create new beams from old beams by slicing along batch dim (useful for `filter`)."""
        return Beams(self.model, self.tokenizer, self.logprob_sums[batch_idx], self.tokens[batch_idx])

    @property
    def logprobs_and_completions(self) -> list[tuple[float, str]]:
        """returns self as a list of logprob sums and completions (useful for getting final output)."""
        return [
            (logprob_sum.item(), self.tokenizer.decode(tokens))
            for (logprob_sum, tokens) in zip(self.logprob_sums, self.tokens)
        ]

    def best_log_probs(self, logprobs: Float|Tensor, tokens_old, k: int, no_repeat_ngram_size: int | None = None) -> tuple[Float[Tensor, "k"], Int[Tensor, "k"]]:

      if no_repeat_ngram_size is not None:
        top_logprobs = t.zeros(k, device = self.tokens.device)
        top_tokens = t.zeros(k, dtype=t.int64, device = self.tokens.device)
        j = 0
        is_ngram_repeated = False

        logprobs, tokens = logprobs.sort(dim = -1, descending = True)

        for i in range(len(logprobs)):
          # check if the n-gram is repeated
          seq = [self.tokenizer.decode(token) for token in tokens_old]
          tokens_new = t.cat( (tokens_old, tokens[i].reshape(1)), dim = 0)
          seq_new = [self.tokenizer.decode(token) for token in tokens_new]
          seq_str = ' '.join(seq)
          seq_new_str = ' '.join(seq_new[-no_repeat_ngram_size:])

          if seq_str.find(seq_new_str) != -1: is_ngram_repeated = True

          # if the n-gram is repeated, don't add it to tuple of (logprobs, tokens) that is to be returned
          if is_ngram_repeated:
            is_ngram_repeated = False
            continue
          else:
            top_logprobs[j] = logprobs[i]
            top_tokens[j] = tokens[i]
            j += 1
            if j == k: break
      else:
        top_logprobs, top_tokens = logprobs.topk(k, dim = -1)

      return (top_logprobs, top_tokens)

    def generate(self, k: int, no_repeat_ngram_size: int | None = None) -> "Beams":
        """
        starting from the current set of beams (i.e. self.tokens) and returns a new set of `len(self.tokens) * k` beams,
        containing the best `k` continuations for each of the original beams.

        optional argument `no_repeat_ngram_size` means your model won't generate any sequences with a repeating n-gram
        of this length.
        """
        num_candidates = self.tokens.shape[0] # number of current beams

        lps_old = self.logprob_sums
        tokens_old = self.tokens
        self.logprob_sums = t.zeros(num_candidates * k, device = self.tokens.device)
        self.tokens: Int[Tensor, "batch seq"] = t.zeros(num_candidates * k, self.tokens.shape[1] + 1, dtype = t.int64, device = self.tokens.device)

        for i in range(num_candidates):
          logits = self.model(tokens_old[i].unsqueeze(0)) # input [ batch seq_len ], output: [ batch seq_len d_vocab ]
          logits = logits.squeeze(0)[-1] # [ d_vocab] over the last token returned
          log_probs = t.log_softmax(logits, dim = -1)

          # get the top log_probs and the corresponding tokens
          # top_log_probs, top_tokens = log_probs.topk(k, dim = -1)
          top_log_probs, top_tokens = self.best_log_probs(log_probs, tokens_old[i], k, no_repeat_ngram_size)
          # text = [self.tokenizer.decode(token) for token in top_tokens]
          # print(text)

          # we need a list of tuples to be presented to the filter routine
          # each tuple is the tokens and the corresponding logprob sum
          for j in range(k):
            self.logprob_sums[i*k + j] = t.add(lps_old[i], top_log_probs[j])
            self.tokens[i*k + j, :-1] = tokens_old[i] # all the current tokens...
            self.tokens[i*k + j, -1] = top_tokens[j] # ...and put in the token generated in this step

        return Beams(self.model, self.tokenizer, self.logprob_sums, self.tokens)


    def filter(self, k: int) -> tuple["Beams", "Beams"]:
        """
        returns:
        - best_beams (Beams), filtered version of self, containing all best `k` which are also not terminated.
        - early_terminations (Beams), filtered version of self, containing all best `k` which are also terminated.
        """
        # sort the logprobs and the corresponding tokens
        self.logprob_sums, indxs = t.sort(self.logprob_sums, descending = True)
        self.tokens = self.tokens[indxs]

        beam_best = Beams(self.model, self.tokenizer, self.logprob_sums[:k], self.tokens[:k])

        # add the sequence that have terminated separately
        eos_logprob_sums = []
        eos_tokens = []

        for i in range(len(self.logprob_sums)):
          if "<|endoftext|>" in self.tokenizer.decode(self.tokens[i]):
            eos_logprob_sums.append(self.logprob_sums[i])
            eos_tokens.append(self.tokens[i])

          if len(eos_logprob_sums) == k: break

        beam_early_terminations = Beams(self.model, self.tokenizer, t.Tensor(eos_logprob_sums), t.Tensor(eos_tokens))

        return (beam_best, beam_early_terminations)


    def print(self, title="best completions", max_print_chars=80, title_addl=None) -> None:
        """
        prints out a set of sequences with their corresponding logprob sums.
        """
        if len(self.tokens) == 0:
            return

        if title_addl is None: table = Table("logprob sum", "completion", title=title)
        else: table = Table("logprob sum", "completion", title=f"{title} {title_addl}")

        for logprob_sum, tokens in zip(self.logprob_sums, self.tokens):
            text = self.tokenizer.decode(tokens)
            if len(repr(text)) > max_print_chars:
                text = text[: int(0.3 * max_print_chars)] + " ... " + text[-int(0.7 * max_print_chars) :]
            table.add_row(f"{logprob_sum:>8.3f}", repr(text))

        rprint(table)


@t.inference_mode()
def beam_search(
    self: TransformerSampler,
    prompt: str,
    num_return_sequences: int,
    num_beams: int,
    max_new_tokens: int,
    no_repeat_ngram_size: int | None = None,
) -> list[tuple[float, str]]:
    """
    implements a beam search, by repeatedly performing the `generate` and `filter` steps (starting from the initial prompt) until either of the two stopping criteria are met:
    (1) we've generated `max_new_tokens` tokens, or
    (2) we've generated `num_returns_sequences` terminating sequences.
    """
    assert num_return_sequences <= num_beams
    self.model.eval()

    tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(device)

    final_logprobs_and_completions = []  # we add to this list as we get terminated beams
    best_beams = Beams(self.model, self.tokenizer, t.tensor([0.0]).to(device), tokens)  # start with just 1 beam

    for _ in tqdm(range(max_new_tokens)):
        t.cuda.empty_cache()

        # generate & filter beams
        best_beams = best_beams.generate(k=num_beams, no_repeat_ngram_size=no_repeat_ngram_size)
        best_beams, best_beams_terminated = best_beams.filter(k=num_beams)

        # add terminated beams to our list, and return early if we have enough
        final_logprobs_and_completions.extend(best_beams_terminated.logprobs_and_completions)
        if len(final_logprobs_and_completions) >= num_return_sequences:
            return final_logprobs_and_completions[:num_return_sequences]

    # return terminated beams plus the best ongoing beams of length `orig_len + max_new_tokens`
    final_logprobs_and_completions.extend(best_beams.logprobs_and_completions)
    return final_logprobs_and_completions[:num_return_sequences]


TransformerSampler.beam_search = beam_search

<details>
<summary>Help - I'm stuck on the implementation of <code>no_repeat_ngram_size</code>.</summary>

Here's a method, which you can use in your `generate` function in place of `logprobs.topk(k)`, which filters out the ngrams of length `no_repeat_ngram_size` which have already appeared in `self.tokens`:

```python
def get_topk_non_repeating(
    self,
    logprobs: Float[Tensor, "batch d_vocab"],
    no_repeat_ngram_size: int | None,
    k: int,
) -> tuple[Float[Tensor, "k"], Int[Tensor, "k"]]:
    """
    logprobs:
        tensor of the log-probs for the next token
    no_repeat_ngram_size:
        size of ngram to avoid repeating
    k:
        number of top logits to return, for each beam in our collection

    Returns:
        equivalent to the output of `logprobs.topk(dim=-1)`, but makes sure that no returned tokens would produce an
        ngram of size `no_repeat_ngram_size` which has already appeared in `self.tokens`.
    """
    batch, seq_len = self.tokens.shape

    # If completion isn't long enough for a repetition, or we have no restructions, just return topk
    if (no_repeat_ngram_size is not None) and (seq_len > no_repeat_ngram_size - 1):
        # Otherwise, we need to check for ngram repetitions
        # First, get the most recent `no_repeat_ngram_size-1` tokens
        last_ngram_prefix = self.tokens[:, seq_len - (no_repeat_ngram_size - 1) :]
        # Next, find all the tokens we're not allowed to generate, by checking all past ngrams for a match
        for i in range(seq_len - (no_repeat_ngram_size - 1)):
            ngrams = self.tokens[:, i : i + no_repeat_ngram_size]  # (batch, ngram)
            ngrams_are_repeated = (ngrams[:, :-1] == last_ngram_prefix).all(-1)  # (batch,)
            ngram_end_tokens = ngrams[:, [-1]]  # (batch, 1)
            # Fill logprobs with neginf wherever the ngrams are repeated
            logprobs[range(batch), ngram_end_tokens] = t.where(
                ngrams_are_repeated, -1.0e4, logprobs[range(batch), ngram_end_tokens]
            )

    # Finally, get our actual tokens
    return logprobs.topk(k=k, dim=-1)
```

</details>

<details>
<summary>Solution</summary>

```python
def generate(self, k: int, no_repeat_ngram_size: int | None = None) -> "Beams":
    """
    Starting from the current set of beams (i.e. self.tokens) and returns a new set of `len(self.tokens) * k` beams,
    containing the best `k` continuations for each of the original beams.

    Optional argument `no_repeat_ngram_size` means your model won't generate any sequences with a repeating n-gram
    of this length.
    """
    # Get the output logprobs for the next token (for every sequence in current beams)
    logprobs = self.model(self.tokens)[:, -1, :].log_softmax(-1)

    # Get the top `toks_per_beam` tokens for each sequence
    topk_logprobs, topk_tokenIDs = self.get_topk_non_repeating(logprobs, no_repeat_ngram_size, k=k)

    # Add new logprobs & concat new tokens. When doing this, we need to add an extra `k` dimension since our current
    # logprobs & tokens have shape (batch,) and (batch, seq), but our new ones both have shape (batch, k)
    new_logprob_sums = einops.repeat(self.logprob_sums, "b -> b k", k=k) + topk_logprobs
    new_tokens = t.concat([einops.repeat(self.tokens, "b s -> b k s", k=k), topk_tokenIDs.unsqueeze(-1)], dim=-1)

    return Beams(self.model, self.tokenizer, new_logprob_sums.flatten(), new_tokens.flatten(0, 1))

def filter(self, k: int) -> tuple["Beams", "Beams"]:
    """
    Returns:
        best_beams: Beams
            filtered version of self, containing all best `k` which are also not terminated.
        early_terminations: Beams
            filtered version of self, containing all best `k` which are also terminated.
    """
    # Get the indices of top `k` beams
    top_beam_indices = self.logprob_sums.topk(k=k, dim=0).indices.tolist()
    # Get the indices of terminated sequences
    new_tokens = self.tokens[:, -1]
    terminated_indices = t.nonzero(new_tokens == self.tokenizer.eos_token_id)

    # Get the indices of the `k` best sequences (some terminated, some not terminated)
    best_continuing = [i for i in top_beam_indices if i not in terminated_indices]
    best_terminated = [i for i in top_beam_indices if i in terminated_indices]

    # Return the beam objects from these indices
    return self[best_continuing], self[best_terminated]

def get_topk_non_repeating(
    self,
    logprobs: Float[Tensor, "batch d_vocab"],
    no_repeat_ngram_size: int | None,
    k: int,
) -> tuple[Float[Tensor, "k"], Int[Tensor, "k"]]:
    """
    logprobs:
        tensor of the log-probs for the next token
    no_repeat_ngram_size:
        size of ngram to avoid repeating
    k:
        number of top logits to return, for each beam in our collection

    Returns:
        equivalent to the output of `logprobs.topk(dim=-1)`, but makes sure that no returned tokens would produce an
        ngram of size `no_repeat_ngram_size` which has already appeared in `self.tokens`.
    """
    batch, seq_len = self.tokens.shape

    # If completion isn't long enough for a repetition, or we have no restructions, just return topk
    if (no_repeat_ngram_size is not None) and (seq_len > no_repeat_ngram_size - 1):
        # Otherwise, we need to check for ngram repetitions
        # First, get the most recent `no_repeat_ngram_size-1` tokens
        last_ngram_prefix = self.tokens[:, seq_len - (no_repeat_ngram_size - 1) :]
        # Next, find all the tokens we're not allowed to generate, by checking all past ngrams for a match
        for i in range(seq_len - (no_repeat_ngram_size - 1)):
            ngrams = self.tokens[:, i : i + no_repeat_ngram_size]  # (batch, ngram)
            ngrams_are_repeated = (ngrams[:, :-1] == last_ngram_prefix).all(-1)  # (batch,)
            ngram_end_tokens = ngrams[:, [-1]]  # (batch, 1)
            # Fill logprobs with neginf wherever the ngrams are repeated
            logprobs[range(batch), ngram_end_tokens] = t.where(
                ngrams_are_repeated, -1.0e4, logprobs[range(batch), ngram_end_tokens]
            )

    # Finally, get our actual tokens
    return logprobs.topk(k=k, dim=-1)
```

</details>

example usage of the `Beams` class, and the `print` method, corresponding to the diagram above:

In [None]:
# start with prompt "When I was", get top 3 tokens (and their logprobs), and use that to create & display the top 3 beams
prompt = "When I was"
tokens = tokenizer.encode(prompt, return_tensors="pt").to(device) # [ batch seq_len ]
# logits = model(tokens)
logprobs = model(tokens)[0, -1].log_softmax(-1)
top_logprobs, top_tokens = logprobs.topk(k=3, dim=-1)

new_tokens = t.concat([tokens.repeat(3, 1), top_tokens.unsqueeze(-1)], dim=-1)

beams = Beams(model, tokenizer, logprob_sums=top_logprobs, tokens=new_tokens)
beams.print()

and here are some unit tests for your `generate` and `filter` methods, starting from the prompt `"When I was"` (so your output should match the diagram above).

In [None]:
print("testing generate...")

new_beams = beams.generate(k=3, no_repeat_ngram_size=None)
new_beams.print()

expected_values = [(-3.1, "When I was a kid"), (-4.8, "When I was a child"), (-4.9, "When I was a little")]

# for i, (logprob_sum, completion) in enumerate(new_beams.logprobs_and_completions[:3]):
#  assert abs(logprob_sum - expected_values[i][0]) < 0.1, f"{i}"
#  assert completion == expected_values[i][1], f"{i}"

print("all tests for `generate` passed!")
test_submit(13)

In [None]:
print("testing `filter`...")

best_beams, terminated_beams = new_beams.filter(3)
best_beams.print()
terminated_beams.print(title_addl="terminated beams")

expected_values = [(-3.1, "When I was a kid"), (-3.2, "When I was growing up"), (-4.6, "When I was in the")]

for i, (logprob_sum, completion) in enumerate(best_beams.logprobs_and_completions):
    assert abs(logprob_sum - expected_values[i][0]) < 0.1, f"{i}"
    assert completion == expected_values[i][1], f"{i}"

# assert len(terminated_beams.logprobs_and_completions) == 0

print("all tests for `filter` passed!")
test_submit(14)

lastly, we'll test the `no_repeat_ngram_size` argument. we do this by continually generating new tokens from our starting beams `beams`, and seeing if the model repeats the `I was` ngram (which it will by default unless we prohibit repeating n-grams).

In [None]:
import copy

NO_REPEAT_NGRAM_SIZE = 2

print("testing `no_repeat_ngram_size`... \n\n")

best_beams.print(title="best completions - `best_beams`")
print("\n\n")

beams_ng_nr = copy.deepcopy(best_beams)
for _ in range(5): beams_ng_nr = beams_ng_nr.generate(k=1)
beams_ng_nr.print(title="best completions - generate with no restriction for n-gram size")

assert all(
  "I was" in completion.removeprefix(prompt) for _, completion in beams_ng_nr.logprobs_and_completions
), "without restriction, all beams should be completed as '...I was...'"

print("\n\n")

beams_ng_r = copy.deepcopy(best_beams)
for _ in range(5): beams_ng_r = beams_ng_r.generate(k=1, no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE)
beams_ng_r.print(title=f"best completions - with restriction of n-gram size = {NO_REPEAT_NGRAM_SIZE}")

assert all(
  "I was" not in completion.removeprefix(prompt) for _, completion in new_beams.logprobs_and_completions
), "with no repeated bigrams, no beams should contain a second '...I was...'"

test_submit(15)

Once you've passed all of these unit tests, you can try implementing the full beam search function. It should create a `Beams` object from the initial prompt, and then repeatedly call `generate` and `filter` until the stopping criteria are met.

In [None]:
sampler = TransformerSampler(model, tokenizer)

prompt = "The ships hung in the sky in much the same way that"
orig_len = len(tokenizer.encode(prompt))

final_logitsums_and_completions = sampler.beam_search(
    prompt=prompt,
    num_return_sequences=3,
    num_beams=10,
    max_new_tokens=15,
    no_repeat_ngram_size=2,
)

# print all the best output
for logprob_sum, text in final_logitsums_and_completions:
  avg_logprob_as_prob = t.tensor(logprob_sum / (len(tokenizer.encode(text)) - orig_len)).exp()
  rprint(f"Avg token prob = {avg_logprob_as_prob:.3f}\nBest output:\n[bold dark_orange]{text}\n\n")