<a href="https://colab.research.google.com/github/shahved25/shahved25/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

 # Activation Patching in TransformerLens Demo
 This is an accompaniment to [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo). That notebook explains some basic techniques for mech interp of networks, including an overview of activation patching ([summary here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)). This demonstrates how to use the Activation Patching utils in TransformerLens.


 <b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

 **Tips for reading this Colab:**
 * You can run all this code for yourself!
 * The graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate
 * Collapse irrelevant sections with the dropdown arrows
 * Search the page using the search in the sidebar, not CTRL+F

 ## Setup (Ignore)

In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/TransformerLensOrg/TransformerLens.git
  Cloning https://github.com/TransformerLensOrg/TransformerLens.git to /tmp/pip-req-build-398arqux
  Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens.git /tmp/pip-req-build-398arqux
  Resolved https://github.com/TransformerLensOrg/TransformerLens.git to commit b5a16f849649a237cc02cc2c272ae4dc2085abe4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens==0.0.0)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting fancy-einsum>=0.0.3 (from transformer-lens==0.0.0)
  Downloading fancy_einsum-0.0.3-py3-none-

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

NameError: name 'IN_COLAB' is not defined

In [2]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [3]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

 We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7b38bee68cd0>

 Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:

In [5]:
from neel_plotly import line, imshow, scatter

In [6]:
import transformer_lens.patching as patching

 ## Activation Patching Setup
 This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important.

In [35]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: read)

In [24]:
!pip install hf_xet
!pip install flash-attention
import hf_xet

[31mERROR: Operation cancelled by user[0m[31m


In [36]:
model = HookedTransformer.from_pretrained("pythia-2.8b")

OSError: There was a specific connection error when trying to load EleutherAI/pythia-2.8b:
401 Client Error: Unauthorized for url: https://huggingface.co/EleutherAI/pythia-2.8b/resolve/main/config.json (Request ID: Root=1-68363d81-06b5f5d63bb193be49b4636e;054e7d16-2b4e-4f97-b058-3e0b8007fe99)

Invalid credentials in Authorization header

In [17]:
clean_tokens = "3(2+1)="
corrupted_tokens = "4(2+1)="

In [18]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

My Patching

In [19]:

# Function to patch activations at a given hook
def make_patch_hook(hook_name):
    def patch_hook(value, hook):
        return clean_cache[hook.name]
    return (hook_name, patch_hook)

# Get clean/corrupted predictions
def get_predicted_token_str(logits):
    token_id = logits[0, -1].argmax().item()
    return model.tokenizer.decode(token_id)

print("Clean Output:    ", get_predicted_token_str(clean_logits))
print("Corrupted Output:", get_predicted_token_str(corrupted_logits))

# Patch attention head output in each layer
print("\nPatched Outputs (attention head results):")
for layer in range(model.cfg.n_layers):
    hook_name = f"blocks.{layer}.hook_resid_post"
    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[make_patch_hook(hook_name)]
    )
    patched_token = get_predicted_token_str(patched_logits)
    print(f"Layer {layer}: {patched_token}")

Clean Output:     1
Corrupted Output: 1

Patched Outputs (attention head results):
Layer 0: 1
Layer 1: 1
Layer 2: 1
Layer 3: 1
Layer 4: 1
Layer 5: 1
Layer 6: 1
Layer 7: 1
Layer 8: 1
Layer 9: 1
Layer 10: 1
Layer 11: 1


 ## Patching
 In the following cells, we use the patching module to call activation patching utilities

In [30]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = True

 ### Patching Single Activation Types
 We start by patching single types of activation
 The general syntax is that the functions are called get_act_patch_... and take in (model, corrupted_tokens, clean_cache, patching_metric)

 We can patch the residual stream at the start of each block over each layer and position
 resid_pre -> attn_out, mlp_out, resid_mid all also work

In [None]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(resid_pre_act_patch_results,
       yaxis="Layer",
       xaxis="Position",
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="resid_pre Activation Patching")

 We can patch head outputs over each head in each layer, patching across all positions at once
 out -> q, k, v, pattern all also work

In [None]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results,
       yaxis="Layer",
       xaxis="Head",
       title="attn_head_out Activation Patching (All Pos)")

 We can patch head outputs over each head in each layer, patching on each position in turn
 out -> q, k, v, pattern all also work, though note that pattern has output shape [layer, pos, head]
 We reshape it to plot nicely

In [None]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
    imshow(attn_head_out_act_patch_results,
        yaxis="Head Label",
        xaxis="Pos",
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos")

 ### Patching multiple activation types
 Some utilities are provided to patch multiple activations types *in turn*. Note that this is *not* a utility to patch multiple activations at once, it's just a useful scan to get a sense for what's going on in a model
 By block: We patch the residual stream at the start of each block, attention output and MLP output over each layer and position

In [None]:
every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_block_result, facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Activation Patching Per Block", xaxis="Position", yaxis="Layer", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

 By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow.

In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer", zmax=1, zmin=-1)
# [markdown]
# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail

In [None]:
if DO_SLOW_RUNS:
    every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
    every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
    imshow(every_head_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (By Pos)", xaxis="Position", yaxis="Layer & Head", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)

 ## Induction Patching
 To show how easy it is, lets do that again with induction heads in a 2L Attention Only model
 The input will be repeated random tokens eg BOS 1 5 8 9 2 1 5 8 9 2, and we judge the model's ability to predict the second repetition with its induction heads
 Lets call A, B and C different (non-repeated) random sequences. We'll start with clean tokens AA and corrupted tokens AB, and see how well the model can predict the second A given the first A

 ### Setup

In [None]:
attn_only = HookedTransformer.from_pretrained("attn-only-2l")
batch = 4
seq_len = 20
rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
rand_tokens_B = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
rand_tokens_C = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
bos = torch.tensor([attn_only.tokenizer.bos_token_id]*batch)[:, None].to(attn_only.cfg.device)
clean_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_A], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_B], dim=1).to(attn_only.cfg.device)

In [None]:
clean_logits_induction, clean_cache_induction = attn_only.run_with_cache(clean_tokens_induction)
corrupted_logits_induction, corrupted_cache_induction = attn_only.run_with_cache(corrupted_tokens_induction)

 We define our metric as negative loss on the second half (negative loss so that higher is better)
 This time we won't normalise our metric

In [None]:
def induction_loss(logits, answer_token_indices=rand_tokens_A):
    seq_len = answer_token_indices.shape[1]

    # logits: batch x seq_len x vocab_size
    # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)
    final_logits = logits[:, -seq_len:-1]
    final_log_probs = final_logits.log_softmax(-1)
    return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()
CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()
print("Clean baseline:", CLEAN_BASELINE_INDUCTION)
CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()
print("Corrupted baseline:", CORRUPTED_BASELINE_INDUCTION)

 ### Patching

In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)

if DO_SLOW_RUNS:
    every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)
    every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
    imshow(every_head_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (By Pos)", xaxis="Position", yaxis="Layer & Head", zmax=CLEAN_BASELINE_INDUCTION, x= [f"{tok}_{i}" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))], y=[f"L{l}H{h}" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])

 ### Changing the Corrupted Baseline
 We can also change the corrupted baseline easily to check what things look like! We'll keep clean as AA, but rather than corrupted as AB, we'll try out:
 * BA - This has a corrupted first half, so we expect both keys *and* values to matter. Head output patching should work, but value and key and pattern won't.
 * BB - This is still inductiony but with different tokens. So keys, queries and patterns don't matter, head output patching will work, and value will.
 * BC - This is just random tokens, so everything is corrupted! The induction head needs queries, keys *and* values, so only output will work.

In [None]:
corrupted_tokens_induction_BA = torch.cat([bos, rand_tokens_B, rand_tokens_A], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction_BB = torch.cat([bos, rand_tokens_B, rand_tokens_B], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction_BC = torch.cat([bos, rand_tokens_B, rand_tokens_C], dim=1).to(attn_only.cfg.device)

In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BA, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head on BA (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BB, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head on BB (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BC, clean_cache_induction, induction_loss)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head on BC (All Pos)", xaxis="Head", yaxis="Layer", zmax=CLEAN_BASELINE_INDUCTION)