# Prelims

<b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

This is because tokenizer will use .cuda to process input batches in parallel.

# Setup
(No need to read)

In [None]:
%%capture
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f03ce616ad0>

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Analyze GPT-2-Small
80M parameter model

## Loading and Running Models

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


## Test prompts

In [None]:
example_prompt = "Alice is a teacher. Bob is a student. The child is Bob. Carol is a teacher. David is a student. The child is"
example_answer = " David"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Alice', ' is', ' a', ' teacher', '.', ' Bob', ' is', ' a', ' student', '.', ' The', ' child', ' is', ' Bob', '.', ' Carol', ' is', ' a', ' teacher', '.', ' David', ' is', ' a', ' student', '.', ' The', ' child', ' is']
Tokenized answer: [' David']


Top 0th token. Logit: 19.13 Prob: 90.02% Token: | David|
Top 1th token. Logit: 14.97 Prob:  1.41% Token: | Dave|
Top 2th token. Logit: 13.91 Prob:  0.49% Token: | Bob|
Top 3th token. Logit: 13.42 Prob:  0.30% Token: | Carol|
Top 4th token. Logit: 13.19 Prob:  0.24% Token: | D|
Top 5th token. Logit: 13.17 Prob:  0.23% Token: | Dr|
Top 6th token. Logit: 13.04 Prob:  0.20% Token: | Ed|
Top 7th token. Logit: 13.00 Prob:  0.20% Token: | Don|
Top 8th token. Logit: 12.82 Prob:  0.16% Token: | the|
Top 9th token. Logit: 12.54 Prob:  0.12% Token: | Ted|


# GPT2-small analogous inputs analysis

## Try N=1 first

In [None]:
prompts = [
    "Alice is a teacher. Bob is a student. The child is Bob. Carol is a teacher. David is a student. The child is",
]
# List of answers, in the format (correct, incorrect)
answers = [
    (" David", " Carol"),
]

answer_tokens = []
for answer in answers:
    correct_token = model.to_single_token(answer[0])
    incorrect_token = model.to_single_token(answer[1])
    answer_tokens.append((correct_token, incorrect_token))
if len(prompts) > 1:
    answer_tokens = torch.tensor(answer_tokens).cuda()  # if many inputs
else:
    answer_tokens = torch.tensor(answer_tokens)

tokens = model.to_tokens(prompts, prepend_bos=True)
# tokens = tokens.cuda() # Move the tokens to the GPU
original_logits, cache = model.run_with_cache(tokens) # Run the model and cache all activations

### Logit avg. Attribution: layers and heads

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([5.7112])
Average logit difference: 5.711185455322266


In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

Answer residual directions shape: torch.Size([1, 2, 768])
Logit difference directions shape: torch.Size([1, 768])


In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


Get (layer, head) with biggest differences

In [None]:
import torch

# Step 1: Define the number of top values to retrieve
num_top_values = 5

# Step 2: Flatten the 2D tensor
flattened_tensor = per_head_logit_diffs.view(-1)

# Step 3: Get the indices and values of the top elements
top_indices = torch.topk(flattened_tensor, num_top_values)[1]
top_values = torch.topk(flattened_tensor, num_top_values)[0]

# Step 4: Convert the flattened index back to row and column indices
rows = top_indices // per_head_logit_diffs.size(1)
cols = top_indices % per_head_logit_diffs.size(1)

# Step 5: Retrieve the row, column, and value for each top element
for i in range(num_top_values):
    row = rows[i].item()
    col = cols[i].item()
    value = top_values[i].item()
    print(f"Top value {i+1}: Row={row}, Column={col}, Value={value}")

top_indices_tuple = [(row.item(), col.item()) for row, col in zip(rows, cols)]
print(top_indices_tuple)

Top value 1: Row=9, Column=9, Value=2.5532712936401367
Top value 2: Row=8, Column=11, Value=2.1216535568237305
Top value 3: Row=10, Column=6, Value=1.6274135112762451
Top value 4: Row=11, Column=1, Value=0.37464624643325806
Top value 5: Row=8, Column=6, Value=0.36867403984069824
[(9, 9), (8, 11), (10, 6), (11, 1), (8, 6)]


In [None]:
import torch

def get_top_indices(tensor, num_top_values):
    # Flatten the 2D tensor
    flattened_tensor = tensor.view(-1)

    # Get the indices and values of the top elements
    top_indices = torch.topk(flattened_tensor, num_top_values)[1]

    # Convert the flattened index back to row and column indices
    rows = top_indices // tensor.size(1)
    cols = top_indices % tensor.size(1)

    # Create a list of tuples with row and column indices
    top_indices_tuple = [(row.item(), col.item()) for row, col in zip(rows, cols)]

    return top_indices_tuple

top_indices = get_top_indices(per_head_logit_diffs, 10)
for row, col in top_indices:
    print(f"Row={row}, Column={col}")
top_indices

Row=9, Column=9
Row=8, Column=11
Row=10, Column=6
Row=11, Column=1
Row=8, Column=6
Row=9, Column=2
Row=2, Column=10
Row=11, Column=8
Row=8, Column=8
Row=5, Column=11


[(9, 9),
 (8, 11),
 (10, 6),
 (11, 1),
 (8, 6),
 (9, 2),
 (2, 10),
 (11, 8),
 (8, 8),
 (5, 11)]

### Activation patching by switching subjs

In [None]:
corrupted_prompts = ["Alice is a teacher. Bob is a student. The child is Bob. David is a teacher. Carol is a student. The child is"]

In [None]:
# corrupted_prompts = []
# for i in range(0, len(prompts), 2):
#     corrupted_prompts.append(prompts[i+1])
#     corrupted_prompts.append(prompts[i])
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)

Corrupted Average Logit Diff tensor(-4.2183)
Clean Average Logit Diff tensor(5.7112)


In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

# patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("resid_pre", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

We can immediately see that, exactly as predicted, originally all relevant computation happens on the LAST subject token, and at around layer 9, the information is moved to the final token. Moving the residual stream at the correct position near *exactly* recovers performance!

For reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over *all* 8 prompts, while the labels only come from the *first* prompt.

To be easier to interpret, we normalise the logit difference, by subtracting the corrupted logit difference, and dividing by the total improvement from clean to corrupted to normalise
0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance

In [None]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
imshow(patched_residual_stream_diff, x=prompt_position_labels, title="Logit Difference From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})

### Activation patching layers heatmap

We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks.

In [None]:
# patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
# patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)
patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("attn_out", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(patched_attn_logits, answer_tokens)
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("mlp_out", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(patched_mlp_logits, answer_tokens)

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff)
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)

We see that several attention layers are significant but that, matching the residual stream results, early layers matter on the second subject token, and later layers matter on the final token, and layers essentially don't matter on any other token. Extremely localised! As with direct logit attribution, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from the second subject token to the final token.

In [None]:
imshow(patched_attn_diff, x=prompt_position_labels, title="Logit Difference From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})

In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information.

The one exception is MLP 0, which matters a lot, but I think this is misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.

<details> <summary>My takes on MLP0</summary>
It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much.

In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!

I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are <i>not</i> inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this.

I only have suggestive evidence of this, and would love to see someone look into this properly!
</details>

In [None]:
imshow(patched_mlp_diff, x=prompt_position_labels, title="Logit Difference From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})

### Activation patching head heatmap

We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.

The easiest way to do this is to patch in the activation `z`, the "mixed value" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights.

In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][:, :, head_index, :]
    return corrupted_head_vector


# patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"),
                hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

We can now see that, in addition to the name mover heads identified before, in mid-late layers the heads L8H6, L8H10, L7H9 matter and are presumably responsible for moving information from the second subject to the final token. And heads L5H5, L6H9, L3H0 also matter a lot, and are presumably involved in detecting duplicated tokens.

In [None]:
imshow(patched_head_z_diff, title="Logit Difference From Patched Head Output", labels={"x":"Head", "y":"Layer"})

## Try N=50 examples

Use a GPU because tokenizing multiple prompts is faster in parallel using a GPU.

In [None]:
#@title Names list
names = [
    "Michael",
    "Christopher",
    "Jessica",
    "Matthew",
    "Ashley",
    "Jennifer",
    "Joshua",
    "Amanda",
    "Daniel",
    "David",
    "James",
    "Robert",
    "John",
    "Joseph",
    "Andrew",
    "Ryan",
    "Brandon",
    "Jason",
    "Justin",
    "Sarah",
    "William",
    "Jonathan",
    "Stephanie",
    "Brian",
    "Nicole",
    "Nicholas",
    "Anthony",
    "Heather",
    "Eric",
    "Elizabeth",
    "Adam",
    "Megan",
    "Melissa",
    "Kevin",
    "Steven",
    "Thomas",
    "Timothy",
    "Christina",
    "Kyle",
    "Rachel",
    "Laura",
    "Lauren",
    "Amber",
    "Brittany",
    "Danielle",
    "Richard",
    "Kimberly",
    "Jeffrey",
    "Amy",
    "Crystal",
    "Michelle",
    "Tiffany",
    "Jeremy",
    "Benjamin",
    "Mark",
    "Emily",
    "Aaron",
    "Charles",
    "Rebecca",
    "Jacob",
    "Stephen",
    "Patrick",
    "Sean",
    "Erin",
    "Jamie",
    "Kelly",
    "Samantha",
    "Nathan",
    "Sara",
    "Dustin",
    "Paul",
    "Angela",
    "Tyler",
    "Scott",
    "Katherine",
    "Andrea",
    "Gregory",
    "Erica",
    "Mary",
    "Travis",
    "Lisa",
    "Kenneth",
    "Bryan",
    "Lindsey",
    "Kristen",
    "Jose",
    "Alexander",
    "Jesse",
    "Katie",
    "Lindsay",
    "Shannon",
    "Vanessa",
    "Courtney",
    "Christine",
    "Alicia",
    "Cody",
    "Allison",
    "Bradley",
    "Samuel",
]

def filter_names(names):
    return [name for name in names if len(model.tokenizer.tokenize(name)) == 1]
names = filter_names(names)

In [None]:
import random

def make_latestS_prompts(names, template, num_sentences):
    sentences = []
    generated_set = set() # Ensure none of the generated sentences are the same
    while len(sentences) < num_sentences:
        unique_names = random.sample(names, k=4)
        temp_template = template
        sentence_dict = {}
        for i, name in enumerate(unique_names, start=1):
            temp_template = temp_template.replace(f"[S{i}]", name)
            sentence_dict[f'S{i}'] = name
        sentence_dict['text'] = temp_template
        if sentence_dict['text'] not in generated_set:
            generated_set.add(sentence_dict['text'])
            sentences.append(sentence_dict)
    return sentences

template = "[S1] is a teacher. [S2] is a student. The child is [S2]. [S3] is a teacher. [S4] is a student. The child is"
N= 50
latestS_prompts = make_latestS_prompts(names, template, N)
# latestS_prompts

### Avg Logit Diff

Now instead of a dict where each prompt is a dict with subjs, extract 'text' and 's3', 's4' to turn into prompts and answer tokens

In [None]:
prompts = []
answers = []
for prompt in latestS_prompts:
    prompts.append(prompt['text'])
    answers.append( (prompt['S4'], prompt['S3']) )
prompts[1], answers[1]

('Steven is a teacher. Kelly is a student. The child is Kelly. Rachel is a teacher. Charles is a student. The child is',
 ('Charles', 'Rachel'))

In [None]:
# prompts = [
#     "Alice is a teacher. Bob is a student. The child is Bob. Carol is a teacher. David is a student. The child is",
# ]
# # List of answers, in the format (correct, incorrect)
# answers = [
#     (" David", " Carol"),
# ]

answer_tokens = []
for answer in answers:
    correct_token = model.to_single_token(answer[0])
    incorrect_token = model.to_single_token(answer[1])
    answer_tokens.append((correct_token, incorrect_token))
answer_tokens = torch.tensor(answer_tokens).cuda()  # if many inputs

tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.cuda() # Move the tokens to the GPU
original_logits, cache = model.run_with_cache(tokens) # Run the model and cache all activations

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Per prompt logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True))
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", logits_to_ave_logit_diff(original_logits, answer_tokens).item())

Per prompt logit difference: tensor([3.8833, 4.5183, 2.2390, 5.7353, 2.5049, 2.6313, 3.8555, 3.1197, 4.9303,
        4.1759, 3.1439, 2.4993, 2.8341, 2.5784, 2.7652, 3.2179, 0.2949, 1.0065,
        4.2282, 1.5995, 4.6615, 3.0637, 4.4992, 4.8377, 0.9897, 2.1557, 4.2508,
        3.5323, 3.0236, 3.1798, 3.5135, 1.5404, 3.1957, 1.6865, 1.4431, 2.5355,
        2.5747, 5.8563, 1.3548, 2.6309, 2.2194, 0.8228, 1.3434, 6.3136, 3.0683,
        4.2229, 3.1763, 2.9900, 3.5717, 2.6901], device='cuda:0')
Average logit difference: 3.054112195968628


### Activation Patching

Corruption only changes the prompts, not the answers. Else, it wouldn't consider the answers as wrong.

In [None]:
corrupted_prompts = []
for prompt in latestS_prompts:
    new_text = prompt['text'].replace(prompt['S3'], "TEMP_NAME")
    new_text = new_text.replace(prompt['S4'], prompt['S3'])
    new_text = new_text.replace("TEMP_NAME", prompt['S4'])
    corrupted_prompts.append(new_text)
prompts[1], corrupted_prompts[1]

('Steven is a teacher. Kelly is a student. The child is Kelly. Rachel is a teacher. Charles is a student. The child is',
 'Steven is a teacher. Kelly is a student. The child is Kelly. Charles is a teacher. Rachel is a student. The child is')

In [None]:
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", corrupted_average_logit_diff)
print("Clean Average Logit Diff", original_average_logit_diff)

Corrupted Average Logit Diff tensor(-3.2767, device='cuda:0')
Clean Average Logit Diff tensor(3.0541, device='cuda:0')


The corrupted is negative because the right answer is below the (outputted) wrong answer.

The clean is positive because the (outputted) right answer is above the wrong answer.

In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(original_average_logit_diff - corrupted_average_logit_diff)

patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
# patched_residual_stream_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("resid_pre", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))]
imshow(patched_residual_stream_diff, x=prompt_position_labels, title="Logit Difference From Patched Residual Stream", labels={"x":"Position", "y":"Layer"})

What's with pos is_17, layer 5?

### Activation patching layers heatmap

We can apply exactly the same idea, but this time patching in attention or MLP layers. These are also residual components with identical shapes to the residual stream terms, so we can reuse the same hooks.

In [None]:
patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], device="cuda", dtype=torch.float32)
# patched_attn_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)
# patched_mlp_diff = torch.zeros(model.cfg.n_layers, tokens.shape[1], dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("attn_out", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(patched_attn_logits, answer_tokens)
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("mlp_out", layer),
                hook_fn)],
            return_type="logits"
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(patched_mlp_logits, answer_tokens)

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff)
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)

We see that several attention layers are significant but that, matching the residual stream results, early layers matter on the second subject token, and later layers matter on the final token, and layers essentially don't matter on any other token. Extremely localised! As with direct logit attribution, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from the second subject token to the final token.

In [None]:
imshow(patched_attn_diff, x=prompt_position_labels, title="Logit Difference From Patched Attention Layer", labels={"x":"Position", "y":"Layer"})

In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information.

The one exception is MLP 0, which matters a lot, but I think this is misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.

<details> <summary>My takes on MLP0</summary>
It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much.

In this framing, it makes sense that MLP0 matters on the second subject token, because that's the one position with a different input token!

I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are <i>not</i> inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this.

I only have suggestive evidence of this, and would love to see someone look into this properly!
</details>

In [None]:
imshow(patched_mlp_diff, x=prompt_position_labels, title="Logit Difference From Patched MLP Layer", labels={"x":"Position", "y":"Layer"})

### Activation patching head heatmap

We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions (head_index, position and layer), so for now lets patch in a head's output across all positions.

The easiest way to do this is to patch in the activation `z`, the "mixed value" of the attention head. That is, the average of all previous values weighted by the attention pattern, ie the activation that is then multiplied by `W_O`, the output weights.

In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][:, :, head_index, :]
    return corrupted_head_vector

# as long as input and tok use cuda, it doesn't make much diff to use cuda for the storing tensor below
patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
# patched_head_z_diff = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, dtype=torch.float32)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"),
                hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

We can now see that, in addition to the name mover heads identified before, in mid-late layers the heads L8H6, L8H10, L7H9 matter and are presumably responsible for moving information from the second subject to the final token. And heads L5H5, L6H9, L3H0 also matter a lot, and are presumably involved in detecting duplicated tokens.

In [None]:
imshow(patched_head_z_diff, title="Logit Difference From Patched Head Output", labels={"x":"Head", "y":"Layer"})

### Direct Logit Attr

Further, the easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called **direct logit attribution**

**Background:** The central object of a transformer is the **residual stream**. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer).

The logits of a model are `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!

<details> <summary>(*) <b>Background and motivation of the logit difference</b></summary>

Logit difference is actually a *really* nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities).

The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged.

But `log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)`, and so `log_probs(" Mary") - log_probs(" John") = logits(" Mary") - logits(" John")` - the ability to add an arbitrary constant cancels out!

Further, the metric helps us isolate the precise capability we care about - figuring out *which* name is the Indirect Object. There are many other components of the task - deciding whether to return an article (the) or pronoun (her) or name, realising that the sentence wants a person next at all, etc. By taking the logit difference we control for all of that.

Our metric is further refined, because each prompt is repeated twice, for each possible indirect object. This controls for irrelevant behaviour such as the model learning that John is a more frequent token than Mary (this actually happens! The final layernorm bias increases the John logit by 1 relative to the Mary logit)

</details>

<details> <summary>Ignoring LayerNorm</summary>

LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is *almost* a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The `fold_ln` flag when loading a model factors out all the linear parts).

But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of *all* components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the `cache` to apply the global layernorm scaling factor to each constant. See [my clean GPT-2 implementation](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.
</details>

Getting an output logit is equivalent to projecting onto a direction in the residual stream. We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([50, 2, 768])
Logit difference directions shape: torch.Size([50, 768])


To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer.

<details> <summary>Technical details</summary>

`logits = Unembed(LayerNorm(final_residual_stream))`, so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling.

The centering is accounted for with the preprocessing flag `center_writing_weights` which ensures that every weight matrix writing to the residual stream has mean zero.

The learned scaling is folded into the unembedding weights `model.unembed.W_U` via `W_U_fold = layer_norm.weights[:, None] * unembed.W_U`

The learned translation is folded to `model.unembed.b_U`, a bias added to the logits (note that GPT-2 is not trained with an existing `b_U`). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out.

Note that rather than using layernorm scaling we could just study cache["ln_final.hook_normalised"]

</details>

In [None]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff.item())

Final residual stream shape: torch.Size([50, 29, 768])
Calculated average logit diff: 2.994062662124634
Original logit difference: 3.054112195968628


### Logit Lens

We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers.

In [None]:
def residual_stack_to_logit_diff(residual_stack: Float[torch.Tensor, "components batch d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)


Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually *decreases* from there.

**Note:** Hover over each data point to see what residual stream position it's from!

<details> <summary>Details on `accumulated_resid`</summary>
**Key:** `n_pre` means the residual stream at the start of layer n, `n_mid` means the residual stream after the attention part of layer n (`n_post` is the same as `n+1_pre` so is not included)

* `layer` is the layer for which we input the residual stream (this is used to identify *which* layer norm scaling factor we want)
* `incl_mid` is whether to include the residual stream in the middle of a layer, ie after attention & before MLP
* `pos_slice` is the subset of the positions used. See `utils.Slice` for details on the syntax.
* return_labels is whether to return the labels for each component returned (useful for plotting)
</details>

In [None]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulate Residual Stream")

### Logit diff head attribution

We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.

<details> <summary>Decomposing attention output into sums of heads</summary>
The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer)
</details>

We see that only a few heads really matter - heads L9H6 and L9H9 contribute a lot positively (explaining why attention layer 9 is so important), while heads L10H7 and L11H10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful). These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backu name movers)

There are a few meta observations worth making here - our model has 144 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) that attention heads are the right level of abstraction to understand attention. It also really surprising that there are *negative* heads - eg L10H7 makes the incorrect logit 7x *more* likely. I'm not sure what's going on there, though the paper discusses some possibilities.

In [None]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

### Attention Patterns

This requires running "logit diff head attribution", which requires running "logit lens" and "direct logit attr".

Study from what positions they move information from and to. As we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token.

We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt.

<details> <summary>Interpreting Attention Patterns</summary>
An Hooked mistake to make when looking at attention patterns is thinking that they must convey information about the <i>token</i> looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the *residual stream position* corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in ".", "!" or "?"
</details>

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: Optional[ActivationCache]=None,
    local_tokens: Optional[torch.Tensor]=None,
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = tokens[0]

    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

Inspecting the patterns, we can see that both types of name movers attend to the indirect object - this suggests they're simply copying the name attended to (with the OV circuit) and that the interesting part is the circuit behind the attention pattern that calculates *where* to move information from (the QK circuit)

**Instructions for using the attention visualization**:
* Hover over a token to see what tokens it attends to (ie which previous tokens it copies information from) and click on a token to lock focus onto it
* By default, colors are the max attention over all tokens, mixed together for each head and its colour.
* Each head has its own colour. Hover over a head to just show that head's colour and click to lock.
* The grid in the top left shows the attention patterns as a heat map (it's a lower triangular grid because GPT-2 has causal attention - tokens can only attend backwards!). Y axis is the destination token, X axis is the source token.
* Click the tick box to flip the direction - now hovering over a token shows which tokens *attend to it* (ie which tokens it copies information to).

In [None]:
top_k = 3
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")
top_negative_logit_attr_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads")

pysvelte components appear to be unbuilt or stale
Running npm install...
Building pysvelte components with webpack...


### Further viz attn???

The following differs from the former because it uses activation patching to obtain the values. It also gets the top 9 pos heads- early, mid, late categories- instead of just the top 3 heads for both pos and neg.

https://colab.research.google.com/drive/1nD6tfM33StbAqXG5HnYPlC40hKSj8mzD#scrollTo=1p2zyJ_6IAif