In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
from functools import partial

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer, AutoTokenizer
from transformer_lens import HookedTransformer

from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute 

In [3]:
model_name = 'meta-llama/Llama-3.2-1B-Instruct'
model = HookedTransformer.from_pretrained(model_name, device='cuda')
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True



Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


In [4]:
def collate_EAP(xs):
    clean, corrupted, labels = zip(*xs)
    clean = list(clean)
    corrupted = list(corrupted)
    return clean, corrupted, labels

class EAPDataset(Dataset):
    def __init__(self, filepath, tokenizer):
        self.df = pd.read_csv(filepath)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)
    
    def shuffle(self):
        self.df = self.df.sample(frac=1)

    def head(self, n: int):
        self.df = self.df.head(n)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        clean = row['clean']
        corrupted = row['corrupted']
        label_str = str(row['label']) 
        label_id = self.tokenizer(label_str, add_special_tokens=False).input_ids[0]
        return clean, corrupted, label_id
    
    def to_dataloader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, collate_fn=collate_EAP)

def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
    batch_size = logits.size(0)
    idx = torch.arange(batch_size, device=logits.device)
    return logits[idx, input_length - 1]

def prob_diff(
    logits: torch.Tensor,
    clean_logits: torch.Tensor,
    input_length: torch.Tensor,
    labels: torch.Tensor,
    mean: bool = True,
    loss: bool = False
) -> torch.Tensor:
    logits = get_logit_positions(logits, input_length)
    probs = torch.softmax(logits, dim=-1)
    selected_probs = probs[torch.arange(probs.size(0)), labels]
    
    results = selected_probs
    if loss:
        results = -results
    if mean:
        results = results.mean()
    return results

In [9]:
ds = EAPDataset('../datasets/add_sub.csv', model.tokenizer)
dataloader = ds.to_dataloader(10)

In [10]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(prob_diff, loss=True, mean=True), method='EAP-IG-inputs', ig_steps=5)

  0%|          | 0/2500 [00:00<?, ?it/s]

Next token predicted: ?
 ?
    ?
  -


  0%|          | 1/2500 [00:02<1:25:43,  2.06s/it]

Next token predicted:    -  ?
   - 


  0%|          | 2/2500 [00:03<1:18:49,  1.89s/it]

Next token predicted:           


  0%|          | 3/2500 [00:05<1:16:48,  1.85s/it]

Next token predicted:       -?
   


  0%|          | 4/2500 [00:07<1:15:46,  1.82s/it]

Next token predicted:           


  0%|          | 5/2500 [00:09<1:15:34,  1.82s/it]

Next token predicted:           


  0%|          | 6/2500 [00:11<1:15:16,  1.81s/it]

Next token predicted:           


  0%|          | 7/2500 [00:12<1:14:40,  1.80s/it]

Next token predicted:           


  0%|          | 8/2500 [00:14<1:14:31,  1.79s/it]

Next token predicted:     ?
  -?
  


  0%|          | 9/2500 [00:16<1:14:09,  1.79s/it]

Next token predicted:       -    


  0%|          | 10/2500 [00:18<1:14:04,  1.78s/it]

Next token predicted:           


  0%|          | 11/2500 [00:19<1:14:09,  1.79s/it]

Next token predicted:   -    ?
   


  0%|          | 12/2500 [00:21<1:14:11,  1.79s/it]

Next token predicted:           -


  1%|          | 13/2500 [00:23<1:14:09,  1.79s/it]

Next token predicted:          -?



  1%|          | 14/2500 [00:25<1:14:03,  1.79s/it]

Next token predicted:           


  1%|          | 15/2500 [00:27<1:14:14,  1.79s/it]

Next token predicted:           


  1%|          | 16/2500 [00:28<1:14:15,  1.79s/it]

Next token predicted:           


  1%|          | 17/2500 [00:30<1:14:03,  1.79s/it]

Next token predicted:    ?
 ?
    


  1%|          | 18/2500 [00:32<1:14:19,  1.80s/it]

Next token predicted:       ?
   


  1%|          | 19/2500 [00:34<1:14:19,  1.80s/it]

Next token predicted:    ?
      


  1%|          | 20/2500 [00:36<1:14:01,  1.79s/it]

Next token predicted:           


  1%|          | 21/2500 [00:37<1:13:59,  1.79s/it]

Next token predicted:      -     


  1%|          | 22/2500 [00:39<1:13:51,  1.79s/it]

Next token predicted:     ?
  ?
  


  1%|          | 23/2500 [00:41<1:13:48,  1.79s/it]

Next token predicted: ?
       ?
?



  1%|          | 24/2500 [00:43<1:13:49,  1.79s/it]

Next token predicted:   - ?
   ?
  


  1%|          | 25/2500 [00:44<1:13:50,  1.79s/it]

Next token predicted:     -      


  1%|          | 26/2500 [00:46<1:14:06,  1.80s/it]

Next token predicted:           


  1%|          | 27/2500 [00:48<1:14:18,  1.80s/it]

Next token predicted:           


  1%|          | 28/2500 [00:50<1:14:10,  1.80s/it]

Next token predicted:         ?
?



  1%|          | 29/2500 [00:52<1:14:18,  1.80s/it]

Next token predicted:  ?
   - ?
   


  1%|          | 30/2500 [00:54<1:14:20,  1.81s/it]

Next token predicted:     -      


  1%|          | 31/2500 [00:55<1:14:27,  1.81s/it]

Next token predicted:           


  1%|▏         | 32/2500 [00:57<1:14:17,  1.81s/it]

Next token predicted:    -?
     ?



  1%|▏         | 33/2500 [00:59<1:14:14,  1.81s/it]

Next token predicted:  -?
    - ?
  -


  1%|▏         | 34/2500 [01:01<1:14:25,  1.81s/it]

Next token predicted:           


  1%|▏         | 35/2500 [01:03<1:14:15,  1.81s/it]

Next token predicted:     ?
     


  1%|▏         | 36/2500 [01:04<1:14:24,  1.81s/it]

Next token predicted:    -   -    


  1%|▏         | 37/2500 [01:06<1:14:24,  1.81s/it]

Next token predicted:       ?
   


  2%|▏         | 38/2500 [01:08<1:14:14,  1.81s/it]

Next token predicted:       -  - ?



  2%|▏         | 39/2500 [01:10<1:14:06,  1.81s/it]

Next token predicted:  ?
     ?
  


  2%|▏         | 40/2500 [01:12<1:13:56,  1.80s/it]

Next token predicted:   ?
       


  2%|▏         | 41/2500 [01:13<1:13:56,  1.80s/it]

Next token predicted:      ?
    


  2%|▏         | 42/2500 [01:15<1:13:51,  1.80s/it]

Next token predicted:           


  2%|▏         | 43/2500 [01:17<1:13:48,  1.80s/it]

Next token predicted:    ?
   -   


  2%|▏         | 44/2500 [01:19<1:13:49,  1.80s/it]

Next token predicted: ?
 -   ?
   ?



  2%|▏         | 45/2500 [01:21<1:13:47,  1.80s/it]

Next token predicted:  ?
        


  2%|▏         | 46/2500 [01:22<1:13:52,  1.81s/it]

Next token predicted:           


  2%|▏         | 47/2500 [01:24<1:14:04,  1.81s/it]

Next token predicted:    ?
 ?
    


  2%|▏         | 48/2500 [01:26<1:13:55,  1.81s/it]

Next token predicted:  -         


  2%|▏         | 49/2500 [01:28<1:13:44,  1.80s/it]

Next token predicted:      ?
    


  2%|▏         | 50/2500 [01:30<1:13:40,  1.80s/it]

Next token predicted:   ?
    -   


  2%|▏         | 51/2500 [01:32<1:13:51,  1.81s/it]

Next token predicted:         -  





KeyboardInterrupt: 

In [None]:
g.apply_topn(200, True)
g.to_json('graph.json')

In [None]:
try:
    import pygraphviz
    gz = g.to_image(f'graph.png')
except ImportError:
    print("No pygraphviz installed; skipping this part")

No pygraphviz installed; skipping this part


In [None]:
baseline = evaluate_baseline(model, dataloader, partial(prob_diff, loss=False, mean=False)).mean().item()
results = evaluate_graph(model, g, dataloader, partial(prob_diff, loss=False, mean=False)).mean().item()
print(f"Original performance was {baseline}; the circuit's performance is {results}")

100%|██████████| 209/209 [00:40<00:00,  5.21it/s]
100%|██████████| 209/209 [00:50<00:00,  4.12it/s]

Original performance was 0.00018705803086049855; the circuit's performance is 6.506806676043198e-05



