# Setup
(No need to change anything)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-iyfz16f9
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-iyfz16f9
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 0825c5eb4196e7ad72d28bcf4e615306b3897490
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x79c220bf1b70>

## Load Model

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [None]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9262, done.[K
remote: Counting objects: 100% (1976/1976), done.[K
remote: Compressing objects: 100% (344/344), done.[K
remote: Total 9262 (delta 1727), reused 1732 (delta 1628), pack-reused 7286[K
Receiving objects: 100% (9262/9262), 156.79 MiB | 20.12 MiB/s, done.
Resolving deltas: 100% (5620/5620), done.


In [None]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0


In [None]:
import ioi_circuit_extraction as ioi_circuit_extraction

In [None]:
from ioi_dataset import IOIDataset

# Generate dataset with multiple prompts

In [None]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.corr_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.incorr_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [None]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'corr': str(i+4),
            'incorr': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 101)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
import random

def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y):
        r1 = random.randint(1, 100)
        r2 = random.randint(1, 100)
        while True:
            r3 = random.randint(1, 100)
            r4 = random.randint(1, 100)
            if r4 - 1 != r3:
                break
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(i+4),
            'incorr': str(i+3),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 101)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)

# ioi ds

In [None]:
ioi_dataset = IOIDataset(
    prompt_type="mixed",
    N=100,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
)

abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

In [None]:
_, means_cache = model.run_with_cache(
        abc_dataset.toks.long(),
        return_type=None,
        # names_filter=lambda name: name.endswith("z"),
    )

In [None]:
from torch import Tensor

layer=0
z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
# for template_group in means_dataset.groups:
#     z_for_this_template = z_for_this_layer[template_group]
#     z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
#     means[layer, template_group] = z_means_for_this_template

In [None]:
z_for_this_layer.size()  # numbatch, seqlen, numhead, d_head

torch.Size([100, 21, 12, 64])

In [None]:
abc_dataset.groups

[array([ 0,  2,  6, 25, 37, 41, 81, 88]),
 array([30, 45, 68, 74]),
 array([15, 38, 42, 46, 56, 97, 99]),
 array([ 5, 33, 35, 39, 59, 64, 69, 75, 86]),
 array([ 1, 14, 17, 26, 29, 36, 53, 70, 77, 78, 83, 93]),
 array([13, 47, 50, 52]),
 array([16, 19, 43, 49, 61, 66, 80]),
 array([10, 12, 20, 40, 72]),
 array([ 8, 21, 44, 58, 73, 79, 85, 89, 95]),
 array([ 9, 22, 48, 51, 54, 62, 84, 92]),
 array([ 3,  7, 11, 27, 34, 63, 71, 94]),
 array([28, 32, 57, 82, 87, 90, 91]),
 array([ 4, 18, 60, 65, 67, 96]),
 array([23, 24, 31, 55, 76, 98])]

In [None]:
template_group = abc_dataset.groups[0]
z_for_this_template = z_for_this_layer[template_group]
print(template_group)
z_for_this_template.size()  # this just obtains batches at inds in template_group. no other dims change

[ 0  2  6 25 37 41 81 88]


torch.Size([8, 21, 12, 64])

So if all batches belong to the same group, you just use z_for_this_layer directly without needing to extract certain indices to get z_for_this_template

# MLP ablation

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    corr_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.corr_tokenIDs]
    incorr_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.incorr_tokenIDs]
    # Find logit difference
    answer_logit_diff = corr_logits - incorr_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

## Exploratory tests

In [None]:
from typing import Dict, Tuple, List
from jaxtyping import Float, Bool
import torch as t

# lst = [(layer, head) for layer in range(12) for head in range(12)]
lst = [layer for layer in range(12)]
CIRCUIT = {
    "number mover": lst,
    # "number mover 4": lst,
    "number mover 3": lst,
    "number mover 2": lst,
    "number mover 1": lst,
}

SEQ_POS_TO_KEEP = {
    "number mover": "end",
    # "number mover 4": "S4",
    "number mover 3": "S3",
    "number mover 2": "S2",
    "number mover 1": "S1",
}

In [None]:
lst

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

In [None]:
model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'gpt2',
 'tokenizer_prepends_bos': False,
 'use_attn_in': F

In [None]:
_, means_cache = model.run_with_cache(
        dataset_2.toks.long(),
        return_type=None,
        # names_filter=lambda name: name.endswith("z"),
    )
means_cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [None]:
utils.get_act_name("z", 0)

'blocks.0.attn.hook_z'

In [None]:
utils.get_act_name("mlp_out", 0)

'blocks.0.hook_mlp_out'

hook_post has dim equal to d_mlp, the number of neurons. Thus it's all neurons (the middle layer) before projecting to the output layer

In [None]:
means_cache['blocks.0.mlp.hook_post'].size()

torch.Size([100, 4, 3072])

In [None]:
means_cache[utils.get_act_name("mlp_out", 0)].size()

torch.Size([100, 4, 768])

In [None]:
mlp_output_for_this_layer: Float[Tensor, "batch seq d_model"] = means_cache[utils.get_act_name("mlp_out", 0)]

In [None]:
mlp_output_for_this_layer.size()

torch.Size([100, 4, 768])

In [None]:
mlp_output_means_for_this_template = einops.reduce(mlp_output_for_this_layer, "batch seq d_model -> seq d_model", "mean")
mlp_output_means_for_this_template.size()

torch.Size([4, 768])

"mlp_output_means_for_this_template" gives the mean for each (pos, d_model) activation. There are 4 * 768 activations, each with its own mean.

In [None]:
dataset_2.groups

[array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
        51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
        68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
        85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])]

There is only one group; all indices belong to one group

In [None]:
# this is how to create template groups:

# all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
# all_ids_ar = np.array(all_ids)
# self.groups = []
# for id in list(set(all_ids)):
#     self.groups.append(np.where(all_ids_ar == id)[0])

So group size is <= batch size. Each template group creates a copy of mlp_output_for_this_layer, slicing certain batch inds

## MLP ablation fns

In [None]:
def compute_means_by_template_MLP(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Returns the mean of each head's output over the means dataset. This mean is
    computed separately for each group of prompts with the same template (these
    are given by means_dataset.groups).
    '''
    # Cache the outputs of every head
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        # names_filter=lambda name: name.endswith("z"),
    )
    # Create tensor to store means
    n_layers, d_model = model.cfg.n_layers, model.cfg.d_model
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, d_model), device=model.cfg.device)

    # Get set of different templates for this data
    for layer in range(n_layers):
        mlp_output_for_this_layer: Float[Tensor, "batch seq d_model"] = means_cache[utils.get_act_name("mlp_out", layer)]
        for template_group in means_dataset.groups:  # here, we only have one group
            mlp_output_for_this_template = mlp_output_for_this_layer[template_group]
            # aggregate all batches
            mlp_output_means_for_this_template = einops.reduce(mlp_output_for_this_template, "batch seq d_model -> seq d_model", "mean")
            means[layer, template_group] = mlp_output_means_for_this_template
            # at layer, each batch ind is tempalte group (a tensor of size seq d_model)
            # is assigned the SAME mean, "mlp_output_means_for_this_template"

    # # Get set of different templates for this data
    # for layer in range(model.cfg.n_layers):
    #     z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
    #     for template_group in means_dataset.groups:
    #         z_for_this_template = z_for_this_layer[template_group]
    #         z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
    #         means[layer, template_group] = z_means_for_this_template

    return means

In [None]:
model.reset_hooks(including_permanent=True)

# Compute the mean of each head's output on the ABC dataset, grouped by template
means = compute_means_by_template_MLP(dataset_2, model)
means.size()

torch.Size([12, 100, 4, 768])

In [None]:
# def get_heads_and_posns_to_keep(
#     means_dataset: Dataset,
#     model: HookedTransformer,
#     circuit: Dict[str, List[Tuple[int, int]]],
#     seq_pos_to_keep: Dict[str, str],
# ) -> Dict[int, Bool[Tensor, "batch seq head"]]:
#     '''
#     Returns a dictionary mapping layers to a boolean mask giving the indices of the
#     z output which *shouldn't* be mean-ablated.

#     The output of this function will be used for the hook function that does ablation.
#     '''
#     heads_and_posns_to_keep = {}
#     batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

#     for layer in range(model.cfg.n_layers):

#         mask = t.zeros(size=(batch, seq, n_heads))

#         for (head_type, head_list) in circuit.items():
#             seq_pos = seq_pos_to_keep[head_type]
#             indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
#             for (layer_idx, head_idx) in head_list:
#                 if layer_idx == layer:
#                     mask[:, indices, head_idx] = 1

#         heads_and_posns_to_keep[layer] = mask.bool()

#     return heads_and_posns_to_keep

In [None]:
def get_mlp_outputs_and_posns_to_keep(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[int]],  # Adjusted to hold list of layers instead of (layer, head) tuples
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq"]]:  # Adjusted the return type to "batch seq"
    '''
    Returns a dictionary mapping layers to a boolean mask giving the indices of the
    MLP output which *shouldn't* be mean-ablated.

    The output of this function will be used for the hook function that does ablation.
    '''
    mlp_outputs_and_posns_to_keep = {}
    batch, seq = len(means_dataset), means_dataset.max_len

    for layer in range(model.cfg.n_layers):
        mask = t.zeros(size=(batch, seq))

        for (mlp_type, layer_list) in circuit.items():
            seq_pos = seq_pos_to_keep[mlp_type]
            indices = means_dataset.word_idx[seq_pos]
            if layer in layer_list:  # Check if the current layer is in the layer list for this mlp_type
                mask[:, indices] = 1

        mlp_outputs_and_posns_to_keep[layer] = mask.bool()

    return mlp_outputs_and_posns_to_keep

In [None]:
mlp_outputs_and_posns_to_keep = get_mlp_outputs_and_posns_to_keep(dataset_2, model, CIRCUIT, SEQ_POS_TO_KEEP)

In [None]:
# def hook_fn_mask_z(
#     z: Float[Tensor, "batch seq head d_head"],
#     hook: HookPoint,
#     heads_and_posns_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
#     means: Float[Tensor, "layer batch seq head d_head"],
# ) -> Float[Tensor, "batch seq head d_head"]:
#     '''
#     Hook function which masks the z output of a transformer head.

#     heads_and_posns_to_keep
#         Dict created with the get_heads_and_posns_to_keep function. This tells
#         us where to mask.

#     means
#         Tensor of mean z values of the means_dataset over each group of prompts
#         with the same template. This tells us what values to mask with.
#     '''
#     # Get the mask for this layer, and add d_head=1 dimension so it broadcasts correctly
#     mask_for_this_layer = heads_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(z.device)

#     # Set z values to the mean
#     z = t.where(mask_for_this_layer, z, means[hook.layer()])

#     return z

In [None]:
def hook_fn_mask_mlp_out(
    mlp_out: Float[Tensor, "batch seq d_mlp"],
    hook: HookPoint,
    mlp_outputs_and_posns_to_keep: Dict[int, Bool[Tensor, "batch seq"]],
    means: Float[Tensor, "layer batch seq d_mlp"],
) -> Float[Tensor, "batch seq d_mlp"]:
    '''
    Hook function which masks the MLP output of a transformer layer.

    mlp_outputs_and_posns_to_keep
        Dict created with the get_mlp_outputs_and_posns_to_keep function. This tells
        us where to mask.

    means
        Tensor of mean MLP output values of the means_dataset over each group of prompts
        with the same template. This tells us what values to mask with.
    '''
    # Get the mask for this layer, adapted for MLP output structure
    mask_for_this_layer = mlp_outputs_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(mlp_out.device)

    # Set MLP output values to the mean where necessary
    mlp_out = t.where(mask_for_this_layer, mlp_out, means[hook.layer()])

    return mlp_out

In [None]:
# def add_mean_ablation_hook(
#     model: HookedTransformer,
#     means_dataset: Dataset,
#     circuit: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
#     seq_pos_to_keep: Dict[str, str] = SEQ_POS_TO_KEEP,
#     is_permanent: bool = True,
# ) -> HookedTransformer:
#     '''
#     Adds a permanent hook to the model, which ablates according to the circuit and
#     seq_pos_to_keep dictionaries.

#     In other words, when the model is run on ioi_dataset, every head's output will
#     be replaced with the mean over means_dataset for sequences with the same template,
#     except for a subset of heads and sequence positions as specified by the circuit
#     and seq_pos_to_keep dicts.
#     '''

#     model.reset_hooks(including_permanent=True)

#     # Compute the mean of each head's output on the ABC dataset, grouped by template
#     means = compute_means_by_template(means_dataset, model)

#     # Convert this into a boolean map
#     heads_and_posns_to_keep = get_heads_and_posns_to_keep(means_dataset, model, circuit, seq_pos_to_keep)

#     # Get a hook function which will patch in the mean z values for each head, at
#     # all positions which aren't important for the circuit
#     hook_fn = partial(
#         hook_fn_mask_z,
#         heads_and_posns_to_keep=heads_and_posns_to_keep,
#         means=means
#     )

#     # Apply hook
#     model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

#     return model

In [None]:
model.reset_hooks(including_permanent=True)
logits_orig, orig_cache = model.run_with_cache(dataset.toks)
orig_score = logits_to_ave_logit_diff(logits_orig, dataset)

In [None]:
model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

# Get a hook function which will patch in the mean MLP output values for each MLP layer, at
# all positions which aren't important for the circuit
hook_fn = partial(
    hook_fn_mask_mlp_out,
    mlp_outputs_and_posns_to_keep = mlp_outputs_and_posns_to_keep,
    means=means
)

# Apply hook
model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

new_logits = model(dataset.toks)
# new_logits.size()

new_score = logits_to_ave_logit_diff(new_logits, dataset)

print(f"Average logit difference (dataset, using entire model): {orig_score:.4f}")
print(f"Average logit difference (dataset, only using circuit): {new_score:.4f}")
print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")

Average logit difference (dataset, using entire model): 5.8226
Average logit difference (dataset, only using circuit): 5.8226
Average logit difference (circuit / full) %: 100.0000


In [None]:
hook_fn.layer()

AttributeError: ignored

In [None]:
def add_mean_ablation_hook_MLP(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
    seq_pos_to_keep: Dict[str, str] = SEQ_POS_TO_KEEP,
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Adds a permanent hook to the model, which ablates according to the circuit and
    seq_pos_to_keep dictionaries.

    In other words, when the model is run on ioi_dataset, every head's output will
    be replaced with the mean over means_dataset for sequences with the same template,
    except for a subset of heads and sequence positions as specified by the circuit
    and seq_pos_to_keep dicts.
    '''

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template_MLP(means_dataset, model)

    # Convert this into a boolean map
    mlp_outputs_and_posns_to_keep = get_mlp_outputs_and_posns_to_keep(means_dataset, model, circuit, seq_pos_to_keep)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_mlp_out,
        mlp_outputs_and_posns_to_keep=mlp_outputs_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [None]:
def mean_ablate_by_lst_MLP(lst, model, orig_score, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = add_mean_ablation_hook_MLP(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    new_logits = model(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff(new_logits, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [None]:
lst = [layer for layer in range(12) if layer !=9]
perc_of_orig = mean_ablate_by_lst_MLP(lst, model, orig_score, print_output=False).item()
perc_of_orig

7.071479320526123

In [None]:
for i in range(12):
    lst = [layer for layer in range(12) if layer != i]
    perc_of_orig = mean_ablate_by_lst_MLP(lst, model, orig_score, print_output=False).item()
    print(i, perc_of_orig)

0 37.15055847167969
1 64.20915985107422
2 69.72237396240234
3 85.45144653320312
4 67.83340454101562
5 86.49969482421875
6 68.98082733154297
7 81.7919692993164
8 81.44670104980469
9 7.071479320526123
10 66.96712493896484
11 73.90493774414062


In [None]:
mlp_to_not_ablate = [0, 1, 2, 4, 6, 9, 10, 11]
perc_of_orig = mean_ablate_by_lst_MLP(mlp_to_not_ablate, model, orig_score, print_output=False).item()
print(perc_of_orig)

43.83709716796875


In [None]:
mlp_to_not_ablate = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
perc_of_orig = mean_ablate_by_lst_MLP(mlp_to_not_ablate, model, orig_score, print_output=False).item()
print(perc_of_orig)

45.26451110839844


In [None]:
mlp_to_not_ablate = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
perc_of_orig = mean_ablate_by_lst_MLP(mlp_to_not_ablate, model, orig_score, print_output=False).item()
print(perc_of_orig)

73.90493774414062


In [None]:
mlp_to_not_ablate = [0, 8, 9, 10, 11]
perc_of_orig = mean_ablate_by_lst_MLP(mlp_to_not_ablate, model, orig_score, print_output=False).item()
print(perc_of_orig)

-2.684004068374634


# MLP and Head together

In [None]:
def get_heads_and_posns_to_keep(
    means_dataset: Dataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    '''
    Returns a dictionary mapping layers to a boolean mask giving the indices of the
    z output which *shouldn't* be mean-ablated.

    The output of this function will be used for the hook function that does ablation.
    '''
    heads_and_posns_to_keep = {}
    batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for (head_type, head_list) in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos] # modify this for key vs query pos. curr, this is query
            for (layer_idx, head_idx) in head_list:
                if layer_idx == layer:
                    mask[:, indices, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep

def hook_fn_mask_z(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    heads_and_posns_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    means: Float[Tensor, "layer batch seq head d_head"],
) -> Float[Tensor, "batch seq head d_head"]:
    '''
    Hook function which masks the z output of a transformer head.

    heads_and_posns_to_keep
        Dict created with the get_heads_and_posns_to_keep function. This tells
        us where to mask.

    means
        Tensor of mean z values of the means_dataset over each group of prompts
        with the same template. This tells us what values to mask with.
    '''
    # Get the mask for this layer, and add d_head=1 dimension so it broadcasts correctly
    mask_for_this_layer = heads_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(z.device)

    # Set z values to the mean
    z = t.where(mask_for_this_layer, z, means[hook.layer()])

    return z

def compute_means_by_template(
    means_dataset: Dataset,
    model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    '''
    Returns the mean of each head's output over the means dataset. This mean is
    computed separately for each group of prompts with the same template (these
    are given by means_dataset.groups).
    '''
    # Cache the outputs of every head
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )
    # Create tensor to store means
    n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)

    # Get set of different templates for this data
    for layer in range(model.cfg.n_layers):
        z_for_this_layer: Float[Tensor, "batch seq head d_head"] = means_cache[utils.get_act_name("z", layer)]
        for template_group in means_dataset.groups:
            z_for_this_template = z_for_this_layer[template_group]
            z_means_for_this_template = einops.reduce(z_for_this_template, "batch seq head d_head -> seq head d_head", "mean")
            means[layer, template_group] = z_means_for_this_template

    return means

def add_mean_ablation_hook(
    model: HookedTransformer,
    means_dataset: Dataset,
    circuit: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
    seq_pos_to_keep: Dict[str, str] = SEQ_POS_TO_KEEP,
    is_permanent: bool = True,
) -> HookedTransformer:
    '''
    Adds a permanent hook to the model, which ablates according to the circuit and
    seq_pos_to_keep dictionaries.

    In other words, when the model is run on ioi_dataset, every head's output will
    be replaced with the mean over means_dataset for sequences with the same template,
    except for a subset of heads and sequence positions as specified by the circuit
    and seq_pos_to_keep dicts.
    '''

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template(means_dataset, model)

    # Convert this into a boolean map
    heads_and_posns_to_keep = get_heads_and_posns_to_keep(means_dataset, model, circuit, seq_pos_to_keep)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_z,
        heads_and_posns_to_keep=heads_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    return model

In [None]:
def add_mean_ablation_hook_MLP_head(
    model: HookedTransformer,
    means_dataset: Dataset,
    heads_lst, mlp_lst,
    is_permanent: bool = True,
) -> HookedTransformer:
    CIRCUIT = {
        "number mover": heads_lst,
        "number mover 3": heads_lst,
        "number mover 2": heads_lst,
        "number mover 1": heads_lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template(means_dataset, model)

    # Convert this into a boolean map
    heads_and_posns_to_keep = get_heads_and_posns_to_keep(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_z,
        heads_and_posns_to_keep=heads_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    ########################
    CIRCUIT = {
        "number mover": mlp_lst,
        "number mover 3": mlp_lst,
        "number mover 2": mlp_lst,
        "number mover 1": mlp_lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template_MLP(means_dataset, model)

    # Convert this into a boolean map
    mlp_outputs_and_posns_to_keep = get_mlp_outputs_and_posns_to_keep(means_dataset, model, CIRCUIT, SEQ_POS_TO_KEEP)

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_mlp_out,
        mlp_outputs_and_posns_to_keep=mlp_outputs_and_posns_to_keep,
        means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("mlp_out"), hook_fn, is_permanent=True)

    return model

In [None]:
heads_not_ablate = [(layer, head) for layer in range(12) for head in range(12)]
mlps_not_ablate = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)
new_logits = model(dataset.toks)
new_score = logits_to_ave_logit_diff(new_logits, dataset)
print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")

Average logit difference (circuit / full) %: 100.0000


In [None]:
# get rid of last layer

heads_not_ablate = [(layer, head) for layer in range(11) for head in range(12)]
mlps_not_ablate = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)
new_logits = model(dataset.toks)
new_score = logits_to_ave_logit_diff(new_logits, dataset)
print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")

Average logit difference (circuit / full) %: 74.3485


In [None]:
# get rid of heads in L11

heads_not_ablate = [(layer, head) for layer in range(11) for head in range(12)]
mlps_not_ablate = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

model = add_mean_ablation_hook_MLP_head(model, dataset_2, heads_not_ablate, mlps_not_ablate)
new_logits = model(dataset.toks)
new_score = logits_to_ave_logit_diff(new_logits, dataset)
print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")

Average logit difference (circuit / full) %: 99.3207


Try ablating by orig dataset mean

In [None]:
# get rid of heads in L11

heads_not_ablate = [(layer, head) for layer in range(10) for head in range(10)]
mlps_not_ablate = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

model = add_mean_ablation_hook_MLP_head(model, dataset, heads_not_ablate, mlps_not_ablate)
new_logits = model(dataset.toks)
new_score = logits_to_ave_logit_diff(new_logits, dataset)
print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")

Average logit difference (circuit / full) %: 30.9183


# Head Ablation Expm Functions

In [None]:
# model.reset_hooks(including_permanent=True)
# ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)
# orig_score = logits_to_ave_logit_diff(ioi_logits_original, dataset)

In [None]:
def mean_ablate_by_lst(lst, model, orig_score, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    # ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    # orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

In [None]:
def find_circuit_forw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, orig_score=100, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, orig_score, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score