In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir('..')

In [None]:
import os

import circuitsvis as cv
import numpy as np
import torch
from IPython.display import HTML
from transformer_lens import HookedTransformer, patching, utils

from plotly_utils import imshow

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: mps


In [4]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
prompt_format = [
    'When John and Mary went to the shops, {} gave the bag to',
    'When Tom and James went to the park, {} gave the ball to',
    'When Dan and Sid went to the shops, {} gave an apple to',
    'After Martin and Amy went to the park, {} gave a drink to',
]

In [6]:
name_pairs = [
    (' Mary', ' John'),
    (' Tom', ' James'),
    (' Dan', ' Sid'),
    (' Martin', ' Amy'),
]

prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]
answers = [names[::i] for names in name_pairs for i in (1, -1)]
answer_tokens = torch.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])

In [7]:
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)

In [8]:
def logits_to_ave_logit_diff(
    logits,
    answer_tokens = answer_tokens,
    per_prompt = False,
):
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

# 1. Introduction

The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then intervene on a specific activation and patch in the corresponding activation from the clean run (i.e., replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer.

We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to localise which activations matter.

The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what's going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent.

One natural thing to patch in is the residual stream at a specific layer and specific position. For example, the model is likely initially doing some processing on the initial tokens and only after using attention to move that information to the end token. So patching in the residual stream at the end token will likely matter a lot in later layers but not at all in early layers.

### 1.1. Noising vs. denoising

We might call this algorithm a type of noising, since we're running the model on a clean input and adding noise by patching in from the corrupted input. We can also consider the opposite algorithm, denoising, where we run the model on a corrupted input and remove noise by patching in from the clean input.

When would you use noising vs denoising? It depends on your goals. The results of denoising are much stronger, because showing that a component or set of components is sufficient for a task is a big deal. On the other hand, the complexity of transformers and interdependence of components means that noising a model can have unpredictable consequences. If loss goes up when we ablate a component, it doesn't necessarily mean that this component was necessary for the task. As an example, ablating `MLP0` in **gpt2-small** seems to make performance much worse on basically any task, but only because it acts as a kind of extended embedding. In fact, it's not doing anything important which is specfic for the IOI task.

Here our clean input will be the original sentences (e.g. `"When Mary and John went to the store, John gave a drink to"`) and our corrupted input will have the subject token flipped (e.g. `"When Mary and John went to the store, Mary gave a drink to"`). Patching by replacing corrupted residual stream values with clean values is a causal intervention which will allow us to understand precisely which parts of the network are identifying the indirect object. If a component is sufficient to recover performance, rather than necessary, then patching in (replacing that component's corrupted output with its clean output) will reverse the signal that this component produces, hence making performance much better.

On this note, we could instead have our corrupted sentence be `"When John and Mary went to the store, Mary gave a drink to"` (i.e. flip all 3 occurrences of names in the sentence). We don't do this here, because the model could point to the indirect object `Mary` in two different ways:
1. via **token information**, i.e. the model has learned that `Mary` is the indirect object, and
2. via **position information**, i.e. the model has learned that the indirect object is the fourth token in the sentence.

If we patch in the clean value for `Mary` from the original sentence, we will be patching in the token information, but not the position information. This means that the model will still have to rely on position information to identify `Mary` as the indirect object, which is not what we want. We want to patch in both token and position information, so we only flip the subject token.

# 2. Defining a metric

Using as a metric a linear function of the logit difference between the correct answer and the incorrect answer allows the measurement of how much the model's output has changed towards the correct answer after patching. Therefore:

- a value of zero means no change (from the performance on the corrupted prompt)
- a value of one means clean performance has been completely recovered

This is because we're performing a **denoising algorithm**; we're looking for activations which are sufficient for recovering a model's performance (i.e. activations which have enough information to recover the correct answer from the corrupted input). Our "null hypothesis" is that the component isn't sufficient, and so patching it by replacing corrupted with clean values doesn't recover any performance.

In [9]:
clean_tokens = tokens

# wwap each adjacent pair to get corrupted tokens
indices = [i + 1 if i % 2 == 0 else i - 1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[indices]

print(
    "Clean string 0:    ",
    model.to_string(clean_tokens[0]),
    "\nCorrupted string 0:",
    model.to_string(corrupted_tokens[0]),
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean string 0:     <|endoftext|>When John and Mary went to the shops,  John gave the bag to 
Corrupted string 0: <|endoftext|>When John and Mary went to the shops,  Mary gave the bag to
Clean logit diff: 2.7098
Corrupted logit diff: -2.7098


In [10]:
def ioi_metric(
    logits,
    answer_tokens = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
):
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

torch.testing.assert_close(ioi_metric(clean_logits).item(), 1.0)
torch.testing.assert_close(ioi_metric(corrupted_logits).item(), 0.0)
torch.testing.assert_close(ioi_metric((clean_logits + corrupted_logits) / 2).item(), 0.5)

# 3. Residual stream patching

Patch in the residual stream at the start of each layer and for each token position. Tokens and their indexes from the first prompt are on the x-axis. In an abuse of notation, the difference is averaged over all 8 prompts, while the labels only come from the first prompt.

It is striking that the computation is highly localized - the relevant information for choosing the correct indirect object over the subject is initially stored in the second subject token and then moved to END token without taking any detours. The model is basically done after layer 8, and the rest of the layers actually slightly impede performance on this particular task.

In [11]:
act_patch_resid_pre = patching.get_act_patch_resid_pre(
    model=model, corrupted_tokens=corrupted_tokens, clean_cache=clean_cache, patching_metric=ioi_metric
)

labels = [f'{tok} {i}' for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

imshow(
    act_patch_resid_pre,
    labels={'x': 'Position', 'y': 'Layer'},
    x=labels,
    title='resid_pre Activation Patching',
    width=700,
)

  0%|          | 0/192 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 192/192 [00:11<00:00, 17.39it/s]


# 4. Patching in residual stream by block

Rather than just patching to the residual stream in each layer, we can also patch just after the attention layer or just after the MLP. This gives is a slightly more refined view of which tokens matter and when.

The function `patching.get_act_patch_block_every` works just like `get_act_patch_resid_pre`, but rather than just patching to the residual stream, it patches to `resid_pre`, `attn_out` and `mlp_out`, and returns a tensor of shape `(3, n_layers, seq_len)`. Also, it cycles through the `resid_pre`, `attn_out` and `mlp_out` and only patchs one of them at a time, rather than patching all three at once.

We see that several attention layers are significant, but that early layers matter on the second subject and later layers matter on END, while all layers don't matter on any other token. This is extremely localized and consistent with the residual stream results.

With respect to the attention heads, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from the subject to END.

In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information. The one exception is MLP0, which matters a lot, but this is a generally true statement about MLP0 rather than being about the circuit on this task.

In [12]:
act_patch_block_every = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)

imshow(
    act_patch_block_every,
    x=labels,
    facet_col=0,
    facet_labels=['Residual Stream', 'Attn Output', 'MLP Output'],
    title='Logit Difference From Patched Attn Head Output',
    labels={'x': 'Sequence Position', 'y': 'Layer'},
    width=1200,
)

100%|██████████| 192/192 [00:10<00:00, 18.81it/s]
100%|██████████| 192/192 [00:10<00:00, 18.93it/s]
100%|██████████| 192/192 [00:10<00:00, 18.92it/s]


### 4.1. Tied embeddings, or what is the MLP doing?

It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. The current accepted hypothesis is that the first MLP layer is essentially acting as an **extension of the embedding**, and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much.

In this framing, it makes sense that MLP0 matters on the second subject, because that's the one position with a different input token, i.e., it has a different extended embedding between the two prompt versions, and all other tokens will have basically the same extended embeddings.

Why does this happen? It seems like most of the effect comes from the fact that **the embedding and unembedding matrices in GPT2-Small are tied**, i.e. the one equals the transpose of the other. On one hand this seems principled - if two words mean similar things (e.g. "big" and "large") then they should be substitutable, i.e. have similar embeddings and unembeddings. This would seem to suggest that the geometric structure of the embedding and unembedding spaces should be related. On the other hand, there's one major reason why this isn't as principled as it seems - the embedding and the unembedding together form the direct path (if we had no other components then the transformer would just be the linear map $x \rightarrow x^T W_E W_U$), and we do not want this to be symmetric because bigram prediction is not symmetric. As an example, if $W_E = W^T_U$, then in order to predict "Barack Obama" as a probable bigram, we'd also have to predict "Obama Barack" with equally high probability, which obviously shouldn't happen. So it makes sense that the first MLP layer might be used in part to overcome this asymmetry: we now think of $MLP_0(x^T W_E)W_U$ as the direct path, which is no longer symmetric when $W_E$ and $W_U$ are tied.

### 4.2. Knowledge storage in the MLP

It is well established that "facts" and "knowledge" are stored in the MLP layers. If the prompt `The White House is where the` is fed to **GPT-2**, it is expected to guess `president` as the answer, as part of the completion `The White House is where the president lives`. And given the prompt `The Haunted House is where the`, the guess `ghosts` would be expected, as part of the completion `The Haunted House is where the ghosts live`.

If this is the case, how does the model do this? Somewhere it has to have the association between White House/President and Haunted House/Ghosts.

In [13]:
clean_prompt, clean_answer = 'The White House is where the', ' president'
corrupted_prompt, corrupted_answer = 'The Haunted House is where the', ' ghosts'

clean_tokens_mlp = model.to_tokens(clean_prompt)
corrupted_tokens_mlp = model.to_tokens(corrupted_prompt)

assert clean_tokens_mlp.shape == corrupted_tokens_mlp.shape, 'Clean and corrupted tokens must have same shape.'

In [14]:
clean_token_mlp = model.to_single_token(clean_answer)
utils.test_prompt(clean_prompt, clean_answer, model)

Tokenized prompt: ['<|endoftext|>', 'The', ' White', ' House', ' is', ' where', ' the']
Tokenized answer: [' president']


Top 0th token. Logit: 13.75 Prob:  6.90% Token: | president|
Top 1th token. Logit: 13.18 Prob:  3.90% Token: | most|
Top 2th token. Logit: 12.61 Prob:  2.20% Token: | world|
Top 3th token. Logit: 12.55 Prob:  2.09% Token: | real|
Top 4th token. Logit: 12.10 Prob:  1.32% Token: | President|
Top 5th token. Logit: 12.04 Prob:  1.25% Token: | nation|
Top 6th token. Logit: 12.02 Prob:  1.22% Token: | biggest|
Top 7th token. Logit: 11.91 Prob:  1.10% Token: | White|
Top 8th token. Logit: 11.87 Prob:  1.06% Token: | Republican|
Top 9th token. Logit: 11.82 Prob:  1.00% Token: | American|


In [15]:
corrupted_token_mlp = model.to_single_token(corrupted_answer)
utils.test_prompt(corrupted_prompt, corrupted_answer, model)

Tokenized prompt: ['<|endoftext|>', 'The', ' Haunted', ' House', ' is', ' where', ' the']
Tokenized answer: [' ghosts']


Top 0th token. Logit: 12.65 Prob:  6.63% Token: | ghosts|
Top 1th token. Logit: 11.50 Prob:  2.10% Token: | ghost|
Top 2th token. Logit: 11.41 Prob:  1.93% Token: | most|
Top 3th token. Logit: 11.17 Prob:  1.51% Token: | story|
Top 4th token. Logit: 10.96 Prob:  1.23% Token: | original|
Top 5th token. Logit: 10.96 Prob:  1.22% Token: | real|
Top 6th token. Logit: 10.87 Prob:  1.12% Token: | spirits|
Top 7th token. Logit: 10.79 Prob:  1.04% Token: | world|
Top 8th token. Logit: 10.70 Prob:  0.94% Token: | Haunted|
Top 9th token. Logit: 10.67 Prob:  0.91% Token: | first|


In [16]:
def answer_metric(
    logits,
    clean_token_mlp = clean_token_mlp,
    corrupted_token_mlp = corrupted_token_mlp,
):
    return logits[:, -1, clean_token_mlp] - logits[:, -1, corrupted_token_mlp]

clean_logits_mlp, clean_cache_mlp = model.run_with_cache(clean_tokens_mlp)
act_patch_block_every = patching.get_act_patch_block_every(model, corrupted_tokens_mlp, clean_cache_mlp, answer_metric)

100%|██████████| 84/84 [00:03<00:00, 24.32it/s]
100%|██████████| 84/84 [00:03<00:00, 25.15it/s]
100%|██████████| 84/84 [00:03<00:00, 27.24it/s]


In [17]:
imshow(
    act_patch_block_every,
    x=['<endoftext>','The', 'White/Haunted', 'House', 'is', 'where', 'the'],
    facet_col=0,
    facet_labels=['Residual Stream', 'Attn Output', 'MLP Output'],
    title='Logit Difference (president - ghosts)',
    labels={'x': 'Sequence Position', 'y': 'Layer'},
    width=1200,
)

# 5. Head patching

In order to properly patch in the attention heads, three dimensions are needed: the layer, the head and the position. The code below patches a head's output over all sequence positions, and returns the results (for each head in the model).

In the plot, we see some of the heads that we observed in our attention plots at the end of last section (e.g. 9.9 having a large positive score, and 10.7 having a large negative score). But we can also see some other important heads, for instance:

- in layers 7-8, there are several important heads. We might deduce that these are the ones responsible for moving information from the second subject to end.
- in the earlier layers, there are some more important heads (e.g. 3.0 and 5.5). We might guess these are performing some primitive logic, e.g. causing the second " John" token to attend to previous instances of itself.

In [18]:
act_patch_attn_head_out_all_pos = patching.get_act_patch_attn_head_out_all_pos(
    model, corrupted_tokens, clean_cache, ioi_metric
)

imshow(
    act_patch_attn_head_out_all_pos,
    labels={'y': 'Layer', 'x': 'Head'},
    title='attn_head_out Activation Patching (All Pos)',
    width=600,
)

100%|██████████| 144/144 [00:08<00:00, 16.98it/s]


# 6. Decomposing heads

Decomposing attention layers into patching in individual heads has already helped us localize the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating where to move information from and to (represented by the attention pattern and implemented via the QK-circuit) and calculating what information to move (represented by the value vectors and implemented by the OV circuit). We can disentangle which of these is important by patching in just the attention pattern or the value vectors.

Rather than just patching on head output, it patches on:

- output, equivalent to patching the value the head writes to the residual stream
- queries, without changing the key or value vectors
- keys, without changing the query or value vectors
- values, without changing the query or key vectors
- attention patterns

The plot below shows at least three different groups of heads:

- earlier heads (3.0, 5.5, 6.9): matter only in relation to their query vectors and attention patterns.
- middle heads (7.3, 7.9, 8.6, 8.10): matter only in relation to their value vectors.
- later heads (9.9, 10.0): matter because of their query vectors, which improve the logit difference 

In [19]:
act_patch_attn_head_all_pos_every = patching.get_act_patch_attn_head_all_pos_every(
    model, corrupted_tokens, clean_cache, ioi_metric
)

imshow(
    act_patch_attn_head_all_pos_every,
    facet_col=0,
    facet_labels=['Output', 'Query', 'Key', 'Value', 'Pattern'],
    title='Activation Patching Per Head (All Pos)',
    labels={'x': 'Head', 'y': 'Layer'},
    width=1200,
)

100%|██████████| 144/144 [00:07<00:00, 18.11it/s]
100%|██████████| 144/144 [00:08<00:00, 17.73it/s]
100%|██████████| 144/144 [00:08<00:00, 17.72it/s]
100%|██████████| 144/144 [00:08<00:00, 17.53it/s]
100%|██████████| 144/144 [00:07<00:00, 18.00it/s]


What is the significance of the results above for the middle heads, ie. the important ones in layers 7 and 8? How should we interpret the fact that value patching has a much bigger effect than the other two forms of patching?

The attention patterns show us that these heads attend from END to the second subject, so we can guess that they're responsible for moving information from the subject to END, which is used to determine the answer. This agrees with our earlier results, when we saw that most of the information gets moved over layers 7 and 8. The fact that value patching is the most important thing for them suggests that the interesting computation goes into what information they move from the subject to END, rather than why END attends to the subject.

In [None]:
def topk_of_Nd_tensor(tensor, k):
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()


# get heads with largest value patching
# from plot above, these are the 4 heads in layers 7 and 8
k = 4
top_heads = topk_of_Nd_tensor(act_patch_attn_head_all_pos_every[3], k=k)

# get all their attention patterns
attn_patterns_for_important_heads = torch.stack([
    clean_cache['pattern', layer][:, head].mean(0)
        for layer, head in top_heads
])

# display results
display(HTML(f'<h2>Top {k} Logit Attribution Heads (from value-patching)</h2>'))
display(cv.attention.attention_patterns(
    attention = attn_patterns_for_important_heads,
    tokens = model.to_str_tokens(tokens[0]),
))

# 7. Conclusions

- heads `9.9`, `9.6`, and `10.0` are the most important heads in terms of directly writing to the residual stream. In all these heads, the `END` attends strongly to the `IO`.
  - visualizing this is done by taking the values written by each head in each layer to the residual stream, and projecting them along the logit diff direction by using `residual_stack_to_logit_diff`.
  - this suggests that these heads are copying `IO` to `END`, to use it as the predicted next token
  - the question then becomes *"how do these heads know to attend to this token, and not attend to the first subject?"*

- all the action is on the `second subject` until layer 7 and then transitions to `END`.
  - attention layers matter a lot, MLP layers not so much, apart from MLP0, likely an extended embedding.
  - visualizing this is done by patching activations on `resid_pre`, `attn_out`, and `mlp_out`.
  - this suggests that there is a cluster of heads in layers 7 and 8 which move information from the `second object` to `END`. 
  - this information may be how heads `9.9`, `9.6` and `10.0` know to attend to `IO`.
  - the question then becomes *"what is this information, how does it end up in the second subject token, and how does END know to attend to it?"*

- the significant heads in layers 7 and 8 are 7.3, 7.9, 8.6, 8.10. 
  - these heads have high activation patching values for their value vectors, less so for their queries and keys.
  - visualizing this is done by patching activations on the value inputs for these heads.
  - this supports the previous observation, and it tells us that the interesting computation goes into what gets moved from the `second subject` to `END`, rather than the fact that `END` attends to the `second subject`.
  - the question then becomes *"what is this information, and how does it end up in the `second subject` token?"*

- as well as the clusters of heads in layers 7 and 8, there's a third cluster of important heads
  - early heads (3.0, 5.5, 6.9) have query vectors that are particularly important for getting good performance.
  - visualizing this is done by patching activations on the query inputs for these heads.

# Sources

1. [Ground truth - Arena::Activation Patching](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification#keeping-track-of-your-guesses-predictions)
2. [Locating and Editing Factual Associations in GPT, by Meng, K, et. al.](https://arxiv.org/pdf/2202.05262)
3. [A Mathematical Framework for Transformer Circuits, by Chris Olah, Neel Nanda, et. al.](https://transformer-circuits.pub/2021/framework/index.html)
4. [Neel Nanda's walkthrough of Mathematica Framework](https://www.youtube.com/watch?v=KV5gbOmHbjU)