In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial

import torch
import pickle
#from nnsight.models import UnifiedTransformer
from transformer_lens import HookedTransformer, HookedTransformerConfig

from graph import Graph
from circuit_loading import load_graph_from_json, load_graph_from_pt

from dataset import EAPDataset, HFEAPDataset
from attribute import attribute
from metrics import get_metric
from evaluate_graph import evaluate_graph, evaluate_baseline
from huggingface_hub import hf_hub_download

In [2]:
g_patching = load_graph_from_json("circuits/ioi_hf_prob_diff_vanilla_gpt2.json")
g_zero = load_graph_from_json("circuits/ioi_hf_prob_diff_vanilla_zero_gpt2.json")
g_mean = load_graph_from_json("circuits/ioi_hf_prob_diff_vanilla_mean_gpt2.json")

In [4]:
model = HookedTransformer.from_pretrained("gpt2-small")
dataset = HFEAPDataset("danaarad/ioi_dataset", model.tokenizer, task="ioi", num_examples=50)
dataloader = dataset.to_dataloader(4)
metric_fn = get_metric("prob_diff", "ioi", model.tokenizer, model)



Loaded pretrained model gpt2-small into HookedTransformer


In [8]:
from evaluate_graph import evaluate_area_under_curve

results_patching = evaluate_area_under_curve(model, g_patching, dataloader, partial(metric_fn, loss=False, mean=False),
                                    prune=True)

 54%|█████▍    | 7/13 [00:00<00:00, 30.08it/s]

100%|██████████| 13/13 [00:00<00:00, 30.11it/s]


Computing results for 0.1% of edges (N=32)


100%|██████████| 13/13 [00:00<00:00, 25.18it/s]


0.5904379542392778
Computing results for 0.2% of edges (N=64)


100%|██████████| 13/13 [00:00<00:00, 24.85it/s]


0.5904379542392778
Computing results for 0.5% of edges (N=162)


100%|██████████| 13/13 [00:00<00:00, 25.20it/s]


0.5904379542392778
Computing results for 1.0% of edges (N=324)


100%|██████████| 13/13 [00:00<00:00, 24.94it/s]


0.7407070339328818
Computing results for 2.0% of edges (N=649)


100%|██████████| 13/13 [00:00<00:00, 24.37it/s]


0.9765515714418114
Computing results for 5.0% of edges (N=1624)


100%|██████████| 13/13 [00:00<00:00, 24.78it/s]


1.0055424289830595
Computing results for 10.0% of edges (N=3249)


100%|██████████| 13/13 [00:00<00:00, 24.06it/s]


1.0055424289830595
Computing results for 20.0% of edges (N=6498)


100%|██████████| 13/13 [00:00<00:00, 23.91it/s]


1.0055424289830595
Computing results for 50.0% of edges (N=16245)


100%|██████████| 13/13 [00:00<00:00, 23.23it/s]


1.0055424289830595
Computing results for 100% of edges (N=32491)


100%|██████████| 13/13 [00:00<00:00, 22.75it/s]

1.0
Weighted edge counts: [32, 64, 162, 324, 649, 1624, 3249, 6498, 16245, 32491]





In [9]:
results_mean = evaluate_area_under_curve(model, g_patching, dataloader, partial(metric_fn, loss=False, mean=False),
                                    prune=True, intervention='mean', intervention_dataloader=dataloader)

 31%|███       | 4/13 [00:00<00:00, 31.28it/s]

100%|██████████| 13/13 [00:00<00:00, 29.79it/s]


Computing results for 0.1% of edges (N=32)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 24.44it/s]
100%|██████████| 13/13 [00:00<00:00, 36.36it/s]


1.4797326872491254
Computing results for 0.2% of edges (N=64)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 25.88it/s]
100%|██████████| 13/13 [00:00<00:00, 35.58it/s]


1.4797326872491254
Computing results for 0.5% of edges (N=162)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 25.18it/s]
100%|██████████| 13/13 [00:00<00:00, 33.95it/s]


1.4797326872491254
Computing results for 1.0% of edges (N=324)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 24.47it/s]
100%|██████████| 13/13 [00:00<00:00, 36.43it/s]


1.5255013003739204
Computing results for 2.0% of edges (N=649)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 24.66it/s]
100%|██████████| 13/13 [00:00<00:00, 36.31it/s]


1.1422398954316484
Computing results for 5.0% of edges (N=1624)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 24.25it/s]
100%|██████████| 13/13 [00:00<00:00, 34.57it/s]


1.0105328817675492
Computing results for 10.0% of edges (N=3249)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 24.23it/s]
100%|██████████| 13/13 [00:00<00:00, 30.82it/s]


1.0173604444031161
Computing results for 20.0% of edges (N=6498)


Computing mean: 100%|██████████| 13/13 [00:00<00:00, 24.43it/s]
 62%|██████▏   | 8/13 [00:00<00:00, 34.48it/s]

In [7]:
results_zero = evaluate_area_under_curve(model, g_patching, dataloader, partial(metric_fn, loss=False, mean=False),
                                    prune=True, intervention='zero')

100%|██████████| 13/13 [00:00<00:00, 29.75it/s]


Computing results for 0.1% of edges (N=32)


100%|██████████| 13/13 [00:00<00:00, 36.55it/s]


1.5060016484245045
Computing results for 0.2% of edges (N=64)


100%|██████████| 13/13 [00:00<00:00, 36.93it/s]


1.5060016484245045
Computing results for 0.5% of edges (N=162)


100%|██████████| 13/13 [00:00<00:00, 35.74it/s]


1.5060016484245045
Computing results for 1.0% of edges (N=324)


100%|██████████| 13/13 [00:00<00:00, 37.11it/s]


1.5538757393657356
Computing results for 2.0% of edges (N=649)


100%|██████████| 13/13 [00:00<00:00, 36.15it/s]


1.5538757393657356
Computing results for 5.0% of edges (N=1624)


100%|██████████| 13/13 [00:00<00:00, 35.98it/s]


1.5538757393657356
Computing results for 10.0% of edges (N=3249)


100%|██████████| 13/13 [00:00<00:00, 34.36it/s]


1.5538757393657356
Computing results for 20.0% of edges (N=6498)


100%|██████████| 13/13 [00:00<00:00, 21.80it/s]


1.5538757393657356
Computing results for 50.0% of edges (N=16245)


100%|██████████| 13/13 [00:00<00:00, 31.45it/s]


1.5538757393657356
Computing results for 100% of edges (N=32491)


100%|██████████| 13/13 [00:00<00:00, 29.84it/s]

1.0
Weighted edge counts: [32, 64, 162, 324, 649, 1624, 3249, 6498, 16245, 32491]



