# Transformer Mechanistic Interpretability

Following along with [this LessWrong guide](https://www.lesswrong.com/posts/hnzHrdqn3nrjveayv/how-to-transformer-mechanistic-interpretability-in-50-lines) by [Stefan Heimersheim](https://www.lesswrong.com/users/stefan42).

In [1]:
from functools import partial

import plotly.express as px
import torch
from transformer_lens import ActivationCache, HookedTransformer, loading, utils

In [2]:
def imshow(
    tensor,
    xlabel="X",
    ylabel="Y",
    zlabel=None,
    xticks=None,
    yticks=None,
    c_midpoint=0.0,
    c_scale="RdBu",
    show=True,
    **kwargs
):
    tensor = utils.to_numpy(tensor)
    xticks = [str(x) for x in xticks]
    yticks = [str(y) for y in yticks]
    labels = {"x": xlabel, "y": ylabel}
    if zlabel is not None:
        labels["color"] = zlabel
    fig = px.imshow(
        tensor,
        x=xticks,
        y=yticks,
        labels=labels,
        color_continuous_midpoint=c_midpoint,
        color_continuous_scale=c_scale,
        **kwargs
    )
    fig.show()

## Step 1: Grab a model

### List available models

In [3]:
print("Available models:")
print("\n".join(loading.OFFICIAL_MODEL_NAMES))

Available models:
gpt2
gpt2-medium
gpt2-large
gpt2-xl
distilgpt2
facebook/opt-125m
facebook/opt-1.3b
facebook/opt-2.7b
facebook/opt-6.7b
facebook/opt-13b
facebook/opt-30b
facebook/opt-66b
EleutherAI/gpt-neo-125M
EleutherAI/gpt-neo-1.3B
EleutherAI/gpt-neo-2.7B
EleutherAI/gpt-j-6B
EleutherAI/gpt-neox-20b
stanford-crfm/alias-gpt2-small-x21
stanford-crfm/battlestar-gpt2-small-x49
stanford-crfm/caprica-gpt2-small-x81
stanford-crfm/darkmatter-gpt2-small-x343
stanford-crfm/expanse-gpt2-small-x777
stanford-crfm/arwen-gpt2-medium-x21
stanford-crfm/beren-gpt2-medium-x49
stanford-crfm/celebrimbor-gpt2-medium-x81
stanford-crfm/durin-gpt2-medium-x343
stanford-crfm/eowyn-gpt2-medium-x777
EleutherAI/pythia-14m
EleutherAI/pythia-31m
EleutherAI/pythia-70m
EleutherAI/pythia-160m
EleutherAI/pythia-410m
EleutherAI/pythia-1b
EleutherAI/pythia-1.4b
EleutherAI/pythia-2.8b
EleutherAI/pythia-6.9b
EleutherAI/pythia-12b
EleutherAI/pythia-70m-deduped
EleutherAI/pythia-160m-deduped
EleutherAI/pythia-410m-deduped
E

### Create a model

In [4]:
model = HookedTransformer.from_pretrained("gpt2-small")
model.eval()

Loaded pretrained model gpt2-small into HookedTransformer


### Make a prediction

In [5]:
example_prompt = "Famous computer scientist Alan"
logits, cache = model.run_with_cache(example_prompt)
next_token_logits = logits[0, -1]
next_token_prediction = next_token_logits.argmax()
next_word_prediction = model.tokenizer.decode(next_token_prediction)
print(next_word_prediction)

 Turing


### Show dictionary of internal activations

In [6]:
for key, value in cache.items():
	print(f"{key}: {value.shape}, {value.sum()}")

hook_embed: torch.Size([1, 6, 768]), -2.980232238769531e-07
hook_pos_embed: torch.Size([1, 6, 768]), 4.76837158203125e-07
blocks.0.hook_resid_pre: torch.Size([1, 6, 768]), 5.960464477539062e-07
blocks.0.ln1.hook_scale: torch.Size([1, 6, 1]), 1.3880438804626465
blocks.0.ln1.hook_normalized: torch.Size([1, 6, 768]), 3.5762786865234375e-06
blocks.0.attn.hook_q: torch.Size([1, 6, 12, 64]), -255.4600830078125
blocks.0.attn.hook_k: torch.Size([1, 6, 12, 64]), 132.67422485351562
blocks.0.attn.hook_v: torch.Size([1, 6, 12, 64]), 2.7353835105895996
blocks.0.attn.hook_attn_scores: torch.Size([1, 12, 6, 6]), -inf
blocks.0.attn.hook_pattern: torch.Size([1, 12, 6, 6]), 72.0
blocks.0.attn.hook_z: torch.Size([1, 6, 12, 64]), -11.567606925964355
blocks.0.hook_attn_out: torch.Size([1, 6, 768]), 3.552436828613281e-05
blocks.0.hook_resid_mid: torch.Size([1, 6, 768]), 3.266334533691406e-05
blocks.0.ln2.hook_scale: torch.Size([1, 6, 1]), 6.629210948944092
blocks.0.ln2.hook_normalized: torch.Size([1, 6, 768

## Step 2: Analyze behavior

In this example the model must:
1. Recognize that the last token "Alex" is a repetition of a previous occurrence.
2. Copy the last name from after the first "Alex" occurrence to the last token (as a prediction).

In [18]:
baseline_prompt = "Her name was Alex Hart. When Alex"
baseline_continuation = " Hart"
utils.test_prompt(baseline_prompt, baseline_continuation, model)

Tokenized prompt: ['<|endoftext|>', 'Her', ' name', ' was', ' Alex', ' Hart', '.', ' When', ' Alex']
Tokenized answer: [' Hart']


Top 0th token. Logit: 16.23 Prob: 27.03% Token: | was|
Top 1th token. Logit: 15.53 Prob: 13.43% Token: | Hart|
Top 2th token. Logit: 14.35 Prob:  4.12% Token: | died|
Top 3th token. Logit: 14.05 Prob:  3.03% Token: | came|
Top 4th token. Logit: 14.02 Prob:  2.96% Token: |'s|
Top 5th token. Logit: 13.96 Prob:  2.78% Token: | and|
Top 6th token. Logit: 13.75 Prob:  2.26% Token: | arrived|
Top 7th token. Logit: 13.74 Prob:  2.23% Token: | went|
Top 8th token. Logit: 13.72 Prob:  2.19% Token: | started|
Top 9th token. Logit: 13.48 Prob:  1.73% Token: | left|


## Method 1: Residual stream patching

Based on the three pieces of information the model must use, we can generate variations ("corruptions") of the baseline ("clean") prompt:
1. [Baseline] Her name was Alex Hart. When Alex
2. Her name was Alex Carroll. When Alex
3. Her name was Sarah Hart. When Alex
4. Her name was Alex Hart. When Sarah

In [19]:
corrupt_prompt_1 = "Her name was Alex Carroll. When Alex"
corrupt_prompt_2 = "Her name was Sarah Hart. When Alex"
corrupt_prompt_3 = "Her name was Alex Hart. When Sarah"
corrupt_continuation = " Carroll"

### Get activations from corrupt prompt

In [9]:
_, corrupt_cache_1 = model.run_with_cache(corrupt_prompt_1)
_, corrupt_cache_2 = model.run_with_cache(corrupt_prompt_2)
_, corrupt_cache_3 = model.run_with_cache(corrupt_prompt_3)

### Patch model to use corrupt activations

In [10]:
default_layer = "blocks.6.hook_resid_post"
default_pos = 5 # index of last name in prompt

def patch_residual_stream(activations, hook, cache=corrupt_cache_1, layer=default_layer, pos=default_pos):
    activations[:, pos, :] = cache[layer][:, pos, :]
    return activations

model.add_hook(default_layer, patch_residual_stream)

### See how the corrupt activations affect the output of the model

In this case changing the residual stream activations replaces "Hart" with "Carroll" in the predictions.

In [20]:
utils.test_prompt(baseline_prompt, baseline_continuation, model)

Tokenized prompt: ['<|endoftext|>', 'Her', ' name', ' was', ' Alex', ' Hart', '.', ' When', ' Alex']
Tokenized answer: [' Hart']


Top 0th token. Logit: 16.23 Prob: 27.03% Token: | was|
Top 1th token. Logit: 15.53 Prob: 13.43% Token: | Hart|
Top 2th token. Logit: 14.35 Prob:  4.12% Token: | died|
Top 3th token. Logit: 14.05 Prob:  3.03% Token: | came|
Top 4th token. Logit: 14.02 Prob:  2.96% Token: |'s|
Top 5th token. Logit: 13.96 Prob:  2.78% Token: | and|
Top 6th token. Logit: 13.75 Prob:  2.26% Token: | arrived|
Top 7th token. Logit: 13.74 Prob:  2.23% Token: | went|
Top 8th token. Logit: 13.72 Prob:  2.19% Token: | started|
Top 9th token. Logit: 13.48 Prob:  1.73% Token: | left|


### Reset the hooks

In [12]:
model.reset_hooks()

### Plotting the effects

In [21]:
clean_tokens = model.to_str_tokens(baseline_prompt)
layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)]]
n_layers = len(layers)
n_pos = len(clean_tokens)
clean_answer_index = model.tokenizer.encode(baseline_continuation)[0]

In [23]:
def plot_logit_diff(answer: str, title: str, cache: ActivationCache) -> None:
    corrupt_answer_index = model.tokenizer.encode(answer)[0]

    patching_effect = torch.zeros(n_layers, n_pos)
    for l, layer in enumerate(layers):
        for pos in range(n_pos):
            fwd_hooks = [
                (
                    layer,
                    partial(
                        patch_residual_stream, cache=cache, layer=layer, pos=pos
                    ),
                )
            ]
            prediction_logits = model.run_with_hooks(baseline_prompt, fwd_hooks=fwd_hooks)[
                0, -1
            ]
            patching_effect[l, pos] = (
                prediction_logits[clean_answer_index]
                - prediction_logits[corrupt_answer_index]
            )

    token_labels = [f"(pos {i}) {t}" for i, t in enumerate(clean_tokens)]
    imshow(
        patching_effect,
        xticks=token_labels,
        yticks=layers,
        xlabel="pos",
        ylabel="layer",
        zlabel="Logit difference",
        title=title,
        width=800,
        height=600,
    )

In [24]:
plot_logit_diff(corrupt_continuation, "Patching with Alex/Carroll/Alex", corrupt_cache_1)

In [25]:
plot_logit_diff(corrupt_continuation, "Patching with Sarah/Hart/Alex", corrupt_cache_2)

In [27]:
plot_logit_diff(corrupt_continuation, "Patching with Alex/Hart/Sarah", corrupt_cache_3)

## Method 2: Attention head patching