<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-ftngo3hg
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-ftngo3hg
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit fa287750606075574df2c538058e67d648e2f952
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.0-py3-none-any.whl (260 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fe9e08ab5e0>

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [None]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9106, done.[K
remote: Counting objects: 100% (1820/1820), done.[K
remote: Compressing objects: 100% (289/289), done.[K
remote: Total 9106 (delta 1614), reused 1608 (delta 1528), pack-reused 7286[K
Receiving objects: 100% (9106/9106), 155.60 MiB | 38.50 MiB/s, done.
Resolving deltas: 100% (5507/5507), done.


In [None]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [None]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [None]:
class Dataset:
    def __init__(self, prompts, pos_dict, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["corr"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["incorr"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'corr' and key != 'incorr')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                #     target_token = "Ġ" + prompt[targ]
                # target_index = tokens.index(target_token)
                target_index = pos_dict[targ]
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [None]:
pos_dict = {
    'S1': 0,
    'S2': 1,
    'S3': 2,
    'S4': 3,
}

In [None]:
def generate_prompts_list(x ,y):
    words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': words[i],
            'S2': words[i+1],
            'S3': words[i+2],
            'S4': words[i+3],
            'corr': words[i+4],
            'incorr': words[i],  # this is arbitrary
            'text': f"{words[i]} {words[i+1]} {words[i+2]} {words[i+3]}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 6)
dataset = Dataset(prompts_list, pos_dict, model.tokenizer, S1_is_first=True)

In [None]:
import random

def generate_prompts_list_corr(x ,y):
    words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
    prompts_list = []
    for i in range(x, y):
        r1 = random.choice(words)
        r2 = random.choice(words)
        r3 = random.choice(words)
        r4 = random.choice(words)
        prompt_dict = {
            'S1': str(r1),
            'S2': str(r2),
            'S3': str(r3),
            'S4': str(r4),
            'corr': str(r1),
            'incorr': str(i+4),
            'text': f"{r1} {r2} {r3} {r4}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 6)
dataset_2 = Dataset(prompts_list_2, pos_dict, model.tokenizer, S1_is_first=True)
prompts_list_2

[{'S1': 'four',
  'S2': 'one',
  'S3': 'four',
  'S4': 'eight',
  'corr': 'four',
  'incorr': '5',
  'text': 'four one four eight'},
 {'S1': 'nine',
  'S2': 'five',
  'S3': 'three',
  'S4': 'seven',
  'corr': 'nine',
  'incorr': '6',
  'text': 'nine five three seven'},
 {'S1': 'five',
  'S2': 'two',
  'S3': 'four',
  'S4': 'three',
  'corr': 'five',
  'incorr': '7',
  'text': 'five two four three'},
 {'S1': 'two',
  'S2': 'one',
  'S3': 'nine',
  'S4': 'eight',
  'corr': 'two',
  'incorr': '8',
  'text': 'two one nine eight'},
 {'S1': 'eight',
  'S2': 'seven',
  'S3': 'ten',
  'S4': 'seven',
  'corr': 'eight',
  'incorr': '9',
  'text': 'eight seven ten seven'}]

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        # "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        # "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

# Ablate the model and compare with original

### try full circuit from repeatLast iter fb

In [None]:
curr_circuit = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (3, 0), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 10), (4, 11), (5, 4), (5, 5), (5, 9), (6, 1), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 1), (8, 2), (8, 6), (8, 8), (9, 1), (9, 5), (10, 7), (11, 10)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 64.6226


64.62264251708984

In [None]:
curr_circuit = [(9, 1)]
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 24.9522


24.952171325683594

## compare with repeatRandElem

In [None]:
repeatRand_backw_3 = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (1, 7), (2, 0), (2, 2), (2, 9), (3, 0), (4, 2), (4, 4), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (6, 1), (6, 3), (6, 4), (6, 6), (6, 8), (6, 10), (7, 10), (7, 11), (8, 11), (9, 1), (10, 1)]
mean_ablate_by_lst(repeatRand_backw_3, model, print_output=True).item()

Average logit difference (circuit / full) %: 61.2811


61.2811164855957

In [None]:
repeatRand_backw_20 = [(0, 1), (0, 2), (0, 3), (0, 9), (0, 10), (0, 11), (1, 0), (1, 5), (1, 7), (1, 8), (2, 0), (2, 2), (2, 7), (2, 9), (3, 0), (3, 3), (4, 4), (4, 6), (4, 7), (4, 9), (4, 10), (4, 11), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (6, 1), (6, 6), (6, 9), (6, 10), (7, 10), (7, 11), (9, 1)]
mean_ablate_by_lst(repeatRand_backw_20, model, print_output=True).item()

Average logit difference (circuit / full) %: 53.9477


53.94770050048828

In [None]:
repeatRand_fb_3 = [(0, 1), (0, 2), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (2, 0), (2, 2), (2, 6), (2, 7), (3, 0), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 3), (5, 4), (5, 5), (5, 6), (6, 1), (6, 6), (6, 10), (7, 10), (7, 11), (8, 8), (8, 9), (8, 11), (9, 1), (11, 10)]
mean_ablate_by_lst(repeatRand_fb_3, model, print_output=True).item()

Average logit difference (circuit / full) %: 67.9914


67.99136352539062

## Prune backwards

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
100.10379028320312


Removed: (11, 1)
98.97869873046875


Removed: (11, 2)
99.02860260009766


Removed: (11, 3)
98.74519348144531


Removed: (11, 4)
99.08023834228516


Removed: (11, 5)
99.12744903564453


Removed: (11, 6)
99.15473937988281


Removed: (11, 7)
99.03898620605469


Removed: (11, 8)
97.15733337402344


Removed: (11, 9)
97.16557312011719


Removed: (11, 11)
97.33262634277344


Removed: (10, 0)
97.36923217773438


Removed: (10, 1)
97.38016510009766


Removed: (10, 3)
97.48831176757812


Removed: (10, 4)
97.46089172363281


Removed: (10, 6)
97.41268157958984


Removed: (10, 7)
98.00159454345703


Removed: (10, 8)
98.22784423828125


Removed: (10, 9)
98.02510833740234


Removed: (10, 10)
98.12550354003906


Removed: (10, 11)
98.09323120117188


Removed: (9, 0)
98.04710388183594


Removed: (9, 2)
97.6371841430664


Removed: (9, 3)
97.615234375


Removed: (9, 4)
97.75285339355469


Removed: (9, 6)
98.1245346069336


Removed: (9, 7)
98.0626449584961


Removed: (9

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 100.4750


tensor(100.4750, device='cuda:0')

In [None]:
backw_3 = curr_circuit.copy()
backw_3

[(0, 1),
 (1, 5),
 (2, 4),
 (3, 2),
 (4, 4),
 (4, 8),
 (4, 10),
 (5, 0),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 7),
 (5, 8),
 (6, 1),
 (6, 3),
 (6, 4),
 (6, 7),
 (6, 9),
 (6, 10),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 5),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 5),
 (10, 2),
 (10, 5),
 (11, 10)]

In [None]:
len(backw_3)

38

Now try 10% threshold:

In [None]:
def find_circuit_backw(threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    # Start with full circuit
    curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("Removed:", (layer, head))
                print(new_score)
                print("\n")

    return curr_circuit

In [None]:
curr_circuit = find_circuit_backw(10)

Removed: (11, 0)
100.10379028320312


Removed: (11, 1)
98.97869873046875


Removed: (11, 2)
99.02860260009766


Removed: (11, 3)
98.74519348144531


Removed: (11, 4)
99.08023834228516


Removed: (11, 5)
99.12744903564453


Removed: (11, 6)
99.15473937988281


Removed: (11, 7)
99.03898620605469


Removed: (11, 8)
97.15733337402344


Removed: (11, 9)
97.16557312011719


Removed: (11, 10)
95.91690063476562


Removed: (11, 11)
96.08448791503906


Removed: (10, 0)
96.11646270751953


Removed: (10, 1)
96.08088684082031


Removed: (10, 3)
96.18419647216797


Removed: (10, 4)
96.1608657836914


Removed: (10, 5)
95.62007904052734


Removed: (10, 6)
95.57417297363281


Removed: (10, 7)
95.9581069946289


Removed: (10, 8)
96.15215301513672


Removed: (10, 9)
95.95096588134766


Removed: (10, 10)
96.05656433105469


Removed: (10, 11)
96.03401947021484


Removed: (9, 0)
96.00611114501953


Removed: (9, 2)
95.5905532836914


Removed: (9, 3)
95.58920288085938


Removed: (9, 4)
95.7115707397461


Remo

In [None]:
backw_10 = curr_circuit.copy()
backw_10

[(0, 1),
 (0, 5),
 (1, 5),
 (2, 4),
 (3, 2),
 (4, 4),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 7),
 (5, 8),
 (6, 1),
 (6, 2),
 (6, 3),
 (6, 4),
 (6, 7),
 (6, 8),
 (6, 10),
 (7, 6),
 (7, 8),
 (7, 10),
 (7, 11),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (10, 2)]

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 93.5710


tensor(93.5710, device='cuda:0')

In [None]:
len(backw_10)

34

20%:

In [None]:
%%capture
curr_circuit = find_circuit_backw(20)

KeyboardInterrupt: ignored

In [None]:
backw_20 = curr_circuit.copy()
backw_20

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True)

Average logit difference (circuit / full) %: 93.5710


tensor(93.5710, device='cuda:0')

In [None]:
len(backw_20)

34

### set diffs of the three perf lvls

In [None]:
set(backw_3) - set(backw_10)

{(5, 1), (6, 9), (7, 7), (8, 0), (8, 1), (8, 5), (9, 5), (10, 5), (11, 10)}

In [None]:
set(backw_10) - set(backw_3)

{(0, 5), (4, 11), (6, 2), (6, 8), (7, 8)}

In [None]:
set(backw_3) - set(backw_20)

{(5, 1), (6, 9), (7, 7), (8, 0), (8, 1), (8, 5), (9, 5), (10, 5), (11, 10)}

In [None]:
set(backw_10) - set(backw_20)

set()

In [None]:
mean_ablate_by_lst(backw_20, model, print_output=True)

Average logit difference (circuit / full) %: 93.5710


tensor(93.5710, device='cuda:0')

In [None]:
mean_ablate_by_lst(backw_20 + [(10, 2)], model, print_output=True)

Average logit difference (circuit / full) %: 93.5710


tensor(93.5710, device='cuda:0')

In [None]:
mean_ablate_by_lst([x for x in backw_20 if x != (9, 1)], model, print_output=True)

Average logit difference (circuit / full) %: 71.2875


tensor(71.2875, device='cuda:0')

In [None]:
mean_ablate_by_lst([x for x in backw_20 if x != (9, 1)] + [(10, 2)], model, print_output=True)

Average logit difference (circuit / full) %: 71.2875


tensor(71.2875, device='cuda:0')

### set diff w repeatLast and repeatFirstAll circs

In [None]:
repeatFirstAll_backw_3 = [(0, 1), (0, 9), (1, 0), (1, 5), (2, 2), (2, 9), (2, 10), (3, 0), (3, 3), (3, 6), (3, 7), (4, 4), (4, 8), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (7, 1), (7, 2), (7, 6), (7, 7), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (9, 9), (10, 1), (10, 2), (11, 8), (11, 9), (11, 10)]
repeatFirstAll_backw_10 = [(0, 1), (0, 9), (1, 0), (1, 5), (1, 6), (2, 2), (2, 8), (2, 9), (3, 0), (3, 2), (3, 3), (3, 7), (3, 8), (3, 10), (4, 4), (5, 1), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 10), (6, 0), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (6, 11), (7, 0), (7, 6), (7, 8), (7, 10), (7, 11), (8, 0), (8, 1), (8, 6), (8, 8), (8, 11), (9, 1), (10, 2)]
repeatFirstAll_backw_20 = [(0, 1), (0, 9), (1, 0), (1, 5), (1, 6), (2, 2), (2, 8), (2, 9), (2, 10), (3, 0), (3, 2), (3, 3), (3, 7), (3, 10), (4, 4), (4, 10), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (5, 10), (6, 1), (6, 3), (6, 4), (6, 6), (6, 9), (6, 10), (7, 2), (7, 6), (7, 7), (7, 10), (7, 11), (8, 0), (8, 6), (8, 8), (8, 11), (9, 1)]

In [None]:
mean_ablate_by_lst(repeatFirstAll_backw_3, model, print_output=True)

Average logit difference (circuit / full) %: 78.6232


tensor(78.6232, device='cuda:0')

In [None]:
mean_ablate_by_lst(repeatFirstAll_backw_10, model, print_output=True)

Average logit difference (circuit / full) %: 78.4179


tensor(78.4179, device='cuda:0')

In [None]:
mean_ablate_by_lst(repeatFirstAll_backw_20, model, print_output=True)

Average logit difference (circuit / full) %: 74.1122


tensor(74.1122, device='cuda:0')

In [None]:
repLast_backw_3 = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (0, 10), (1, 0), (1, 4), (1, 5), (2, 2), (2, 8), (2, 9), (3, 0), (3, 2), (3, 3), (3, 7), (4, 4), (4, 7), (4, 10), (5, 1), (5, 3), (5, 4), (5, 5), (5, 6), (5, 8), (5, 9), (6, 1), (6, 4), (6, 6), (6, 10), (6, 11), (7, 6), (7, 10), (7, 11), (8, 0), (8, 5), (8, 6), (8, 8), (9, 1), (10, 7)]
set(backw_3) - set(repLast_backw_3)

{(2, 4),
 (4, 8),
 (5, 0),
 (5, 2),
 (5, 7),
 (6, 3),
 (6, 7),
 (6, 9),
 (7, 7),
 (8, 1),
 (8, 9),
 (8, 11),
 (9, 5),
 (10, 2),
 (10, 5),
 (11, 10)}

In [None]:
mean_ablate_by_lst(repLast_backw_3, model, print_output=True)

Average logit difference (circuit / full) %: 65.7443


tensor(65.7443, device='cuda:0')

In [None]:
set(repLast_backw_3) - set(backw_3)

{(0, 3),
 (0, 5),
 (0, 7),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 4),
 (2, 2),
 (2, 8),
 (2, 9),
 (3, 0),
 (3, 3),
 (3, 7),
 (4, 7),
 (5, 9),
 (6, 6),
 (6, 11),
 (10, 7)}

In [None]:
set(backw_3) - set(repeatFirstAll_backw_3)

{(2, 4),
 (3, 2),
 (4, 10),
 (5, 0),
 (5, 7),
 (6, 7),
 (8, 5),
 (8, 9),
 (9, 5),
 (10, 5)}

In [None]:
set(repeatFirstAll_backw_3) - set(backw_3)

{(0, 9),
 (1, 0),
 (2, 2),
 (2, 9),
 (2, 10),
 (3, 0),
 (3, 3),
 (3, 6),
 (3, 7),
 (6, 6),
 (7, 1),
 (7, 2),
 (9, 9),
 (10, 1),
 (11, 8),
 (11, 9)}

## Prune forwards

In [None]:
# # Start with full circuit
# curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
# threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

# for layer in range(0, 12):
#     for head in range(12):
#         # Copying the curr_circuit so we can iterate over one and modify the other
#         copy_circuit = curr_circuit.copy()

#         # Temporarily removing the current tuple from the copied circuit
#         copy_circuit.remove((layer, head))

#         new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

#         # print((layer,head), new_score)
#         # If the result is less than the threshold, remove the tuple from the original list
#         if (100 - new_score) < threshold:
#             curr_circuit.remove((layer, head))

#             print("Removed:", (layer, head))
#             print(new_score)
#             print("\n")

## Prune fwds-backwds iteratively

In [None]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [None]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
100.2175064086914

Removed: (0, 2)
101.17543029785156

Removed: (0, 3)
99.70558166503906

Removed: (0, 4)
100.49411010742188

Removed: (0, 5)
98.54776763916016

Removed: (0, 6)
97.98517608642578

Removed: (0, 7)
97.66582489013672

Removed: (0, 8)
97.72859191894531

Removed: (0, 9)
98.5828628540039

Removed: (0, 10)
100.9132080078125

Removed: (0, 11)
101.22256469726562

Removed: (1, 0)
102.34335327148438

Removed: (1, 1)
101.18263244628906

Removed: (1, 2)
101.21923828125

Removed: (1, 3)
101.04120635986328

Removed: (1, 4)
100.89093017578125

Removed: (1, 6)
101.30673217773438

Removed: (1, 7)
102.4504165649414

Removed: (1, 8)
102.81678009033203

Removed: (1, 9)
102.7670669555664

Removed: (1, 10)
101.745849609375

Removed: (1, 11)
101.41844177246094

Removed: (2, 0)
101.97945404052734

Removed: (2, 1)
104.06961059570312

Removed: (2, 2)
103.98907470703125

Removed: (2, 3)
101.8916244506836

Removed: (2, 4)
99.25298309326172

Removed: (2, 5)
100.0

In [None]:
fb_3 = curr_circuit.copy()
fb_3

[(0, 1),
 (1, 5),
 (3, 4),
 (4, 4),
 (4, 5),
 (4, 7),
 (4, 8),
 (4, 10),
 (5, 0),
 (5, 2),
 (5, 4),
 (5, 6),
 (5, 8),
 (6, 4),
 (6, 6),
 (6, 7),
 (6, 8),
 (6, 9),
 (6, 10),
 (7, 2),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 5),
 (10, 2),
 (11, 8),
 (11, 10)]

In [None]:
mean_ablate_by_lst(fb_3, model, print_output=True)

Average logit difference (circuit / full) %: 97.0344


tensor(97.0344, device='cuda:0')

In [None]:
mean_ablate_by_lst(fb_3 + [(6, 9)], model, print_output=True)

Average logit difference (circuit / full) %: 97.0344


tensor(97.0344, device='cuda:0')

#### compare

In [None]:
set(backw_3) - set(fb_3)

{(2, 4),
 (3, 2),
 (5, 1),
 (5, 3),
 (5, 5),
 (5, 7),
 (6, 1),
 (6, 3),
 (8, 0),
 (8, 5),
 (10, 5)}

In [None]:
set(fb_3) - set(backw_3)

{(3, 4), (4, 5), (4, 7), (6, 6), (6, 8), (7, 2), (7, 5), (11, 8)}

### iter fwd backw, threshold 20

In [None]:
# threshold = 20
# curr_circuit = []
# prev_score = 100
# new_score = 0
# iter = 1
# while prev_score != new_score:
#     print('\nfwd prune, iter ', str(iter))
#     # track changes in circuit as for some reason it doesn't work with scores
#     old_circuit = curr_circuit.copy() # save old before finding new one
#     curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
#     if curr_circuit == old_circuit:
#         break
#     print('\nbackw prune, iter ', str(iter))
#     # prev_score = new_score # save old score before finding new one
#     old_circuit = curr_circuit.copy() # save old before finding new one
#     curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
#     if curr_circuit == old_circuit:
#         break
#     iter += 1

In [None]:
# curr_circuit

## Prune backwds-fwds iteratively

### iter fwd backw, threshold 3

In [None]:
threshold = 3
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nfwd prune, iter ', str(iter))
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy() # save old before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


backw prune, iter  1

Removed: (11, 0)
100.10379028320312

Removed: (11, 1)
98.97869873046875

Removed: (11, 2)
99.02860260009766

Removed: (11, 3)
98.74519348144531

Removed: (11, 4)
99.08023834228516

Removed: (11, 5)
99.12744903564453

Removed: (11, 6)
99.15473937988281

Removed: (11, 7)
99.03898620605469

Removed: (11, 8)
97.15733337402344

Removed: (11, 9)
97.16557312011719

Removed: (11, 11)
97.33262634277344

Removed: (10, 0)
97.36923217773438

Removed: (10, 1)
97.38016510009766

Removed: (10, 3)
97.48831176757812

Removed: (10, 4)
97.46089172363281

Removed: (10, 6)
97.41268157958984

Removed: (10, 7)
98.00159454345703

Removed: (10, 8)
98.22784423828125

Removed: (10, 9)
98.02510833740234

Removed: (10, 10)
98.12550354003906

Removed: (10, 11)
98.09323120117188

Removed: (9, 0)
98.04710388183594

Removed: (9, 2)
97.6371841430664

Removed: (9, 3)
97.615234375

Removed: (9, 4)
97.75285339355469

Removed: (9, 6)
98.1245346069336

Removed: (9, 7)
98.0626449584961

Removed: (9, 8)

In [None]:
bf_3 = curr_circuit.copy()
bf_3

[(0, 1),
 (1, 5),
 (3, 2),
 (4, 4),
 (4, 8),
 (4, 10),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (6, 1),
 (6, 3),
 (6, 4),
 (6, 7),
 (6, 9),
 (6, 10),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 5),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 5),
 (10, 2),
 (10, 5),
 (11, 10)]

#### compare

In [None]:
len(bf_3)

32

In [None]:
len(fb_3)

35

In [None]:
set(backw_3) - set(bf_3)

{(2, 4), (5, 0), (5, 1), (5, 2), (5, 3), (5, 7)}

In [None]:
set(bf_3) - set(backw_3)

set()

In [None]:
set(fb_3) - (set(fb_3) - set(bf_3))

{(0, 1),
 (1, 5),
 (4, 4),
 (4, 8),
 (4, 10),
 (5, 4),
 (5, 6),
 (5, 8),
 (6, 4),
 (6, 7),
 (6, 9),
 (6, 10),
 (7, 6),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 6),
 (8, 8),
 (8, 9),
 (8, 11),
 (9, 1),
 (9, 5),
 (10, 2),
 (11, 10)}

In [None]:
set(bf_3) - set(fb_3)

{(3, 2), (5, 5), (6, 1), (6, 3), (8, 0), (8, 5), (10, 5)}

Get score of fb_3 without nodes it has that bf_3 doesn't have

this is set intersection: https://chat.openai.com/c/c15f48a7-226b-4c89-8ad9-a39a471867f5

In [None]:
mean_ablate_by_lst(list(set(fb_3) - (set(fb_3) - set(bf_3))), model, print_output=True)

Average logit difference (circuit / full) %: 89.4520


tensor(89.4520, device='cuda:0')

In [None]:
mean_ablate_by_lst(list(set(bf_3) - (set(bf_3) - set(fb_3))), model, print_output=True)

Average logit difference (circuit / full) %: 89.4520


tensor(89.4520, device='cuda:0')

In [None]:
(set(fb_3) - (set(fb_3) - set(bf_3))) == (set(bf_3) - (set(bf_3) - set(fb_3)))

True