In [1]:
from functools import partial

from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from datasets import load_dataset

from eap.graph import Graph
from eap.attribute import attribute
from eap.evaluate import evaluate_graph, evaluate_baseline
from MIB_circuit_track.dataset import HFEAPDataset
from MIB_circuit_track.metrics import get_metric

In [2]:
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

Loaded pretrained model gpt2-small into HookedTransformer


In [7]:
dataset = HFEAPDataset("mech-interp-bench/ioi", model.tokenizer)
dataset.head(500)
dataloader = dataset.to_dataloader(64)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

100%|██████████| 8/8 [00:05<00:00,  1.52it/s]


This cell uses activation-space EAP-IG with 5 steps to attribute using patching

In [None]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-activations', ig_steps=5)
g.apply_greedy(250)
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True).mean().item()
print(f"Faithfulness: {results / baseline}")

  0%|          | 0/8 [01:18<?, ?it/s]


KeyboardInterrupt: 

Same thing, but EAP-IG in input space, which strangely enough does better (normally they're about the same)

In [8]:
g2 = Graph.from_model(model)
attribute(model, g2, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)
g2.apply_greedy(250)
results = evaluate_graph(model, g2, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
print(f"Faithfulness: {results / baseline}")

100%|██████████| 8/8 [01:00<00:00,  7.62s/it]
100%|██████████| 8/8 [00:05<00:00,  1.37it/s]

Faithfulness: 0.49052765344109756





What about if we apply zero ablations to these graphs?

In [9]:
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), intervention='zero').mean().item()
print(f"Faithfulness: {results / baseline}")

100%|██████████| 8/8 [00:02<00:00,  2.86it/s]

Faithfulness: 0.006563079237567003





And if we use an attribution method that natively uses zero gradients? Also consider trying with more ig_steps!

In [10]:
g3 = Graph.from_model(model)
attribute(model, g3, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-activations', intervention='zero', ig_steps=5)
g3.apply_greedy(250)
results = evaluate_graph(model, g3, dataloader, partial(metric_fn, loss=False, mean=False), intervention='zero').mean().item()
print(f"Faithfulness: {results / baseline}")

  0%|          | 0/8 [02:25<?, ?it/s]


KeyboardInterrupt: 

What about the reverse—zero attribution but patching evaluation?

In [None]:
g3.apply_greedy(250)
results = evaluate_graph(model, g3, dataloader, partial(metric_fn, loss=False, mean=False), intervention='zero').mean().item()
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
print(f"Faithfulness: {results / baseline}")

NameError: name 'g3' is not defined

Now, let's do some mean ablations! Here's a generic dataset:

In [11]:
mean_dataset = load_dataset('stas/openwebtext-10k', split='train')
intervention_dataloader = DataLoader(mean_dataset['text'][:1000], batch_size=16)

Let's do mean ablations on a graph scored with patching ablations

In [None]:
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean', intervention_dataloader=intervention_dataloader).mean().item()
print(f"Faithfulness: {results / baseline}")

Computing mean: 100%|██████████| 63/63 [00:21<00:00,  2.87it/s]
100%|██████████| 8/8 [00:01<00:00,  7.04it/s]

Faithfulness: -2.112404321439537e-05





And now on a graph scored with mean ablations

In [12]:
g4 = Graph.from_model(model)
attribute(model, g4, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-activations', intervention='mean', ig_steps=5, intervention_dataloader=intervention_dataloader)
g4.apply_greedy(250)
results = evaluate_graph(model, g4, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean', intervention_dataloader=intervention_dataloader).mean().item()
print(f"Faithfulness: {results / baseline}")

Computing mean:   0%|          | 0/63 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 GiB. GPU 0 has a total capacity of 79.15 GiB of which 11.01 GiB is free. Process 964633 has 5.01 GiB memory in use. Including non-PyTorch memory, this process has 63.12 GiB memory in use. Of the allocated memory 44.06 GiB is allocated by PyTorch, and 18.56 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

What if we do positional ablations, which preserve task information? We can do this using either the original dataset (but note that in the greater-than task, this leads to positional mean ablations being a less-destructive ablation than patching) or just the corrupted version of the dataset. Let's start with the original:

In [None]:
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean-positional', intervention_dataloader=dataloader).mean().item()
print(f"Faithfulness: {results / baseline}")

Computing mean: 100%|██████████| 8/8 [00:01<00:00,  5.78it/s]
100%|██████████| 8/8 [00:01<00:00,  7.10it/s]

Faithfulness: 0.9017045250395437





In [None]:
g5 = Graph.from_model(model)
attribute(model, g5, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-activations', intervention='mean-positional', ig_steps=5, intervention_dataloader=dataloader)
g5.apply_greedy(250)
results = evaluate_graph(model, g5, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean-positional', intervention_dataloader=dataloader).mean().item()
print(f"Faithfulness: {results / baseline}")

Computing mean: 100%|██████████| 8/8 [00:02<00:00,  3.61it/s]
100%|██████████| 8/8 [02:05<00:00, 15.74s/it]
100%|██████████| 32491/32491 [00:00<00:00, 202402.10it/s]
Computing mean: 100%|██████████| 8/8 [00:01<00:00,  5.73it/s]
100%|██████████| 8/8 [00:00<00:00,  8.00it/s]

Faithfulness: 0.9145310149099368





And then try the corrupted:

In [None]:
from copy import deepcopy
corrupted_dataset = deepcopy(dataset)
corrupted_dataset.df['clean'] = corrupted_dataset.df['corrupted']
corrupted_dataloader = corrupted_dataset.to_dataloader(16)

In [None]:
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean-positional', intervention_dataloader=corrupted_dataloader).mean().item()
print(f"Faithfulness: {results / baseline}")

Computing mean: 100%|██████████| 32/32 [00:05<00:00,  5.60it/s]
100%|██████████| 8/8 [00:01<00:00,  7.05it/s]

Faithfulness: 0.7456998058356434





In [None]:
g6 = Graph.from_model(model)
attribute(model, g6, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-activations', intervention='mean-positional', ig_steps=5, intervention_dataloader=corrupted_dataloader)
g6.apply_greedy(250)
results = evaluate_graph(model, g6, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean-positional', intervention_dataloader=corrupted_dataloader).mean().item()
print(f"Faithfulness: {results / baseline}")

Computing mean: 100%|██████████| 32/32 [00:05<00:00,  5.47it/s]
100%|██████████| 8/8 [02:07<00:00, 15.88s/it]
100%|██████████| 32491/32491 [00:00<00:00, 199661.17it/s]
Computing mean: 100%|██████████| 32/32 [00:05<00:00,  5.53it/s]
100%|██████████| 8/8 [00:01<00:00,  7.56it/s]

Faithfulness: 0.7835840997563741



