In [14]:
%load_ext autoreload
%autoreload 2

from functools import partial

import torch
from transformer_lens import HookedTransformer

from graph import Graph
from attribute import attribute
from dataset import HFEAPDataset
from metrics import get_metric
from evaluate_graph import evaluate_graph, evaluate_baseline, evaluate_area_under_curve

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
model_name = "gpt2-small"
model = HookedTransformer.from_pretrained(model_name, device="cuda")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True



Loaded pretrained model Qwen/Qwen2-1.5B into HookedTransformer


In [None]:
dataset = HFEAPDataset("danaarad/ioi_dataset", model.tokenizer, task="ioi", num_examples=100)
dataloader = dataset.to_dataloader(20)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, loss=True, mean=True), 'EAP')

100%|██████████| 10/10 [00:08<00:00,  1.20it/s]


In [None]:
g.apply_topn(300, True)
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()

print(f"Faithfulness: {results / baseline}. Original {baseline}, new {results}")

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:02<00:00,  4.82it/s]
100%|██████████| 10/10 [00:03<00:00,  2.52it/s]

Faithfulness: 0.5966821924853728. Original 0.9711557626724243, new 0.5794713497161865





In [19]:
results = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False))

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:02<00:00,  4.64it/s]
100%|██████████| 10/10 [00:02<00:00,  3.42it/s]


Computing results for 0.1% of edges (N=183)


100%|██████████| 10/10 [00:03<00:00,  2.93it/s]


Computing results for 0.2% of edges (N=367)


100%|██████████| 10/10 [00:03<00:00,  2.75it/s]


Computing results for 0.5% of edges (N=917)


100%|██████████| 10/10 [00:03<00:00,  2.63it/s]


Computing results for 1.0% of edges (N=1835)


100%|██████████| 10/10 [00:04<00:00,  2.50it/s]


Computing results for 2.0% of edges (N=3671)


100%|██████████| 10/10 [00:03<00:00,  2.58it/s]


Computing results for 5.0% of edges (N=9177)


100%|██████████| 10/10 [00:04<00:00,  2.49it/s]


Computing results for 10.0% of edges (N=18355)


100%|██████████| 10/10 [00:03<00:00,  2.54it/s]


Computing results for 20.0% of edges (N=36711)


100%|██████████| 10/10 [00:03<00:00,  2.54it/s]


Computing results for 50.0% of edges (N=91777)


100%|██████████| 10/10 [00:03<00:00,  2.50it/s]


Computing results for 100% of edges (N=183555)


100%|██████████| 10/10 [00:03<00:00,  2.55it/s]
