In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
from functools import partial

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformer_lens import HookedTransformer

from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute 

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

scratch_cache_dir = "/mnt/fast0/rje41/.cache/huggingface"    
model_name = "EleutherAI/pythia-1.4B-deduped"
base_model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=scratch_cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=scratch_cache_dir)

adapter_path = "../fine-tuning/add_sub/results/checkpoint-2000"

model_with_lora = PeftModel.from_pretrained(base_model, adapter_path)
model_with_lora = model_with_lora.merge_and_unload()
model = HookedTransformer.from_pretrained(model_name="pythia-1.4B-deduped", hf_model=model_with_lora, cache_dir=scratch_cache_dir)

model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True

Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [4]:
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath, tokenizer):
        self.df = pd.read_csv(filepath)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        clean = row['clean']
        corrupted = row['corrupted']
        label_str = str(row['label']) 
        label_id = self.tokenizer(label_str, add_special_tokens=False).input_ids[0]
        return clean, corrupted, label_id
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)

def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)
    return logits[idx, input_length - 1]

def prob_diff(
    logits: torch.Tensor,
    clean_logits: torch.Tensor,
    input_length: torch.Tensor,
    labels: torch.Tensor,
    mean: bool = True,
    loss: bool = False
) -> torch.Tensor:
    logits = get_logit_positions(logits, input_length)
    probs = torch.softmax(logits, dim=-1)
    selected_probs = probs[torch.arange(probs.size(0)), labels]
    
    results = selected_probs
    if loss:
        results = -results
    if mean:
        results = results.mean()
    return results

In [5]:
ds = EAPDataset('../datasets/add_sub_test.csv', model.tokenizer)
dataloader = ds.to_dataloader(10)

In [6]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(prob_diff, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

100%|██████████| 50/50 [01:33<00:00,  1.86s/it]


In [7]:
g.apply_topn(200, True)
g.to_json('graph.json')

In [8]:
try:
    import pygraphviz
    gz = g.to_image(f'graph.png')
except ImportError:
    print("No pygraphviz installed; skipping this part")

No pygraphviz installed; skipping this part


In [9]:
baseline = evaluate_baseline(model, dataloader, partial(prob_diff, loss=False, mean=False)).mean().item()
results = evaluate_graph(model, g, dataloader, partial(prob_diff, loss=False, mean=False)).mean().item()
print(f"Original performance was {baseline}; the circuit's performance is {results}")

100%|██████████| 50/50 [00:10<00:00,  4.87it/s]
100%|██████████| 50/50 [00:12<00:00,  4.13it/s]

Original performance was 0.4229940176010132; the circuit's performance is 0.22875460982322693



