In [1]:
%load_ext autoreload
%autoreload 2

import sys
import torch
sys.path.append('EAP-IG/src')

from functools import partial

from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from datasets import load_dataset

from eap.graph import Graph
from eap.attribute import attribute
from eap.attribute_node import attribute_node
from eap.evaluate import evaluate_graph, evaluate_baseline
from MIB_circuit_track.dataset import HFEAPDataset
from MIB_circuit_track.metrics import get_metric
from MIB_circuit_track.utils import MODEL_NAME_TO_FULLNAME, TASKS_TO_HF_NAMES

## Loading model and data

In [2]:
model_name = "gpt2" # One of ["gpt2", "llama3", "gemma2", "qwen2.5"]
model_name = MODEL_NAME_TO_FULLNAME[model_name]

dataset_name = "ioi" # One of ["ioi", "mcqa", "arithmetic_addition", "arithmetic_subtraction", "arc_easy", "arc_challenge"]
dataset_name = f"mib-bench/{TASKS_TO_HF_NAMES[dataset_name]}"

In [None]:
model = HookedTransformer.from_pretrained(model_name, attn_implementation="eager", torch_dtype=torch.bfloat16, device="cuda")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True

In [None]:
dataset = HFEAPDataset(dataset_name, model.tokenizer)
dataset.head(500)
dataloader = dataset.to_dataloader(batch_size=64)
metric_fn = get_metric(metric_name="logit_diff", task=["ignore this"], tokenizer=model.tokenizer, model=model)
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()

## Attribution, Pruning, and Evaluation

In [None]:
# Define and attribute the graph in the following nodes

g_edges = Graph.from_model(model) # For edge based attribution
g_nodes = Graph.from_model(model) # For node based attribution

attribute(model, g_edges, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)
attribute_node(model, g_nodes, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

In [6]:
# Build the graph given the calculated scores
# Some relevant information:
#   graph.in_graph and graph.nodes_in_graph need to be updated to reflect the edges or nodes that are in the graph
#   graph.scores and graph.nodes_scores contain the scores of the edges or nodes, post attribution (scores will be non-zero only if attribution was done on edges, same for nodes_scores if attribution was done on nodes)

def build_graph_from_attribution_scores(g_edges, g_nodes):
    # NOTE: THIS FUNCTION SHOULD BE REPLACED TO BE BASED ON SOMETHING SMARTER (GRAPH ALGORITHM ETC)
    g_edges.apply_greedy(n_edges=250, absolute=True, reset=True, prune=True)
    g_nodes.apply_node_topn(n_nodes=45, absolute=True, reset=True, prune=True)
    return g_edges

graph = build_graph_from_attribution_scores(g_edges, g_nodes)

In [None]:
# Evaluate the graph to find it's faithfulness (higher = better)

results = evaluate_graph(model, graph, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
print(f"Faithfulness: {results / baseline}")