In [1]:
%load_ext autoreload
%autoreload 2

import sys
import torch
import math
import copy
sys.path.append('EAP-IG/src')

from functools import partial

from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from datasets import load_dataset

from eap.graph import Graph
from eap.attribute import attribute
from eap.attribute_node import attribute_node
from eap.evaluate import evaluate_graph, evaluate_baseline
from MIB_circuit_track.dataset import HFEAPDataset
from MIB_circuit_track.metrics import get_metric
from MIB_circuit_track.utils import MODEL_NAME_TO_FULLNAME, TASKS_TO_HF_NAMES
from MIB_circuit_track.evaluation import evaluate_area_under_curve

  from .autonotebook import tqdm as notebook_tqdm


## Loading model and data

In [7]:
model_name = "qwen2.5" # One of ["gpt2", "llama3", "gemma2", "qwen2.5"]
model_name = MODEL_NAME_TO_FULLNAME[model_name]

task_name = "mcqa" # One of ["ioi", "mcqa", "arithmetic_addition", "arithmetic_subtraction", "arc_easy", "arc_challenge"]
dataset_name = f"mib-bench/{TASKS_TO_HF_NAMES[task_name]}"

In [3]:
model = HookedTransformer.from_pretrained(model_name, attn_implementation="eager", torch_dtype=torch.bfloat16, device="cuda")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.


Loaded pretrained model Qwen/Qwen2.5-0.5B into HookedTransformer


In [8]:
dataset = HFEAPDataset(dataset_name, model.tokenizer, task=task_name)
dataset.head(500)
dataloader = dataset.to_dataloader(batch_size=64)
metric_fn = get_metric(metric_name="logit_diff", task=["ignore this"], tokenizer=model.tokenizer, model=model)
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()

Generating train split: 100%|██████████| 110/110 [00:00<00:00, 336.58 examples/s]
Generating validation split: 100%|██████████| 50/50 [00:00<00:00, 3091.14 examples/s]
Generating test split: 100%|██████████| 50/50 [00:00<00:00, 3166.32 examples/s]
Filter: 100%|██████████| 110/110 [00:00<00:00, 1097.19 examples/s]




100%|██████████| 2/2 [00:08<00:00,  4.33s/it]


## Attribution, Pruning, and Evaluation

In [13]:
# Define and attribute the graph in the following nodes

g_edges = Graph.from_model(model) # For edge based attribution
g_nodes = Graph.from_model(model) # For node based attribution
print(len(g_edges.nodes), len(g_edges.edges))

attribute(model, g_edges, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)
attribute_node(model, g_nodes, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

362 179749


100%|██████████| 2/2 [00:36<00:00, 18.23s/it]
100%|██████████| 2/2 [00:33<00:00, 16.85s/it]


In [11]:
# Build the graph given the calculated scores
# Some relevant information:
#   graph.in_graph and graph.nodes_in_graph need to be updated to reflect the edges or nodes that are in the graph
#   graph.scores and graph.nodes_scores contain the scores of the edges or nodes, post attribution (scores will be non-zero only if attribution was done on edges, same for nodes_scores if attribution was done on nodes)

def build_graph_from_attribution_scores(g_edges, g_nodes):
    # NOTE: THIS FUNCTION SHOULD BE REPLACED TO BE BASED ON SOMETHING SMARTER (GRAPH ALGORITHM ETC)
    g_edges.apply_greedy(n_edges=250, absolute=True, reset=True, prune=True)
    g_nodes.apply_node_topn(n_nodes=45, absolute=True, reset=True, prune=True)
    return g_edges

graph = build_graph_from_attribution_scores(g_edges, g_nodes)

In [12]:
# Evaluate the graph to find it's faithfulness (higher = better)

results = evaluate_graph(model, graph, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
print(f"Faithfulness: {results / baseline}")

100%|██████████| 2/2 [00:04<00:00,  2.12s/it]

Faithfulness: -0.23995535714285715





In [18]:
# Similar to the two cells above, but calculates faithfulness 
# across percentages + calculates AUC (area under curve) for the faithfulness scores.
# This is the score that is actually evaluated and important in the end.
 
from einops import einsum
import networkx as nx

percentages = [0, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]

def auc(faithfulnesses, percentages, log_scale=False, ):
    area_under = 0.
    area_from_1 = 0.
    for i in range(len(faithfulnesses) - 1):
        i_1, i_2 = i, i+1
        x_1 = percentages[i_1]
        x_2 = percentages[i_2]
        # area from point to 100
        if log_scale:
            x_1 = math.log(x_1)
            x_2 = math.log(x_2)
        trapezoidal = (x_2 - x_1) * \
                        (((abs(1. - faithfulnesses[i_1])) + (abs(1. - faithfulnesses[i_2]))) / 2)
        area_from_1 += trapezoidal 
        
        trapezoidal = (x_2 - x_1) * ((faithfulnesses[i_1] + faithfulnesses[i_2]) / 2)
        area_under += trapezoidal
    average = sum(faithfulnesses) / len(faithfulnesses)

    return area_under, area_from_1, average


def build_greedy_graph(g_edges, edge_percent):
    n_edges = int(len(g_edges.edges) * edge_percent)
    g_edges.apply_greedy(n_edges=n_edges, absolute=True, reset=True, prune=True)
    return g_edges

def build_topn_graph(g_edges, edge_percent):
    n_edges = int(len(g_edges.edges) * edge_percent)
    g_edges.apply_topn(n=n_edges, absolute=True, reset=True, prune=True)
    return g_edges

def build_graph_from_edges_and_nodes(g_edges, g_nodes, edge_percent, node_percent):
    n_edges = int(len(g_edges.edges) * edge_percent)
    n_nodes = min(len(g_nodes.nodes) - 1, max(0, int(len(g_nodes.nodes) * node_percent)))
    g_edges.apply_greedy(n_edges=n_edges, absolute=True, reset=True, prune=True)

    # Prune nodes based on node graph
    g_nodes.apply_node_topn(n_nodes=n_nodes, absolute=True, reset=True, prune=True)
    
    g_edges.nodes_in_graph = copy.deepcopy(g_nodes.nodes_in_graph) # Take only nodes that were chosen also via node attribution
    g_edges.in_graph[~g_edges.nodes_in_graph] = 0 # Zero out edges going out of nodes that were pruned out

    n_edges_before = g_edges.in_graph.sum()
    g_edges.prune()
    n_edge_after = g_edges.in_graph.sum()

    return g_edges

def build_topn_graph_with_forced_connectivity(g_edges, edge_percent):
    g_edges.reset()
    g_edges.apply_topn(n=int(len(g_edges.edges) * edge_percent), absolute=True, reset=False, prune=False)
    nodes_with_outgoing = g_edges.in_graph.any(dim=1)
    nodes_with_ingoing = einsum(g_edges.in_graph.any(dim=0).float(), g_edges.forward_to_backward.float(), 'backward, forward backward -> forward') > 0
    nodes_with_ingoing[0] = True
    g_edges.nodes_in_graph += nodes_with_outgoing | nodes_with_ingoing # Note the difference from the code in apply_topn - here we add nodes if they have an incoming / outgoing edge
    nx_graph = g_edges.to_networkx(check_in_graph=True)

    print(f"1, {g_edges.in_graph.sum()}")
    # Ensure connectivity of all chosen nodes
    for node in g_edges.nodes.values():
        if node.in_graph:
            # Check connectivity to logits node
            connected_to_output = nx.has_path(nx_graph, node.name, 'logits')
            if not connected_to_output:
                g_edges.edges[f'{node.name}->logits'].in_graph = True
            connected_to_input = nx.has_path(nx_graph, 'input', node.name)
            if not connected_to_input:
                if node.qkv_inputs:
                    possible_edges_to_add = [g_edges.edges[f'input->{node.name}<{letter}>'] for letter in 'qkv']
                    best_edge = max(possible_edges_to_add, key=lambda e: e.score)
                    best_edge.in_graph = True
                else:
                    g_edges.edges[f'input->{node.name}'].in_graph = True
    print(f"2, {g_edges.in_graph.sum()}")

    g_edges.prune()
    print(f"3, {g_edges.in_graph.sum()}")
    return g_edges


g_edges = Graph.from_model(model) # For edge based attribution
# g_nodes = Graph.from_model(model) # For node based attribution
attribute(model, g_edges, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)
# attribute_node(model, g_nodes, dataloader, partial(metric_fn, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

faithfulnesses = []
for edge_percent in percentages:
    # node_percent = edge_percent
    # graph = build_greedy_graph(g_edges, edge_percent)
    # graph = build_graph_from_edges_and_nodes(g_edges, g_nodes, edge_percent, node_percent)
    graph = build_topn_graph(g_edges, edge_percent)
    # graph = build_topn_graph_with_forced_connectivity(g_edges, edge_percent)
    faith = evaluate_graph(model, graph, dataloader, partial(metric_fn, loss=False, mean=False), quiet=True).mean().item() / baseline
    faithfulnesses.append(faith)
    print(edge_percent, faithfulnesses[-1])

print('AUC (Higher is better): ', auc(faithfulnesses + [1.0], percentages + [1.0], log_scale=False)[0])

# Logging some output results (GPT2, IOI):
# Greedy edge-based AUC: 0.9687
# Edge-based + pruning nodes based on node scores (node_percent=edge_percent): 0.5644
# Edge-based topn: 0.9666
# Edge-based topn + forced connectivity: 0.9666


# Logging some output results (QWEN2.5-0.5B, MCQA):
# Greedy edge-based AUC: 0.8602
# Edge-based + pruning nodes based on node scores (node_percent=edge_percent): 0.5618
# Edge-based topn: 0.8630
# Edge-based topn + forced connectivity: 0.8633

100%|██████████| 2/2 [00:36<00:00, 18.20s/it]


0 -0.8125
0.001 -0.21875
0.002 -0.16964285714285715
0.005 0.10770089285714286
0.01 0.17410714285714285
0.02 0.8214285714285714
0.05 0.2734375
0.1 0.4888392857142857
0.2 0.8392857142857143
0.5 0.9508928571428571
AUC (Higher is better):  0.863015625
