In [6]:
%load_ext autoreload
%autoreload 2

from functools import partial

import torch
import pickle
#from nnsight.models import UnifiedTransformer
from transformer_lens import HookedTransformer, HookedTransformerConfig

from graph import Graph
from circuit_loading import load_graph_from_json, load_graph_from_pt

from dataset import EAPDataset, HFEAPDataset
from attribute import attribute
from metrics import get_metric
from evaluate_graph import evaluate_graph, evaluate_baseline
from huggingface_hub import hf_hub_download

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
g_patching = load_graph_from_json("circuits/ioi_hf_prob_diff_vanilla_gpt2.json")
g_zero = load_graph_from_json("circuits/ioi_hf_prob_diff_vanilla_zero_gpt2.json")
g_mean = load_graph_from_json("circuits/ioi_hf_prob_diff_vanilla_mean_gpt2.json")

In [8]:
model = HookedTransformer.from_pretrained("gpt2-small")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
dataset = HFEAPDataset("danaarad/ioi_dataset", model.tokenizer, task="ioi", num_examples=50)
dataloader = dataset.to_dataloader(10)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)

Loaded pretrained model gpt2-small into HookedTransformer


In [10]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, mean=True, loss=True), 'EAP-IG-inputs', intervention='patching', ig_steps=5)

100%|██████████| 5/5 [00:01<00:00,  2.90it/s]
100%|██████████| 32491/32491 [00:00<00:00, 210820.86it/s]


In [11]:
g_patching.apply_greedy(300)
evaluate_graph(model, g_patching, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

100%|██████████| 5/5 [00:00<00:00, 18.12it/s]


tensor(-0.3557)

In [12]:
g.apply_greedy(300)
evaluate_graph(model, g, dataloader, partial(metric_fn, mean=False, loss=False)).mean()

100%|██████████| 5/5 [00:00<00:00, 11.93it/s]


tensor(0.3332)

In [5]:
from evaluate_graph import evaluate_area_under_curve

results_patching = evaluate_area_under_curve(model, g_patching, dataloader, partial(metric_fn, loss=False, mean=False),
                                    prune=True, absolute=False)

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00, 22.33it/s]


Computing results for 0.1% of edges (N=32)


100%|██████████| 5/5 [00:00<00:00, 16.49it/s]


-0.7978444633618875
Computing results for 0.2% of edges (N=64)


100%|██████████| 5/5 [00:00<00:00, 18.31it/s]


-0.7978444633618875
Computing results for 0.5% of edges (N=162)


100%|██████████| 5/5 [00:00<00:00, 18.41it/s]


-0.7978444633618875
Computing results for 1.0% of edges (N=324)


100%|██████████| 5/5 [00:00<00:00, 18.36it/s]


-0.7978443296487108
Computing results for 2.0% of edges (N=649)


100%|██████████| 5/5 [00:00<00:00, 18.32it/s]


-0.7978442627921224
Computing results for 5.0% of edges (N=1624)


100%|██████████| 5/5 [00:00<00:00, 16.09it/s]


-0.8003544602604105
Computing results for 10.0% of edges (N=3249)


100%|██████████| 5/5 [00:00<00:00, 15.21it/s]


-0.8032236776084382
Computing results for 20.0% of edges (N=6498)


100%|██████████| 5/5 [00:00<00:00, 13.75it/s]


-0.8035619050891837
Computing results for 50.0% of edges (N=16245)


100%|██████████| 5/5 [00:00<00:00, 10.53it/s]


-0.7471560042999483
Computing results for 100% of edges (N=32491)


100%|██████████| 5/5 [00:00<00:00,  7.13it/s]

1.0
Weighted edge counts: [1, 1, 1, 5, 15, 938, 2718, 6085, 15916, 32491]





In [6]:
results_mean = evaluate_area_under_curve(model, g_patching, dataloader, partial(metric_fn, loss=False, mean=False),
                                    prune=True, intervention='mean', intervention_dataloader=dataloader, absolute=True)

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00, 23.58it/s]


Computing results for 0.1% of edges (N=32)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.25it/s]
100%|██████████| 5/5 [00:00<00:00, 30.27it/s]


1.6626501073675077
Computing results for 0.2% of edges (N=64)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.44it/s]
100%|██████████| 5/5 [00:00<00:00, 30.23it/s]


1.6626501073675077
Computing results for 0.5% of edges (N=162)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.26it/s]
100%|██████████| 5/5 [00:00<00:00, 29.88it/s]


1.6626501073675077
Computing results for 1.0% of edges (N=324)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.46it/s]
100%|██████████| 5/5 [00:00<00:00, 30.19it/s]


1.3751830104779748
Computing results for 2.0% of edges (N=649)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 17.78it/s]
100%|██████████| 5/5 [00:00<00:00, 29.65it/s]


1.185539189242324
Computing results for 5.0% of edges (N=1624)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.59it/s]
100%|██████████| 5/5 [00:00<00:00, 28.77it/s]


0.9540514278320689
Computing results for 10.0% of edges (N=3249)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.15it/s]
100%|██████████| 5/5 [00:00<00:00, 28.32it/s]


0.9611719827063687
Computing results for 20.0% of edges (N=6498)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.66it/s]
100%|██████████| 5/5 [00:00<00:00, 27.85it/s]


0.9694328113960443
Computing results for 50.0% of edges (N=16245)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 17.98it/s]
100%|██████████| 5/5 [00:00<00:00, 27.26it/s]


0.9841225612718604
Computing results for 100% of edges (N=32491)


Computing mean: 100%|██████████| 5/5 [00:00<00:00, 18.61it/s]
100%|██████████| 5/5 [00:00<00:00, 12.06it/s]

1.0
Weighted edge counts: [32, 64, 162, 324, 649, 1624, 3249, 6498, 16245, 32491]





In [10]:
results_zero = evaluate_area_under_curve(model, g_patching, dataloader, partial(metric_fn, loss=False, mean=False),
                                    prune=True, intervention='zero')

 46%|████▌     | 6/13 [00:00<00:00, 28.77it/s]

100%|██████████| 13/13 [00:00<00:00, 28.22it/s]


Computing results for 0.1% of edges (N=32)


100%|██████████| 13/13 [00:00<00:00, 34.83it/s]


1.4927942930534859
Computing results for 0.2% of edges (N=64)


100%|██████████| 13/13 [00:00<00:00, 36.72it/s]


1.4927942930534859
Computing results for 0.5% of edges (N=162)


100%|██████████| 13/13 [00:00<00:00, 37.24it/s]


1.4927942930534859
Computing results for 1.0% of edges (N=324)


100%|██████████| 13/13 [00:00<00:00, 36.49it/s]


1.5566677254913672
Computing results for 2.0% of edges (N=649)


100%|██████████| 13/13 [00:00<00:00, 35.36it/s]


1.1938927218205513
Computing results for 5.0% of edges (N=1624)


100%|██████████| 13/13 [00:00<00:00, 35.63it/s]


1.0566527173088132
Computing results for 10.0% of edges (N=3249)


100%|██████████| 13/13 [00:00<00:00, 32.93it/s]


1.0566527173088132
Computing results for 20.0% of edges (N=6498)


100%|██████████| 13/13 [00:00<00:00, 33.40it/s]


1.0566527173088132
Computing results for 50.0% of edges (N=16245)


100%|██████████| 13/13 [00:00<00:00, 32.46it/s]


1.0566527173088132
Computing results for 100% of edges (N=32491)


100%|██████████| 13/13 [00:00<00:00, 20.70it/s]

1.0
Weighted edge counts: [32, 64, 162, 324, 649, 1624, 3249, 6498, 16245, 32491]



