In [None]:
from functools import partial

from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from datasets import load_dataset

from eap.attribute import attribute
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from mib_evaluations import evaluate_area_under_curve
from dataset import HFEAPDataset
from metrics import get_metric

In [2]:
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
dataset = HFEAPDataset("mech-interp-bench/ioi", model.tokenizer)
dataloader = dataset.to_dataloader(64)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)

In [4]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5, quiet=True)

100%|██████████| 32491/32491 [00:00<00:00, 214067.58it/s]


In [5]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)
g.apply_greedy(250)
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True).mean().item()
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
print(f"Faithfulness: {results / baseline}")
g.to_pt("circuits/greater-than_prob_diff_ig.pt")

100%|██████████| 16/16 [00:11<00:00,  1.44it/s]
100%|██████████| 32491/32491 [00:00<00:00, 225823.29it/s]
100%|██████████| 16/16 [00:02<00:00,  6.46it/s]
100%|██████████| 16/16 [00:01<00:00, 14.72it/s]


Faithfulness: 0.923550621209393


In [6]:
mean_dataset = load_dataset('stas/openwebtext-10k', split='train')

In [7]:
intervention_dataloader = DataLoader(mean_dataset['text'][:1000], batch_size=16)

In [8]:
results2 = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean-positional', intervention_dataloader=dataloader).mean().item()
print(f"Faithfulness: {results2 / baseline}")

Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.79it/s]
100%|██████████| 16/16 [00:02<00:00,  7.83it/s]

Faithfulness: 0.9599007763104539





In [19]:
area_under, area_from_100, average, faithfulnesses = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, quiet=False, node_eval=False, neuron_level=False,
log_scale=True, absolute=True, intervention='mean-positional', intervention_dataloader=dataloader)

100%|██████████| 16/16 [00:01<00:00, 13.87it/s]


Computing results for 0.1% of edges (N=32)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.53it/s]
100%|██████████| 16/16 [00:01<00:00, 12.01it/s]


0.8483858217574802
Computing results for 0.2% of edges (N=64)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.53it/s]
100%|██████████| 16/16 [00:01<00:00, 10.49it/s]


0.8747131106143583
Computing results for 0.5% of edges (N=162)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.61it/s]
100%|██████████| 16/16 [00:01<00:00,  9.44it/s]


0.9573824152433821
Computing results for 1.0% of edges (N=324)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.71it/s]
100%|██████████| 16/16 [00:02<00:00,  6.81it/s]


0.9691415496638135
Computing results for 2.0% of edges (N=649)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.68it/s]
100%|██████████| 16/16 [00:02<00:00,  5.55it/s]


0.8993753504229962
Computing results for 5.0% of edges (N=1624)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.66it/s]
100%|██████████| 16/16 [00:03<00:00,  4.36it/s]


0.9563609733817594
Computing results for 10.0% of edges (N=3249)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.56it/s]
100%|██████████| 16/16 [00:04<00:00,  3.60it/s]


0.9089669493473166
Computing results for 20.0% of edges (N=6498)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.52it/s]
100%|██████████| 16/16 [00:04<00:00,  3.43it/s]


0.9208838198631831
Computing results for 50.0% of edges (N=16245)


Computing mean: 100%|██████████| 16/16 [00:03<00:00,  4.24it/s]
100%|██████████| 16/16 [00:05<00:00,  3.03it/s]


0.9444015763362187
Computing results for 100% of edges (N=32491)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.70it/s]
100%|██████████| 16/16 [00:05<00:00,  2.83it/s]

1.0





In [18]:
g.apply_topn(0, True)
mean_positional_empty_results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean-positional', intervention_dataloader=dataloader).mean().item()
print(mean_positional_empty_results, mean_positional_empty_results / baseline)

Computing mean: 100%|██████████| 16/16 [00:04<00:00,  3.83it/s]
100%|██████████| 16/16 [00:01<00:00, 13.58it/s]

0.41183537244796753 0.5057400567703553





In [10]:
from copy import deepcopy
corrupted_dataset = deepcopy(dataset)
corrupted_dataset.df['clean'] = corrupted_dataset.df['corrupted']

In [17]:
area_under, area_from_100, average, faithfulnesses = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, quiet=False, node_eval=False, neuron_level=False,
log_scale=True, absolute=True, intervention='mean-positional', intervention_dataloader=corrupted_dataset.to_dataloader(64))

100%|██████████| 16/16 [00:01<00:00, 14.06it/s]


Computing results for 0.1% of edges (N=32)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.70it/s]
100%|██████████| 16/16 [00:01<00:00, 11.99it/s]


-0.09244362672979038
Computing results for 0.2% of edges (N=64)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.70it/s]
100%|██████████| 16/16 [00:01<00:00, 10.59it/s]


0.356446612004339
Computing results for 0.5% of edges (N=162)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.59it/s]
100%|██████████| 16/16 [00:01<00:00,  9.34it/s]


0.7991329272452324
Computing results for 1.0% of edges (N=324)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.64it/s]
100%|██████████| 16/16 [00:02<00:00,  7.25it/s]


0.890616276022796
Computing results for 2.0% of edges (N=649)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.70it/s]
100%|██████████| 16/16 [00:03<00:00,  5.17it/s]


0.8831035729604466
Computing results for 5.0% of edges (N=1624)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.68it/s]
100%|██████████| 16/16 [00:03<00:00,  4.35it/s]


1.0117668931332464
Computing results for 10.0% of edges (N=3249)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.39it/s]


0.8495027104258069
Computing results for 20.0% of edges (N=6498)


Computing mean: 100%|██████████| 16/16 [00:03<00:00,  4.18it/s]
100%|██████████| 16/16 [00:04<00:00,  3.45it/s]


0.9172569875992347
Computing results for 50.0% of edges (N=16245)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.52it/s]
100%|██████████| 16/16 [00:05<00:00,  2.92it/s]


0.9364455287123611
Computing results for 100% of edges (N=32491)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.51it/s]
100%|██████████| 16/16 [00:05<00:00,  2.85it/s]

1.0





In [20]:
area_under, area_from_100, average, faithfulnesses = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, quiet=False, node_eval=False, neuron_level=False,
log_scale=True, absolute=True, intervention='mean', intervention_dataloader=corrupted_dataset.to_dataloader(64))

100%|██████████| 16/16 [00:01<00:00, 13.21it/s]


Computing results for 0.1% of edges (N=32)


Computing mean: 100%|██████████| 16/16 [00:02<00:00,  7.52it/s]
100%|██████████| 16/16 [00:01<00:00, 11.86it/s]


-0.023308513686808578
Computing results for 0.2% of edges (N=64)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.18it/s]
100%|██████████| 16/16 [00:01<00:00, 10.36it/s]


-0.01198164959200882
Computing results for 0.5% of edges (N=162)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.39it/s]
100%|██████████| 16/16 [00:01<00:00,  9.50it/s]


-0.042836672324671386
Computing results for 1.0% of edges (N=324)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.53it/s]
100%|██████████| 16/16 [00:02<00:00,  7.40it/s]


-0.013640797260808399
Computing results for 2.0% of edges (N=649)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.66it/s]
100%|██████████| 16/16 [00:02<00:00,  5.56it/s]


-0.0018839437635594486
Computing results for 5.0% of edges (N=1624)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.60it/s]
100%|██████████| 16/16 [00:03<00:00,  4.35it/s]


-0.0039004721378035232
Computing results for 10.0% of edges (N=3249)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.48it/s]
100%|██████████| 16/16 [00:04<00:00,  3.43it/s]


-0.005748425063643404
Computing results for 20.0% of edges (N=6498)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.52it/s]
100%|██████████| 16/16 [00:04<00:00,  3.44it/s]


-0.008445109005340337
Computing results for 50.0% of edges (N=16245)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.66it/s]
100%|██████████| 16/16 [00:05<00:00,  2.89it/s]


0.04274510944908747
Computing results for 100% of edges (N=32491)


Computing mean: 100%|██████████| 16/16 [00:01<00:00,  9.61it/s]
100%|██████████| 16/16 [00:05<00:00,  2.81it/s]

1.0



