In [9]:
%load_ext autoreload
%autoreload 2

from functools import partial
import os 
#os.environ['CUDA_LAUNCH_BLOCKING']='1'

import torch
from transformer_lens import HookedTransformer

from graph import Graph
from dataset import EAPDataset, HFEAPDataset
from attribute import attribute
from metrics import get_metric
from evaluate_graph import evaluate_graph, evaluate_baseline, evaluate_area_under_curve
from circuit_loading import load_graph_from_json, load_graph_from_pt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
model = HookedTransformer.from_pretrained("google/gemma-2-2b", center_unembed=False, center_writing_weights=False)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True
dataset = HFEAPDataset("danaarad/ioi_dataset", model.tokenizer, task="ioi", num_examples=1000)
dataloader = dataset.to_dataloader(8)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [18]:
import pandas as pd 
cleans, corrupteds, labelss = [], [], []
for clean, corrupted, labels in dataloader:
    cleans += clean
    corrupteds += corrupted
    labelss += labels

In [22]:
correct_idx, incorrect_idx = zip(*labelss)
d = {'clean': cleans, 'corrupted': corrupteds, 'correct_idx': correct_idx, 'incorrect_idx': incorrect_idx}
df = pd.DataFrame(data=d)
df.to_csv('ioi_gemma.csv', index=False)

In [3]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, mean=True, loss=True), 'EAP-IG-inputs', ig_steps=10, intervention='patching')

100%|██████████| 13/13 [00:46<00:00,  3.55s/it]


In [4]:
evaluate_baseline(model, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:02<00:00,  4.81it/s]


tensor(4.4935)

In [5]:
evaluate_baseline(model, dataloader, partial(metric_fn, mean=False, loss=False), run_corrupted=True).mean()

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:02<00:00,  4.81it/s]


tensor(-4.6264)

In [6]:
g.apply_greedy(0)
evaluate_graph(model, g, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

100%|██████████| 13/13 [00:04<00:00,  3.09it/s]


tensor(-4.6264)

In [7]:
g.apply_topn(800, absolute=True)
print(g.count_included_edges())
evaluate_graph(model, g, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

601


  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:23<00:00,  1.79s/it]


tensor(6.0593)

In [8]:
results_patching = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), absolute=True)

100%|██████████| 13/13 [00:02<00:00,  4.82it/s]
100%|██████████| 13/13 [00:23<00:00,  1.79s/it]


Computing results for 0.1% of edges (N=74)


100%|██████████| 13/13 [00:08<00:00,  1.48it/s]


-0.05327235873081935
Computing results for 0.2% of edges (N=148)


100%|██████████| 13/13 [00:10<00:00,  1.23it/s]


0.7936992492165699
Computing results for 0.5% of edges (N=371)


100%|██████████| 13/13 [00:18<00:00,  1.40s/it]


1.3477920328519084
Computing results for 1.0% of edges (N=742)


100%|██████████| 13/13 [00:22<00:00,  1.69s/it]


1.2715755925295398
Computing results for 2.0% of edges (N=1484)


100%|██████████| 13/13 [00:25<00:00,  1.97s/it]


1.1949043377478068
Computing results for 5.0% of edges (N=3710)


100%|██████████| 13/13 [00:27<00:00,  2.11s/it]


1.0489056066865154
Computing results for 10.0% of edges (N=7421)


100%|██████████| 13/13 [00:27<00:00,  2.14s/it]


0.8954876484341532
Computing results for 20.0% of edges (N=14843)


100%|██████████| 13/13 [00:27<00:00,  2.12s/it]


1.005622359892793
Computing results for 50.0% of edges (N=37109)


 54%|█████▍    | 7/13 [00:16<00:13,  2.31s/it]


KeyboardInterrupt: 

In [9]:
results_mean = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), intervention='mean', intervention_dataloader=dataloader, absolute=True)

100%|██████████| 10/10 [00:02<00:00,  4.09it/s]
Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.02it/s]
100%|██████████| 10/10 [00:24<00:00,  2.49s/it]


Computing results for 0.1% of edges (N=74)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.61it/s]
100%|██████████| 10/10 [00:08<00:00,  1.20it/s]


-0.06963775522169238
Computing results for 0.2% of edges (N=148)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.63it/s]
100%|██████████| 10/10 [00:10<00:00,  1.03s/it]


-0.07297893373470771
Computing results for 0.5% of edges (N=371)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.62it/s]
100%|██████████| 10/10 [00:16<00:00,  1.65s/it]


-0.05658852499479765
Computing results for 1.0% of edges (N=742)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.60it/s]
100%|██████████| 10/10 [00:22<00:00,  2.25s/it]


-0.1090220256099009
Computing results for 2.0% of edges (N=1484)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.60it/s]
100%|██████████| 10/10 [00:24<00:00,  2.48s/it]


-0.1094702609538308
Computing results for 5.0% of edges (N=3710)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.65it/s]
100%|██████████| 10/10 [00:25<00:00,  2.56s/it]


-0.10110202728866112
Computing results for 10.0% of edges (N=7421)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.63it/s]
100%|██████████| 10/10 [00:25<00:00,  2.56s/it]


-0.20903703045782507
Computing results for 20.0% of edges (N=14843)


Computing mean:  10%|█         | 1/10 [00:00<00:01,  5.39it/s]

In [None]:
results_zero = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), intervention='zero', absolute=True)