In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
from functools import partial

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformer_lens import HookedTransformer
import torch.nn.functional as F

In [3]:
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute 

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def load_adapter_into_hooked_transformer(adapter_path, 
                                         hf_model_name, 
                                         translens_model_name, 
                                         adapter = True, 
                                         scratch_cache_dir = None):
    if adapter == True: 
        base_model = AutoModelForCausalLM.from_pretrained(hf_model_name, cache_dir=scratch_cache_dir)
        model_with_lora = PeftModel.from_pretrained(base_model, adapter_path)
        adapter_model = model_with_lora.merge_and_unload()
        model = HookedTransformer.from_pretrained(model_name=translens_model_name, hf_model=adapter_model, cache_dir=scratch_cache_dir)  
    else:
        model = HookedTransformer.from_pretrained(model_name=translens_model_name, cache_dir=scratch_cache_dir)  

    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    return model

In [5]:
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['label']
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)
    
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)

    logits = logits[idx, input_length - 1]
    return logits

def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)

    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)
    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results

In [6]:
scratch_cache_dir = "/mnt/faster0/rje41/.cache/huggingface"    
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
adapter_path = "/homes/rje41/mech-interp-ft/circuit-analysis/prompt_variation/results/checkpoint-300/prompt_template_0/graph.json"
model = load_adapter_into_hooked_transformer(
        adapter_path=None,
        hf_model_name=hf_model_name,
        translens_model_name=translens_model_name,
        scratch_cache_dir=scratch_cache_dir,
        adapter = False
    )

Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [7]:
ds = EAPDataset('AddSub_corrupt_swap/datasets_csv/Add_Sub_100_ft (2).csv')
dataloader = ds.to_dataloader(6)

In [8]:
def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)

    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)
    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

  0%|          | 0/750 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.88 GiB. GPU 0 has a total capacity of 23.57 GiB of which 1.46 GiB is free. Process 2507595 has 2.06 GiB memory in use. Process 2856799 has 4.32 GiB memory in use. Including non-PyTorch memory, this process has 15.71 GiB memory in use. Of the allocated memory 15.02 GiB is allocated by PyTorch, and 390.46 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
def calculate_faithfulness(model, g, dataloader, metric_fn):
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn,skip_clean=False).mean().item()
    faithfulness = abs(baseline_performance - circuit_performance)
    percentage_performance = (1 - faithfulness / baseline_performance) * 100

    print(f"Baseline performance: {baseline_performance}")
    print(f"Circuit performance: {circuit_performance}")
    print(f"Faithfulness: {faithfulness}")
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")

    return baseline_performance, circuit_performance, faithfulness, percentage_performance

In [None]:
total_edges = len(g.edges)
percent_edges = int(total_edges * 0.05)
g.apply_topn(percent_edges , absolute=True, prune=True)
print('Pruned the graph!')

metric_fn = partial(kl_divergence, loss=False, mean=False)
baseline_performance, circuit_performance, faithfulness, percentage_performance = calculate_faithfulness(model, g, dataloader, metric_fn)

Pruned the graph!


100%|██████████| 167/167 [00:57<00:00,  2.91it/s]
100%|██████████| 167/167 [01:41<00:00,  1.64it/s]

Baseline performance: 3.4469907404854894e-05
Circuit performance: 2.936303280876018e-05
Faithfulness: 5.106874596094713e-06
Percentage of model performance achieved by the circuit: 85.18%



