In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir('..')

In [None]:
import os

import circuitsvis as cv
import einops
import numpy as np
import torch
from IPython.display import HTML, display
from transformer_lens import HookedTransformer
from transformer_lens import utils

from plotly_utils import line, imshow

In [None]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

In [None]:
prompt_format = [
    'When John and Mary went to the shops, {} gave the bag to',
    'When Tom and James went to the park, {} gave the ball to',
    'When Dan and Sid went to the shops, {} gave an apple to',
    'After Martin and Amy went to the park, {} gave a drink to',
]

In [None]:
name_pairs = [
    (' Mary', ' John'),
    (' Tom', ' James'),
    (' Dan', ' Sid'),
    (' Martin', ' Amy'),
]

prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]
answers = [names[::i] for names in name_pairs for i in (1, -1)]
answer_tokens = torch.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])

In [None]:
# get logits and cache of all internal activations for later analysis
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)
logits, cache = model.run_with_cache(tokens)

In [None]:
def logits_to_ave_logit_diff(
    logits,
    answer_tokens = answer_tokens,
    per_prompt = False,
):
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

ave_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)

# 1. Introduction

The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then intervene on a specific activation and patch in the corresponding activation from the clean run (i.e., replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer.

We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to localise which activations matter.

The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what's going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent.

The diagrams below demonstrate activation patching on an abstract neural network (the nodes represent activations, and the arrows between them are weight connections).

A regular forward pass on the clean input looks like:

<img src="../../assets/act-patch-regular.jpg" width="300"/>

And activation patching from a corrupted input (green) into a forward pass for the clean input (black) looks like:

<img src="../../assets/act-patch-corrupted.jpg" width="400"/>

where the dotted line represents patching in a value (i.e. during the forward pass on the clean input, we replace node $D$ with the value it takes on the corrupted input). Nodes $H$, $G$ and $F$ are colored orange, to represent that they now follow a distribution which is not the same as clean or corrupted.

We can patch into a transformer in many different ways, e.g. values of the residual stream, the MLP, or attention heads' output. We can also get even more granular by patching at particular sequence positions.

<img src="../../assets/act-patch-examples.jpg" width="800"/>

### 1.1. Noising vs. denoising

We might call this algorithm a type of noising, since we're running the model on a clean input and adding noise by patching in from the corrupted input. We can also consider the opposite algorithm, denoising, where we run the model on a corrupted input and remove noise by patching in from the clean input.

When would you use noising vs denoising? It depends on your goals. The results of denoising are much stronger, because showing that a component or set of components is sufficient for a task is a big deal. On the other hand, the complexity of transformers and interdependence of components means that noising a model can have unpredictable consequences. If loss goes up when we ablate a component, it doesn't necessarily mean that this component was necessary for the task. As an example, ablating `MLP0` in **gpt2-small** seems to make performance much worse on basically any task, but only because it acts as a kind of extended embedding. In fact, it's not doing anything important which is specfic for the IOI task.

# Sources

1. [Ground truth - Arena::Activation Patching](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification#keeping-track-of-your-guesses-predictions)
2. [Locating and Editing Factual Associations in GPT, by Meng, K, et. al.](https://arxiv.org/pdf/2202.05262)