<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-vy0sjotk
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-vy0sjotk
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 3f929b1d142b8f82bfbb8ae30e69bab7f76cadf3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x79040f63b7f0>

Plotting helper functions:

In [6]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


## Import functions from repo

In [8]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9063, done.[K
remote: Counting objects: 100% (9063/9063), done.[K
remote: Compressing objects: 100% (3540/3540), done.[K
remote: Total 9063 (delta 5508), reused 8890 (delta 5425), pack-reused 0[K
Receiving objects: 100% (9063/9063), 155.49 MiB | 25.01 MiB/s, done.
Resolving deltas: 100% (5508/5508), done.


In [9]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [10]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [11]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["S5"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["S4"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'S5')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                    target_token = prompt[targ]
                else:
                    target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Repalce io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [12]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+3),
            'S5': str(i+4),
            'text': f"{i} {i+1} {i+2} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 11)
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

In [13]:
def generate_prompts_list_corr(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'S1': str(i),
            'S2': str(i+1),
            'S3': str(i+2),
            'S4': str(i+2),
            'S5': str(i+3),
            'text': f"{i} {i+1} {i+2} {i+2}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list_2 = generate_prompts_list_corr(1, 11)
dataset_2 = Dataset(prompts_list_2, model.tokenizer, S1_is_first=True)

Logit diff is correct - incorr token. Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

# Ablation Expm Functions

In [14]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [15]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        "number mover 4": lst,
        "number mover 3": lst,
        "number mover 2": lst,
        "number mover 1": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 4": "S4",
        "number mover 3": "S3",
        "number mover 2": "S2",
        "number mover 1": "S1",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    # return new_score
    return 100 * new_score / orig_score

We can also prevent redundant computation of the full circuit score by storing it and just passing it in to the function.

# Ablate the model and compare with original

## Work backwards

https://www.notion.so/wlg1/Search-Methods-brainstorm-15a3020ab00b40adb79b0acf3622f5f4?pvs=4#dd6b43247d4945eda1d70ca4d4bae01d

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(11, -1, -1):  # go thru all heads in a layer first
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (11, 0)
99.02855682373047


Removed: (11, 1)
99.0681381225586


Removed: (11, 2)
99.27287292480469


Removed: (11, 3)
99.61582946777344


Removed: (11, 4)
100.09127807617188


Removed: (11, 5)
100.09712982177734


Removed: (11, 6)
100.05851745605469


Removed: (11, 7)
99.9634017944336


Removed: (11, 8)
99.3410873413086


Removed: (11, 9)
99.123291015625


Removed: (11, 10)
98.09922790527344


Removed: (11, 11)
99.79859161376953


Removed: (10, 0)
99.7418212890625


Removed: (10, 1)
98.1563491821289


Removed: (10, 2)
100.61833953857422


Removed: (10, 3)
100.79714965820312


Removed: (10, 4)
100.48069763183594


Removed: (10, 5)
100.26615142822266


Removed: (10, 6)
100.29136657714844


Removed: (10, 8)
100.43941497802734


Removed: (10, 9)
100.6727294921875


Removed: (10, 10)
101.22868347167969


Removed: (10, 11)
100.94214630126953


Removed: (9, 0)
100.90746307373047


Removed: (9, 2)
100.99788665771484


Removed: (9, 3)
102.2463150024414


Removed: (9, 4)
101.65393829345

In [None]:
mean_ablate_by_lst(curr_circuit, model, print_output=True).item()

Average logit difference (circuit / full) %: 97.2790


97.27898406982422

In [None]:
curr_circuit

[(0, 1),
 (0, 3),
 (0, 5),
 (0, 7),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 4),
 (1, 5),
 (2, 2),
 (2, 8),
 (2, 9),
 (3, 0),
 (3, 2),
 (3, 3),
 (3, 7),
 (4, 4),
 (4, 7),
 (4, 10),
 (5, 1),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (5, 9),
 (6, 1),
 (6, 4),
 (6, 6),
 (6, 10),
 (6, 11),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 5),
 (8, 6),
 (8, 8),
 (9, 1),
 (10, 7)]

Now try 10% threshold:

In [16]:
def find_circuit_backw(threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    # Start with full circuit
    curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("Removed:", (layer, head))
                print(new_score)
                print("\n")

    return curr_circuit

In [None]:
curr_circuit = find_circuit_backw(10)

Removed: (11, 0)
99.02855682373047


Removed: (11, 1)
99.0681381225586


Removed: (11, 2)
99.27287292480469


Removed: (11, 3)
99.61582946777344


Removed: (11, 4)
100.09127807617188


Removed: (11, 5)
100.09712982177734


Removed: (11, 6)
100.05851745605469


Removed: (11, 7)
99.9634017944336


Removed: (11, 8)
99.3410873413086


Removed: (11, 9)
99.123291015625


Removed: (11, 10)
98.09922790527344


Removed: (11, 11)
99.79859161376953


Removed: (10, 0)
99.7418212890625


Removed: (10, 1)
98.1563491821289


Removed: (10, 2)
100.61833953857422


Removed: (10, 3)
100.79714965820312


Removed: (10, 4)
100.48069763183594


Removed: (10, 5)
100.26615142822266


Removed: (10, 6)
100.29136657714844


Removed: (10, 8)
100.43941497802734


Removed: (10, 9)
100.6727294921875


Removed: (10, 10)
101.22868347167969


Removed: (10, 11)
100.94214630126953


Removed: (9, 0)
100.90746307373047


Removed: (9, 2)
100.99788665771484


Removed: (9, 3)
102.2463150024414


Removed: (9, 4)
101.65393829345

Try this method on greater-than task to see if recovers circuit similar to paper.

In [None]:
curr_circuit

[(0, 1),
 (0, 3),
 (0, 5),
 (0, 7),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (2, 2),
 (2, 9),
 (3, 0),
 (3, 3),
 (3, 7),
 (4, 4),
 (4, 7),
 (4, 8),
 (4, 10),
 (5, 1),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (5, 9),
 (5, 10),
 (6, 1),
 (6, 3),
 (6, 4),
 (6, 6),
 (6, 10),
 (7, 2),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 8),
 (9, 1),
 (10, 7)]

In [None]:
%%capture
curr_circuit = find_circuit_backw(20)

In [None]:
curr_circuit

[(0, 1),
 (0, 3),
 (0, 5),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (2, 2),
 (2, 9),
 (3, 0),
 (3, 2),
 (3, 3),
 (3, 7),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 10),
 (5, 0),
 (5, 1),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 8),
 (5, 9),
 (5, 11),
 (6, 1),
 (6, 6),
 (6, 8),
 (7, 10),
 (7, 11),
 (8, 6),
 (8, 8),
 (9, 1)]

## mean ablation the circuit pruned by iterative path patching

From:

https://colab.research.google.com/drive/1onREXMNmc9ks0xpwDslUX2pdG0RSYtWS#scrollTo=ehsYSXYO_25N&line=6&uniqifier=1

In [None]:
test_circ = [(0,1), (3,0), (4,4), (5,5), (5,8), (6,6), (7,11), (9,1)]

mean_ablate_by_lst(test_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 13.2641


13.264147758483887

From:

https://colab.research.google.com/drive/1onREXMNmc9ks0xpwDslUX2pdG0RSYtWS#scrollTo=V8JWdlVokmpL&line=6&uniqifier=1

In [None]:
test_circ = [(0, 1), (0, 5), (0, 10), (1, 5), (3, 0), (4, 4), (4, 8), (5, 1), (5, 4), (5, 5), (5, 8), (6, 1), (6, 6), (6, 9), (6, 10), (7, 6), (7, 10), (7, 11), (8, 8), (9, 1), (10, 7)]
mean_ablate_by_lst(test_circ, model, print_output=True).item()

Average logit difference (circuit / full) %: 40.4605


40.46047592163086

## Prune forwards

In [None]:
# Start with full circuit
curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]
threshold = 3  # This is T, a %. if performance is less than T%, allow its removal

for layer in range(0, 12):
    for head in range(12):
        # Copying the curr_circuit so we can iterate over one and modify the other
        copy_circuit = curr_circuit.copy()

        # Temporarily removing the current tuple from the copied circuit
        copy_circuit.remove((layer, head))

        new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

        # print((layer,head), new_score)
        # If the result is less than the threshold, remove the tuple from the original list
        if (100 - new_score) < threshold:
            curr_circuit.remove((layer, head))

            print("Removed:", (layer, head))
            print(new_score)
            print("\n")

Removed: (0, 0)
100.00466918945312


Removed: (0, 2)
98.07209777832031


Removed: (0, 4)
97.96208953857422


Removed: (0, 6)
97.41792297363281


Removed: (0, 11)
98.04102325439453


Removed: (1, 1)
97.67564392089844


Removed: (1, 2)
97.67819213867188


Removed: (1, 3)
97.88668823242188


Removed: (1, 4)
97.89542388916016


Removed: (1, 6)
97.8697509765625


Removed: (1, 7)
98.2431640625


Removed: (1, 8)
98.43437194824219


Removed: (1, 9)
98.68045806884766


Removed: (1, 10)
98.94314575195312


Removed: (1, 11)
99.24425506591797


Removed: (2, 0)
99.28617858886719


Removed: (2, 1)
100.14505767822266


Removed: (2, 2)
99.1255111694336


Removed: (2, 3)
99.42776489257812


Removed: (2, 4)
99.11087036132812


Removed: (2, 5)
99.4810562133789


Removed: (2, 6)
99.1651611328125


Removed: (2, 7)
98.68614959716797


Removed: (2, 8)
98.37564086914062


Removed: (2, 9)
97.45429992675781


Removed: (2, 10)
97.83071899414062


Removed: (2, 11)
98.20713806152344


Removed: (3, 1)
98.5207824707

In [None]:
curr_circuit

[(0, 1),
 (0, 3),
 (0, 5),
 (0, 7),
 (0, 8),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 4),
 (5, 5),
 (5, 9),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 2),
 (8, 6),
 (8, 8),
 (9, 1),
 (9, 5),
 (10, 7),
 (11, 10)]

## prune fwds then back iteratively- fns

In [47]:
def find_circuit_forw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(0, 12):
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # print((layer,head), new_score)
            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [48]:
def find_circuit_backw(curr_circuit=None, threshold=10):
    # threshold is T, a %. if performance is less than T%, allow its removal
    if curr_circuit == []:
        # Start with full circuit
        curr_circuit = [(layer, head) for layer in range(12) for head in range(12)]

    for layer in range(11, -1, -1):  # go thru all heads in a layer first
        for head in range(12):
            if (layer, head) not in curr_circuit:
                continue

            # Copying the curr_circuit so we can iterate over one and modify the other
            copy_circuit = curr_circuit.copy()

            # Temporarily removing the current tuple from the copied circuit
            copy_circuit.remove((layer, head))

            new_score = mean_ablate_by_lst(copy_circuit, model, print_output=False).item()

            # If the result is less than the threshold, remove the tuple from the original list
            if (100 - new_score) < threshold:
                curr_circuit.remove((layer, head))

                print("\nRemoved:", (layer, head))
                print(new_score)

    return curr_circuit, new_score

In [24]:
curr_circuit = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (0, 10), (1, 0), (1, 5), (3, 0), (3, 3), (3, 7), (3, 10), (3, 11), (4, 4), (4, 6), (4, 7), (4, 8), (4, 10), (4, 11), (5, 4), (5, 5), (5, 9), (6, 1), (6, 6), (6, 10), (7, 6), (7, 10), (7, 11), (8, 1), (8, 2), (8, 6), (8, 8), (9, 1), (9, 5), (10, 7), (11, 10)]
curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=2)

Removed: (8, 2)
98.4662094116211


Removed: (3, 10)
98.19361877441406




In [28]:
curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=2)

### iter fwd backw, threshold 2

In [30]:
threshold = 2
curr_circuit = None
prev_score = 100
iter = 1
while prev_score - new_score < threshold/2:
    print('fwd prune, iter ', str(iter))
    prev_score = new_score  # save old score before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=2)
    if prev_score - new_score < threshold/2:
        break
    print('backw prune, iter ', str(iter))
    prev_score = new_score # save old score before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=2)
    iter += 1

fwd prune, iter  1


Removed: (0, 0)
100.00470733642578


Removed: (0, 2)
98.07218170166016


Removed: (0, 11)
98.68363952636719


Removed: (1, 1)
98.21087646484375


Removed: (1, 2)
98.30923461914062


Removed: (1, 3)
98.48136901855469


Removed: (1, 4)
98.44184875488281


Removed: (1, 6)
98.35868072509766


Removed: (1, 7)
98.70216369628906


Removed: (1, 8)
98.81130981445312


Removed: (1, 9)
99.10165405273438


Removed: (1, 10)
99.28426361083984


Removed: (1, 11)
98.32254791259766


Removed: (2, 0)
98.21942138671875


Removed: (2, 1)
98.96421813964844


Removed: (2, 3)
99.10186004638672


Removed: (2, 4)
98.81763458251953


Removed: (2, 5)
98.89395141601562


Removed: (2, 6)
98.6978988647461


Removed: (2, 7)
98.333984375


Removed: (2, 8)
98.04685974121094


Removed: (2, 10)
98.67316436767578


Removed: (2, 11)
98.90179443359375


Removed: (3, 1)
99.17343139648438


Removed: (3, 2)
98.44761657714844


Removed: (3, 4)
99.8545150756836


Removed: (3, 5)
99.87609100341797


Removed:

### iter fwd backw, threshold 10

In [32]:
threshold = 10
curr_circuit = None
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    prev_score = new_score  # save old score before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if prev_score == new_score:
        break
    print('\nbackw prune, iter ', str(iter))
    prev_score = new_score # save old score before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    iter += 1

fwd prune, iter  1


Removed: (0, 0)
100.00470733642578


Removed: (0, 2)
98.07218170166016


Removed: (0, 3)
96.57498931884766


Removed: (0, 4)
96.4569320678711


Removed: (0, 5)
92.88771057128906


Removed: (0, 6)
92.20349884033203


Removed: (0, 7)
90.53197479248047


Removed: (0, 11)
91.30901336669922


Removed: (1, 1)
90.67813110351562


Removed: (1, 2)
90.63320922851562


Removed: (1, 3)
90.15375518798828


Removed: (1, 4)
90.081298828125


Removed: (1, 6)
90.14987182617188


Removed: (1, 7)
90.65447998046875


Removed: (1, 8)
90.69853210449219


Removed: (1, 9)
90.91239929199219


Removed: (1, 10)
91.1763687133789


Removed: (2, 0)
90.96540069580078


Removed: (2, 1)
91.77622985839844


Removed: (2, 2)
90.87447357177734


Removed: (2, 3)
90.8670425415039


Removed: (2, 4)
90.60151672363281


Removed: (2, 5)
90.76985931396484


Removed: (2, 7)
90.26012420654297


Removed: (2, 8)
90.02489471435547


Removed: (2, 10)
90.35562133789062


Removed: (2, 11)
90.95338439941406


Removed

Exception ignored in: <function _xla_gc_callback at 0x7904376a6e60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 97, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 


KeyboardInterrupt: ignored

In [33]:
curr_circuit

[(0, 1),
 (0, 8),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (1, 11),
 (2, 6),
 (2, 9),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 10),
 (3, 11),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 4),
 (5, 5),
 (5, 9),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 7),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 6),
 (8, 8),
 (9, 1),
 (9, 5),
 (10, 7)]

### iter fwd backw, threshold 20

In [37]:
threshold = 20
curr_circuit = None
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    prev_score = new_score  # save old score before finding new one
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if prev_score == new_score:
        break
    print('\nbackw prune, iter ', str(iter))
    prev_score = new_score # save old score before finding new one
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
100.00470733642578

Removed: (0, 2)
98.07218170166016

Removed: (0, 3)
96.57498931884766

Removed: (0, 4)
96.4569320678711

Removed: (0, 5)
92.88771057128906

Removed: (0, 6)
92.20349884033203

Removed: (0, 7)
90.53197479248047

Removed: (0, 8)
89.92404174804688

Removed: (0, 9)
85.35267639160156

Removed: (0, 10)
89.38638305664062

Removed: (0, 11)
89.96485137939453

Removed: (1, 0)
85.79306030273438

Removed: (1, 1)
85.86489868164062

Removed: (1, 2)
86.05559539794922

Removed: (1, 3)
85.52399444580078

Removed: (1, 4)
85.40201568603516

Removed: (1, 6)
85.52527618408203

Removed: (1, 7)
86.04039001464844

Removed: (1, 8)
86.73933410644531

Removed: (1, 9)
87.02789306640625

Removed: (1, 10)
87.12644958496094

Removed: (1, 11)
86.26460266113281

Removed: (2, 0)
86.16632080078125

Removed: (2, 1)
87.24700164794922

Removed: (2, 2)
86.12347412109375

Removed: (2, 3)
86.27395629882812

Removed: (2, 4)
85.96822357177734

Removed: (2, 5)
86.13011169433

KeyboardInterrupt: ignored

In [38]:
prev_score

46.846946716308594

In [39]:
new_score

46.846946716308594

In [40]:
curr_circuit

[(0, 1),
 (1, 5),
 (3, 0),
 (3, 6),
 (3, 7),
 (3, 10),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 11),
 (5, 4),
 (5, 5),
 (5, 6),
 (6, 6),
 (6, 11),
 (7, 0),
 (7, 2),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 1),
 (8, 6),
 (8, 8),
 (9, 1),
 (9, 7),
 (10, 7),
 (11, 10)]

### debug why it doesn't stop

In [49]:
import pdb

threshold = 2
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    if new_score != 0:
        prev_score = new_score  # save old score before finding new one
    old_circuit = curr_circuit.copy()
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if prev_score == new_score:
        break
    elif curr_circuit == old_circuit:
        pdb.set_trace()
    print('\nbackw prune, iter ', str(iter))
    prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy()
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if prev_score == new_score:
        break
    elif curr_circuit == old_circuit:
        pdb.set_trace()
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
100.00470733642578

Removed: (0, 2)
98.07218170166016

Removed: (0, 11)
98.68363952636719

Removed: (1, 1)
98.21087646484375

Removed: (1, 2)
98.30923461914062

Removed: (1, 3)
98.48136901855469

Removed: (1, 4)
98.44184875488281

Removed: (1, 6)
98.35868072509766

Removed: (1, 7)
98.70216369628906

Removed: (1, 8)
98.81130981445312

Removed: (1, 9)
99.10165405273438

Removed: (1, 10)
99.28426361083984

Removed: (1, 11)
98.32254791259766

Removed: (2, 0)
98.21942138671875

Removed: (2, 1)
98.96421813964844

Removed: (2, 3)
99.10186004638672

Removed: (2, 4)
98.81763458251953

Removed: (2, 5)
98.89395141601562

Removed: (2, 6)
98.6978988647461

Removed: (2, 7)
98.333984375

Removed: (2, 8)
98.04685974121094

Removed: (2, 10)
98.67316436767578

Removed: (2, 11)
98.90179443359375

Removed: (3, 1)
99.17343139648438

Removed: (3, 2)
98.44761657714844

Removed: (3, 4)
99.8545150756836

Removed: (3, 5)
99.87609100341797

Removed: (3, 6)
98.54624938964844



Exception ignored in: <function _xla_gc_callback at 0x7904376a6e60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 97, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 


KeyboardInterrupt: ignored

In [50]:
threshold = 2
curr_circuit = []
prev_score = 100
new_score = 0
iter = 1
while prev_score != new_score:
    print('\nfwd prune, iter ', str(iter))
    # if new_score != 0:
    #     prev_score = new_score  # save old score before finding new one
    # track changes in circuit as for some reason it doesn't work with scores
    old_circuit = curr_circuit.copy()
    curr_circuit, new_score = find_circuit_forw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    print('\nbackw prune, iter ', str(iter))
    # prev_score = new_score # save old score before finding new one
    old_circuit = curr_circuit.copy()
    curr_circuit, new_score = find_circuit_backw(curr_circuit=curr_circuit, threshold=threshold)
    if curr_circuit == old_circuit:
        break
    iter += 1


fwd prune, iter  1

Removed: (0, 0)
100.00470733642578

Removed: (0, 2)
98.07218170166016

Removed: (0, 11)
98.68363952636719

Removed: (1, 1)
98.21087646484375

Removed: (1, 2)
98.30923461914062

Removed: (1, 3)
98.48136901855469

Removed: (1, 4)
98.44184875488281

Removed: (1, 6)
98.35868072509766

Removed: (1, 7)
98.70216369628906

Removed: (1, 8)
98.81130981445312

Removed: (1, 9)
99.10165405273438

Removed: (1, 10)
99.28426361083984

Removed: (1, 11)
98.32254791259766

Removed: (2, 0)
98.21942138671875

Removed: (2, 1)
98.96421813964844

Removed: (2, 3)
99.10186004638672

Removed: (2, 4)
98.81763458251953

Removed: (2, 5)
98.89395141601562

Removed: (2, 6)
98.6978988647461

Removed: (2, 7)
98.333984375

Removed: (2, 8)
98.04685974121094

Removed: (2, 10)
98.67316436767578

Removed: (2, 11)
98.90179443359375

Removed: (3, 1)
99.17343139648438

Removed: (3, 2)
98.44761657714844

Removed: (3, 4)
99.8545150756836

Removed: (3, 5)
99.87609100341797

Removed: (3, 6)
98.54624938964844



In [51]:
curr_circuit


[(0, 1),
 (0, 3),
 (0, 5),
 (0, 7),
 (0, 9),
 (0, 10),
 (1, 0),
 (1, 5),
 (2, 2),
 (3, 0),
 (3, 3),
 (3, 7),
 (3, 10),
 (4, 4),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 10),
 (4, 11),
 (5, 4),
 (5, 5),
 (5, 8),
 (5, 9),
 (6, 1),
 (6, 6),
 (6, 10),
 (7, 6),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 6),
 (8, 8),
 (9, 1),
 (9, 5),
 (10, 7)]

## etc fns

In [None]:
# base_lst = [(0, 1), (0, 10), (3, 0), (4, 4), (5, 5), (6, 1), (6, 6), (7, 10), (7, 11), (8, 8), (8, 11), (9, 1), (9, 5), (10, 7)]

In [None]:
# import json

# with open("scores.json", "w") as file:
#     json.dump(all_scores, file, default=lambda x: str(x))  # Convert tuples to strings for JSON serialization

In [None]:
# from google.colab import files
# files.download("scores.json")  # or "scores.pkl" or "scores.json" depending on the file you saved

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>