In [1]:
from functools import partial

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from datasets import load_dataset

from graph import Graph
from gt_dataset import EAPDataset
from attribute import attribute
from metrics import get_metric
from evaluate_graph import evaluate_graph, evaluate_baseline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
dataset = EAPDataset("greater-than-gpt2.csv")
dataloader = dataset.to_dataloader(64)
metric_fn = get_metric("prob_diff", "greater-than", model.tokenizer, model)

In [4]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, loss=True, mean=True), integrated_gradients=5)
g.apply_greedy(250)
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True).mean().item()
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
print(f"Faithfulness: {results / baseline}")
g.to_pt("circuits/greater-than_prob_diff_ig.pt")

100%|██████████| 16/16 [00:11<00:00,  1.42it/s]
100%|██████████| 32491/32491 [00:00<00:00, 87434.82it/s] 
100%|██████████| 16/16 [00:02<00:00,  7.55it/s]
100%|██████████| 16/16 [00:01<00:00, 14.27it/s]


Faithfulness: 0.923550694404797


In [5]:
mean_dataset = load_dataset('stas/openwebtext-10k', split='train')

In [6]:
intervention_dataloader = DataLoader(mean_dataset['text'][:1000], batch_size=16)

In [7]:
results2 = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False), prune=True, intervention='mean-positional', intervention_dataloader=dataloader).mean().item()
print(f"Faithfulness: {results2 / baseline}")

Computing mean: 100%|██████████| 16/16 [00:02<00:00,  5.67it/s]
100%|██████████| 16/16 [00:02<00:00,  7.48it/s]

Faithfulness: 0.95990070311505



