In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute 
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from peft import PeftModel
from functools import partial

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformer_lens import HookedTransformer
import torch.nn.functional as F

In [3]:
def load_model(adapter_path, hf_model_name, translens_model_name, scratch_cache_dir = None):
    base_model = AutoModelForCausalLM.from_pretrained(hf_model_name, cache_dir=scratch_cache_dir)
    model_with_lora = PeftModel.from_pretrained(base_model, adapter_path)
    model_with_lora = model_with_lora.merge_and_unload()
    model = HookedTransformer.from_pretrained(model_name=translens_model_name, hf_model=model_with_lora, cache_dir=scratch_cache_dir)  

    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    model.cfg.ungroup_grouped_query_attention = True
    return model

In [4]:
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath):
        self.df = pd.read_csv(filepath)

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return row['clean'], row['corrupted'], row['label']
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)
    
def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)

    logits = logits[idx, input_length - 1]
    return logits

def kl_divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, loss=True):
    logits = get_logit_positions(logits, input_length)
    clean_logits = get_logit_positions(clean_logits, input_length)

    probs = torch.softmax(logits, dim=-1)
    clean_probs = torch.softmax(clean_logits, dim=-1)
    results = F.kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
    return results.mean() if mean else results

In [5]:
scratch_cache_dir = "/mnt/fast0/rje41/.cache/huggingface"    
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
adapter_path = "../../fine-tuning/joint_training/checkpoints/add_sub/checkpoint-300"
model = load_model(
        adapter_path=adapter_path,
        hf_model_name=hf_model_name,
        translens_model_name=translens_model_name,
        scratch_cache_dir=scratch_cache_dir,
    )

Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [6]:
ds = EAPDataset('../../datasets/Add_Sub_100_circuit.csv')
dataloader = ds.to_dataloader(6)

In [7]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

100%|██████████| 84/84 [05:09<00:00,  3.69s/it]


In [8]:
total_edges = len(g.edges)
five_percent_edges = int(total_edges * 0.05)
g.apply_greedy(five_percent_edges , absolute=True)
g.to_json('add_sub_graph.json')

In [11]:
def calculate_faithfulness(model, g, dataloader, metric_fn):
    baseline_performance = evaluate_baseline(model, dataloader, metric_fn).mean().item()
    circuit_performance = evaluate_graph(model, g, dataloader, metric_fn,skip_clean=False).mean().item()
    faithfulness = abs(baseline_performance - circuit_performance)
    percentage_performance = (1 - faithfulness / baseline_performance) * 100

    print(f"Baseline performance: {baseline_performance}")
    print(f"Circuit performance: {circuit_performance}")
    print(f"Faithfulness: {faithfulness}")
    print(f"Percentage of model performance achieved by the circuit: {percentage_performance:.2f}%")

    return faithfulness, percentage_performance

In [None]:
metric_fn = partial(kl_divergence, loss=False, mean=False)
faithfulness, percentage_performance = calculate_faithfulness(model, g, dataloader, metric_fn)

100%|██████████| 84/84 [00:31<00:00,  2.71it/s]
100%|██████████| 84/84 [00:55<00:00,  1.51it/s]

Baseline performance: 0.0003849417844321579
Circuit performance: 0.00031651731114834547
Faithfulness: 6.84244732838124e-05
Percentage of model performance achieved by the circuit: 82.22%





### Mul_div circuit

In [10]:
scratch_cache_dir = "/mnt/fast0/rje41/.cache/huggingface"    
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
adapter_path = "../../fine-tuning/joint_training/checkpoints/mul_div/checkpoint-300"
model = load_model(
        adapter_path=adapter_path,
        hf_model_name=hf_model_name,
        translens_model_name=translens_model_name,
        scratch_cache_dir=scratch_cache_dir,
    )

Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [11]:
ds = EAPDataset('../../datasets/Mul_Div_circuit.csv')
dataloader = ds.to_dataloader(6)

In [12]:
g_md = Graph.from_model(model)
attribute(model, g_md, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

100%|██████████| 84/84 [05:10<00:00,  3.69s/it]


In [13]:
total_edges = len(g_md.edges)
five_percent_edges = int(total_edges * 0.05)
g_md.apply_greedy(five_percent_edges , absolute=True)
g_md.to_json('mul_div_graph.json')

In [14]:
metric_fn = partial(kl_divergence, loss=False, mean=False)
faithfulness, percentage_performance = calculate_faithfulness(model, g_md, dataloader, metric_fn)

100%|██████████| 84/84 [00:31<00:00,  2.70it/s]
100%|██████████| 84/84 [00:55<00:00,  1.51it/s]

Baseline performance: 0.0003781667910516262
Circuit performance: 0.00033664467628113925
Faithfulness: 4.152211477048695e-05
Percentage of model performance achieved by the circuit: 89.02%





### Merged circuit

In [6]:
scratch_cache_dir = "/mnt/fast0/rje41/.cache/huggingface"    
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
adapter_path = "../../fine-tuning/joint_training/checkpoints/merged/checkpoint-300"
model = load_model(
        adapter_path=adapter_path,
        hf_model_name=hf_model_name,
        translens_model_name=translens_model_name,
        scratch_cache_dir=scratch_cache_dir,
    )

Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [7]:
ds = EAPDataset('../../datasets/Add_Sub_100_circuit.csv')
dataloader = ds.to_dataloader(6)

In [8]:
g_merge = Graph.from_model(model)
attribute(model, g_merge, dataloader, partial(kl_divergence, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

100%|██████████| 84/84 [05:11<00:00,  3.70s/it]


In [9]:
total_edges = len(g_merge.edges)
five_percent_edges = int(total_edges * 0.05)
g_merge.apply_greedy(five_percent_edges , absolute=True)
g_merge.to_json('add_sub_merge_graph.json')

In [12]:
metric_fn = partial(kl_divergence, loss=False, mean=False)
faithfulness, percentage_performance = calculate_faithfulness(model, g_merge, dataloader, metric_fn)

100%|██████████| 84/84 [00:30<00:00,  2.74it/s]
100%|██████████| 84/84 [00:55<00:00,  1.51it/s]

Baseline performance: 0.00038985011633485556
Circuit performance: 0.0003385224554222077
Faithfulness: 5.132766091264784e-05
Percentage of model performance achieved by the circuit: 86.83%



