In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
from functools import partial

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformer_lens import HookedTransformer
import torch.nn.functional as F

In [3]:
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute 

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def load_model(adapter_path, hf_model_name, translens_model_name, scratch_cache_dir = None):
    base_model = AutoModelForCausalLM.from_pretrained(hf_model_name, cache_dir=scratch_cache_dir)
    model_with_lora = PeftModel.from_pretrained(base_model, adapter_path)
    model_with_lora = model_with_lora.merge_and_unload()
    model = HookedTransformer.from_pretrained(model_name=translens_model_name, hf_model=model_with_lora, cache_dir=scratch_cache_dir)  

    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    model.cfg.ungroup_grouped_query_attention = True
    return model

In [5]:
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['label']
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)
    
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)

    logits = logits[idx, input_length - 1]
    return logits

def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)

    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)
    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results

In [6]:
scratch_cache_dir = "/mnt/fast0/rje41/.cache/huggingface"    
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
adapter_path = "../fine-tuning/sequential-fine-tuning/adapter_checkpoints/add_sub_100/checkpoint-282"
model = load_model(
        adapter_path=adapter_path,
        hf_model_name=hf_model_name,
        translens_model_name=translens_model_name,
        scratch_cache_dir=scratch_cache_dir,
    )

Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [13]:
ds = EAPDataset('../datasets/Add_Sub_100_circuit.csv')
dataloader = ds.to_dataloader(6)

In [14]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

100%|██████████| 84/84 [05:10<00:00,  3.70s/it]


In [None]:
total_edges = len(g.edges)
five_percent_edges = int(total_edges * 0.05)
g.apply_topn(five_percent_edges , absolute=True)
g.to_json('graph.json')
g.edges

{'input': Node(input, in_graph: True),
 'a0.h0': Node(a0.h0, in_graph: True),
 'a0.h1': Node(a0.h1, in_graph: True),
 'a0.h2': Node(a0.h2, in_graph: True),
 'a0.h3': Node(a0.h3, in_graph: True),
 'a0.h4': Node(a0.h4, in_graph: True),
 'a0.h5': Node(a0.h5, in_graph: True),
 'a0.h6': Node(a0.h6, in_graph: True),
 'a0.h7': Node(a0.h7, in_graph: True),
 'a0.h8': Node(a0.h8, in_graph: True),
 'a0.h9': Node(a0.h9, in_graph: True),
 'a0.h10': Node(a0.h10, in_graph: True),
 'a0.h11': Node(a0.h11, in_graph: True),
 'a0.h12': Node(a0.h12, in_graph: True),
 'a0.h13': Node(a0.h13, in_graph: True),
 'a0.h14': Node(a0.h14, in_graph: True),
 'a0.h15': Node(a0.h15, in_graph: True),
 'm0': Node(m0, in_graph: True),
 'a1.h0': Node(a1.h0, in_graph: True),
 'a1.h1': Node(a1.h1, in_graph: True),
 'a1.h2': Node(a1.h2, in_graph: True),
 'a1.h3': Node(a1.h3, in_graph: True),
 'a1.h4': Node(a1.h4, in_graph: True),
 'a1.h5': Node(a1.h5, in_graph: True),
 'a1.h6': Node(a1.h6, in_graph: True),
 'a1.h7': Node(a1.h

In [16]:
def calculate_faithfulness(model, g, dataloader, metric_fn):
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn,skip_clean=False).mean().item()
    faithfulness = abs(baseline_performance - circuit_performance)
    percentage_performance = (1 - faithfulness / baseline_performance) * 100

    print(f"Baseline performance: {baseline_performance}")
    print(f"Circuit performance: {circuit_performance}")
    print(f"Faithfulness: {faithfulness}")
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")

    return faithfulness, percentage_performance

metric_fn = partial(kl_divergence, loss=False, mean=False)
faithfulness, percentage_performance = calculate_faithfulness(model, g, dataloader, metric_fn)

100%|██████████| 84/84 [00:31<00:00,  2.70it/s]
  1%|          | 1/84 [00:00<00:55,  1.49it/s]

here


  2%|▏         | 2/84 [00:01<00:55,  1.48it/s]

here


  4%|▎         | 3/84 [00:02<00:54,  1.49it/s]

here


  5%|▍         | 4/84 [00:02<00:53,  1.50it/s]

here


  6%|▌         | 5/84 [00:03<00:52,  1.49it/s]

here


  7%|▋         | 6/84 [00:04<00:52,  1.50it/s]

here


  8%|▊         | 7/84 [00:04<00:51,  1.50it/s]

here


 10%|▉         | 8/84 [00:05<00:50,  1.50it/s]

here


 11%|█         | 9/84 [00:06<00:50,  1.49it/s]

here


 12%|█▏        | 10/84 [00:06<00:49,  1.49it/s]

here


 13%|█▎        | 11/84 [00:07<00:48,  1.50it/s]

here


 14%|█▍        | 12/84 [00:08<00:48,  1.49it/s]

here


 15%|█▌        | 13/84 [00:08<00:47,  1.49it/s]

here


 17%|█▋        | 14/84 [00:09<00:47,  1.49it/s]

here


 18%|█▊        | 15/84 [00:10<00:46,  1.50it/s]

here


 19%|█▉        | 16/84 [00:10<00:45,  1.49it/s]

here


 20%|██        | 17/84 [00:11<00:44,  1.49it/s]

here


 21%|██▏       | 18/84 [00:12<00:44,  1.50it/s]

here


 23%|██▎       | 19/84 [00:12<00:43,  1.49it/s]

here


 24%|██▍       | 20/84 [00:13<00:42,  1.50it/s]

here


 25%|██▌       | 21/84 [00:14<00:42,  1.49it/s]

here


 26%|██▌       | 22/84 [00:14<00:41,  1.50it/s]

here


 27%|██▋       | 23/84 [00:15<00:40,  1.49it/s]

here


 29%|██▊       | 24/84 [00:16<00:40,  1.49it/s]

here


 30%|██▉       | 25/84 [00:16<00:39,  1.50it/s]

here


 31%|███       | 26/84 [00:17<00:38,  1.50it/s]

here


 32%|███▏      | 27/84 [00:18<00:37,  1.51it/s]

here


 33%|███▎      | 28/84 [00:18<00:37,  1.50it/s]

here


 35%|███▍      | 29/84 [00:19<00:36,  1.50it/s]

here


 36%|███▌      | 30/84 [00:20<00:36,  1.49it/s]

here


 37%|███▋      | 31/84 [00:20<00:35,  1.50it/s]

here


 38%|███▊      | 32/84 [00:21<00:34,  1.51it/s]

here


 39%|███▉      | 33/84 [00:22<00:33,  1.50it/s]

here


 40%|████      | 34/84 [00:22<00:33,  1.50it/s]

here


 42%|████▏     | 35/84 [00:23<00:32,  1.50it/s]

here


 43%|████▎     | 36/84 [00:24<00:32,  1.50it/s]

here


 44%|████▍     | 37/84 [00:24<00:31,  1.49it/s]

here


 45%|████▌     | 38/84 [00:25<00:30,  1.49it/s]

here


 46%|████▋     | 39/84 [00:26<00:30,  1.50it/s]

here


 48%|████▊     | 40/84 [00:26<00:29,  1.49it/s]

here


 49%|████▉     | 41/84 [00:27<00:28,  1.49it/s]

here


 50%|█████     | 42/84 [00:28<00:28,  1.49it/s]

here


 51%|█████     | 43/84 [00:28<00:27,  1.50it/s]

here


 52%|█████▏    | 44/84 [00:29<00:26,  1.49it/s]

here


 54%|█████▎    | 45/84 [00:30<00:25,  1.50it/s]

here


 55%|█████▍    | 46/84 [00:30<00:25,  1.50it/s]

here


 56%|█████▌    | 47/84 [00:31<00:24,  1.49it/s]

here


 57%|█████▋    | 48/84 [00:32<00:24,  1.48it/s]

here


 58%|█████▊    | 49/84 [00:32<00:23,  1.48it/s]

here


 60%|█████▉    | 50/84 [00:33<00:22,  1.49it/s]

here


 61%|██████    | 51/84 [00:34<00:22,  1.49it/s]

here


 62%|██████▏   | 52/84 [00:34<00:21,  1.49it/s]

here


 63%|██████▎   | 53/84 [00:35<00:20,  1.50it/s]

here


 64%|██████▍   | 54/84 [00:36<00:20,  1.50it/s]

here


 65%|██████▌   | 55/84 [00:36<00:19,  1.50it/s]

here


 67%|██████▋   | 56/84 [00:37<00:18,  1.50it/s]

here


 68%|██████▊   | 57/84 [00:38<00:17,  1.50it/s]

here


 69%|██████▉   | 58/84 [00:38<00:17,  1.50it/s]

here


 70%|███████   | 59/84 [00:39<00:16,  1.51it/s]

here


 71%|███████▏  | 60/84 [00:40<00:15,  1.51it/s]

here


 73%|███████▎  | 61/84 [00:40<00:15,  1.49it/s]

here


 74%|███████▍  | 62/84 [00:41<00:14,  1.50it/s]

here


 75%|███████▌  | 63/84 [00:42<00:13,  1.50it/s]

here


 76%|███████▌  | 64/84 [00:42<00:13,  1.50it/s]

here


 77%|███████▋  | 65/84 [00:43<00:12,  1.50it/s]

here


 79%|███████▊  | 66/84 [00:44<00:11,  1.50it/s]

here


 80%|███████▉  | 67/84 [00:44<00:11,  1.51it/s]

here


 81%|████████  | 68/84 [00:45<00:10,  1.50it/s]

here


 82%|████████▏ | 69/84 [00:46<00:09,  1.50it/s]

here


 83%|████████▎ | 70/84 [00:46<00:09,  1.50it/s]

here


 85%|████████▍ | 71/84 [00:47<00:08,  1.50it/s]

here


 86%|████████▌ | 72/84 [00:48<00:07,  1.50it/s]

here


 87%|████████▋ | 73/84 [00:48<00:07,  1.50it/s]

here


 88%|████████▊ | 74/84 [00:49<00:06,  1.50it/s]

here


 89%|████████▉ | 75/84 [00:50<00:05,  1.51it/s]

here


 90%|█████████ | 76/84 [00:50<00:05,  1.50it/s]

here


 92%|█████████▏| 77/84 [00:51<00:04,  1.50it/s]

here


 93%|█████████▎| 78/84 [00:52<00:03,  1.50it/s]

here


 94%|█████████▍| 79/84 [00:52<00:03,  1.51it/s]

here


 95%|█████████▌| 80/84 [00:53<00:02,  1.51it/s]

here


 96%|█████████▋| 81/84 [00:54<00:02,  1.50it/s]

here


 98%|█████████▊| 82/84 [00:54<00:01,  1.49it/s]

here


 99%|█████████▉| 83/84 [00:55<00:00,  1.49it/s]

here


100%|██████████| 84/84 [00:55<00:00,  1.51it/s]

here
Baseline performance: 0.0003850016219075769
Circuit performance: 0.00032297518919222057
Faithfulness: 6.202643271535635e-05
Percentage of model performance achieved by the circuit: 83.89%



