In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
from functools import partial

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformer_lens import HookedTransformer
import torch.nn.functional as F
import json
from pathlib import Path

In [3]:
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute 

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def load_model(adapter_path, hf_model_name, translens_model_name, scratch_cache_dir = None):
    base_model = AutoModelForCausalLM.from_pretrained(hf_model_name, cache_dir=scratch_cache_dir)
    model_with_lora = PeftModel.from_pretrained(base_model, adapter_path)
    model_with_lora = model_with_lora.merge_and_unload()
    model = HookedTransformer.from_pretrained(model_name=translens_model_name, hf_model=model_with_lora, cache_dir=scratch_cache_dir)  

    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    model.cfg.ungroup_grouped_query_attention = True
    return model

In [5]:
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['answer']
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)
    
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)

    logits = logits[idx, input_length - 1]
    return logits

def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)

    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)
    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results

In [6]:
def calculate_faithfulness(model, g, dataloader, metric_fn):
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn,skip_clean=False).mean().item()
    faithfulness = abs(baseline_performance - circuit_performance)
    percentage_performance = (1 - faithfulness / baseline_performance) * 100

    print(f"Baseline performance: {baseline_performance}")
    print(f"Circuit performance: {circuit_performance}")
    print(f"Faithfulness: {faithfulness}")
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")

    return faithfulness, percentage_performance

def exact_match_accuracy(model, logits, corrupted_logits, input_lengths, labels):
    batch_size = logits.size(0)
    device = logits.device
    positions = input_lengths - 1

    last_logits = logits[torch.arange(batch_size), positions, :]
    predicted_tokens = last_logits.argmax(dim=-1)
    predicted_strings = [model.to_string(token.item()).strip() for token in predicted_tokens]

    labels_strings = []
    for i in range(batch_size):
        lab = labels[i]
        if isinstance(lab, torch.Tensor):
            lab = lab.item()
        labels_strings.append(str(lab).strip())

    correct = []
    for pred_str, label_str in zip(predicted_strings, labels_strings):
        print(f'pred_str:{pred_str}, label_str:{label_str}')
        if pred_str == label_str:
            correct.append(1.0)
        else:
            correct.append(0.0)

    return torch.tensor(correct, device=device)

def calculate_accuracy(model, g, dataloader):
    baseline_accuracy = evaluate_baseline(model, dataloader, partial(exact_match_accuracy, model)).mean().item()
    graph_accuracy = evaluate_graph(model, g, dataloader, partial(exact_match_accuracy, model)).mean().item()   
    return baseline_accuracy, graph_accuracy

def exact_match_accuracy_diff(model, logits, corrupted_logits, input_lengths, labels, loss, mean):
    batch_size = logits.size(0)
    device = logits.device
    positions = input_lengths - 1

    # Helper to get predicted strings from logits
    def get_predicted_strings(logits_tensor):
        last_logits = logits_tensor[torch.arange(batch_size), positions, :]
        predicted_tokens = last_logits.argmax(dim=-1)
        return [model.to_string(token.item()).strip() for token in predicted_tokens]

    # Get predicted strings for clean and corrupted
    predicted_clean = get_predicted_strings(logits)
    predicted_corrupted = get_predicted_strings(corrupted_logits)

    # Convert labels to strings (same as your original)
    labels_strings = []
    for i in range(batch_size):
        lab = labels[i]
        if isinstance(lab, torch.Tensor):
            lab = lab.item()
        labels_strings.append(str(lab).strip())

    # Compute exact match (1 or 0) for clean and corrupted
    clean_correct = torch.tensor(
        [1.0 if p == l else 0.0 for p, l in zip(predicted_clean, labels_strings)],
        device=device
    )
    corrupted_correct = torch.tensor(
        [1.0 if p == l else 0.0 for p, l in zip(predicted_corrupted, labels_strings)],
        device=device
    )

    # Accuracy difference per example (corrupted - clean)
    acc_diff = corrupted_correct - clean_correct

    if mean:
        return acc_diff.mean()
    else:
        return acc_diff


In [7]:
scratch_cache_dir = "/mnt/faster0/rje41/.cache/huggingface"   
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
model = load_model(
        adapter_path='../../initial-fine-tunng/add_sub_nlp/checkpoints/prompt_id_0/checkpoint-500',
        hf_model_name=hf_model_name,
        translens_model_name=translens_model_name,
        scratch_cache_dir=scratch_cache_dir,
    )


Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [8]:
ds = EAPDataset(f'../../initial-fine-tunng/add_sub_nlp/datasets_csv/prompts_id_0/test.csv')
dataloader = ds.to_dataloader(6)
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(exact_match_accuracy_diff, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

  0%|          | 0/84 [00:01<?, ?it/s]


TypeError: exact_match_accuracy_diff() missing 1 required positional argument: 'labels'