In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial
import os 
#os.environ['CUDA_LAUNCH_BLOCKING']='1'

import torch
from transformer_lens import HookedTransformer

from graph import Graph
from dataset import EAPDataset, HFEAPDataset
from attribute import attribute
from metrics import get_metric
from evaluate_graph import evaluate_graph, evaluate_baseline, evaluate_area_under_curve
from circuit_loading import load_graph_from_json, load_graph_from_pt

In [2]:
model = HookedTransformer.from_pretrained("google/gemma-2-2b", center_unembed=False, center_writing_weights=False)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True
dataset = HFEAPDataset("danaarad/ioi_dataset", model.tokenizer, task="ioi", num_examples=100)
dataloader = dataset.to_dataloader(8)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [3]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, mean=True, loss=True), 'EAP-IG-inputs', ig_steps=10, intervention='patching')

100%|██████████| 13/13 [00:45<00:00,  3.53s/it]


In [4]:
evaluate_baseline(model, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:02<00:00,  4.83it/s]


tensor(4.4935)

In [5]:
evaluate_baseline(model, dataloader, partial(metric_fn, mean=False, loss=False), run_corrupted=True).mean()

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:02<00:00,  4.84it/s]


tensor(-5.0355)

In [6]:
g.apply_greedy(0)
evaluate_graph(model, g, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

100%|██████████| 13/13 [00:04<00:00,  3.11it/s]


tensor(-5.0355)

In [7]:
g.apply_topn(800, absolute=True)
print(g.count_included_edges())
evaluate_graph(model, g, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

688


  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:26<00:00,  2.05s/it]


tensor(3.0814)

In [8]:
results_patching = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), absolute=True)

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:02<00:00,  4.83it/s]
100%|██████████| 13/13 [00:04<00:00,  3.11it/s]


Computing results for 0.1% of edges (N=74)


100%|██████████| 13/13 [00:07<00:00,  1.72it/s]


-1.0306202826662314
Computing results for 0.2% of edges (N=148)


100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


-0.619224633667964
Computing results for 0.5% of edges (N=371)


100%|██████████| 13/13 [00:23<00:00,  1.81s/it]


0.08935092559411598
Computing results for 1.0% of edges (N=742)


100%|██████████| 13/13 [00:26<00:00,  2.05s/it]


0.5540960201824712
Computing results for 2.0% of edges (N=1484)


100%|██████████| 13/13 [00:27<00:00,  2.09s/it]


0.34616346600452924
Computing results for 5.0% of edges (N=3710)


100%|██████████| 13/13 [00:27<00:00,  2.09s/it]


0.6065725995193144
Computing results for 10.0% of edges (N=7421)


100%|██████████| 13/13 [00:27<00:00,  2.11s/it]


0.7485725500691188
Computing results for 20.0% of edges (N=14843)


100%|██████████| 13/13 [00:27<00:00,  2.10s/it]


0.865512340635868
Computing results for 50.0% of edges (N=37109)


100%|██████████| 13/13 [00:27<00:00,  2.09s/it]


0.9114247248643117
Computing results for 100% of edges (N=74218)


100%|██████████| 13/13 [00:27<00:00,  2.08s/it]

1.0
Weighted edge counts: [24.0, 61.0, 249.0, 641.0, 1396.0, 3519.0, 7215.0, 14771.0, 37109.0, 74218.0]





In [9]:
results_mean = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), intervention='mean', intervention_dataloader=dataloader, absolute=True)

100%|██████████| 10/10 [00:02<00:00,  4.09it/s]
Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.02it/s]
100%|██████████| 10/10 [00:24<00:00,  2.49s/it]


Computing results for 0.1% of edges (N=74)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.61it/s]
100%|██████████| 10/10 [00:08<00:00,  1.20it/s]


-0.06963775522169238
Computing results for 0.2% of edges (N=148)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.63it/s]
100%|██████████| 10/10 [00:10<00:00,  1.03s/it]


-0.07297893373470771
Computing results for 0.5% of edges (N=371)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.62it/s]
100%|██████████| 10/10 [00:16<00:00,  1.65s/it]


-0.05658852499479765
Computing results for 1.0% of edges (N=742)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.60it/s]
100%|██████████| 10/10 [00:22<00:00,  2.25s/it]


-0.1090220256099009
Computing results for 2.0% of edges (N=1484)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.60it/s]
100%|██████████| 10/10 [00:24<00:00,  2.48s/it]


-0.1094702609538308
Computing results for 5.0% of edges (N=3710)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.65it/s]
100%|██████████| 10/10 [00:25<00:00,  2.56s/it]


-0.10110202728866112
Computing results for 10.0% of edges (N=7421)


Computing mean: 100%|██████████| 10/10 [00:02<00:00,  4.63it/s]
100%|██████████| 10/10 [00:25<00:00,  2.56s/it]


-0.20903703045782507
Computing results for 20.0% of edges (N=14843)


Computing mean:  10%|█         | 1/10 [00:00<00:01,  5.39it/s]

In [None]:
results_zero = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), intervention='zero', absolute=True)