In [1]:
import torch as t
from torch import Tensor
from jaxtyping import Float
from tqdm import tqdm
import numpy as np

from nnsight.models.UnifiedTransformer import UnifiedTransformer

device = "cuda"

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


In [2]:
model = UnifiedTransformer(
    'gpt2-small',
    processing=False,
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
tokenizer = model.tokenizer

model.local_model.set_use_hook_mlp_in(True)
model.local_model.set_use_split_qkv_input(True)
model.local_model.set_use_attn_result(True)
model.update_meta()

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  meta
Moving model to device:  meta


In [3]:
from ioi_dataset import IOIDataset, format_prompt, make_table

N = 25
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

# make_table(
#   colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
#   cols = [
#     map(format_prompt, clean_dataset.sentences),
#     tokenizer.decode(clean_dataset.s_tokenIDs).split(),
#     tokenizer.decode(clean_dataset.io_tokenIDs).split(),
#     map(format_prompt, clean_dataset.sentences),
#   ],
#   title = "Sentences from IOI vs ABC distribution",
# )


In [4]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    batch_size = logits.size(0)
    io_logits = logits[range(batch_size), ioi_dataset.word_idx['end'][:batch_size], ioi_dataset.io_tokenIDs[:batch_size]]
    s_logits = logits[range(batch_size), ioi_dataset.word_idx['end'][:batch_size], ioi_dataset.s_tokenIDs[:batch_size]]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()


with model.invoke(clean_dataset.toks) as clean_invoker:
    pass

with model.invoke(corr_dataset.toks) as corrupted_invoker:
    pass

clean_logits = clean_invoker.output
corrupt_logits = corrupted_invoker.output

clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()
corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()

def ioi_metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    ioi_dataset: IOIDataset = clean_dataset
 ):
    patched_logit_diff = ave_logit_diff(logits, ioi_dataset)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

def negative_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -ioi_metric(logits)
    
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Clean direction: 2.805180311203003, Corrupt direction: 1.5410711765289307
Clean metric: 1.0, Corrupt metric: 0.0


In [9]:
import eap_graph

import importlib

importlib.reload(eap_graph)

graph = eap_graph.EAP(model.config, components=["head", "mlp"])

In [10]:
graph.run(
    model,
    clean_dataset.toks,
    corr_dataset.toks,
    batch_size=25,
    metric=ioi_metric,
)

In [11]:
edges = graph.top_edges(n=20, format=True)

head.9.9 -> [-0.027] -> head.11.10.q
head.10.7 -> [0.026] -> head.11.10.q
head.5.5 -> [0.019] -> head.8.6.v
head.9.9 -> [-0.018] -> head.10.7.q
head.5.5 -> [-0.017] -> mlp.5
mlp.0 -> [-0.017] -> head.7.9.k
mlp.0 -> [0.017] -> mlp.5
mlp.0 -> [0.014] -> head.6.9.q
mlp.0 -> [-0.014] -> head.6.9.k
head.5.5 -> [-0.013] -> head.6.9.q
mlp.0 -> [-0.013] -> mlp.2
mlp.0 -> [0.012] -> head.6.5.k
head.9.6 -> [-0.012] -> head.11.10.q
head.3.0 -> [-0.012] -> mlp.5
head.9.6 -> [-0.012] -> head.10.7.q
mlp.5 -> [-0.011] -> mlp.6
mlp.0 -> [-0.011] -> head.8.10.k
mlp.10 -> [0.01] -> head.11.10.k
mlp.0 -> [-0.01] -> mlp.1
head.4.11 -> [0.01] -> head.6.9.k
