In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir('..')

In [None]:
import os
from collections import OrderedDict

import circuitsvis as cv
import numpy as np
import torch
from IPython.display import HTML
from transformer_lens import HookedTransformer, patching, utils
from transformers import GPT2LMHeadModel

from plotly_utils import imshow


In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: mps


In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

checkpoint = torch.load('checkpoint_step.pth')
custom_state = checkpoint['model_state_dict']

fixed_state = OrderedDict()
for k, v in custom_state.items():
    new_key = k.replace('_orig_mod.', '')
    fixed_state[new_key] = v

for k, v in fixed_state.items():
    print(k, v.shape)


model.load_state_dict(fixed_state, strict=True)
print('Model loaded from local checkpoint')

Model loaded from local checkpoint


In [5]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
prompt_format = [
    'When John and Mary went to the shops, {} gave the bag to',
    'When Tom and James went to the park, {} gave the ball to',
    'When Dan and Sid went to the shops, {} gave an apple to',
    'After Martin and Amy went to the park, {} gave a drink to',
]

In [133]:
name_pairs = [
    (' Mary', ' John'),
    (' Tom', ' James'),
    (' Dan', ' Sid'),
    (' Martin', ' Amy'),
]

prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]
answers = [names[::i] for names in name_pairs for i in (1, -1)]
answer_tokens = torch.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])

In [134]:
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)

In [135]:
def logits_to_ave_logit_diff(
    logits,
    answer_tokens = answer_tokens,
    per_prompt = False,
):
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
clean_tokens = tokens

# wwap each adjacent pair to get corrupted tokens
indices = [i + 1 if i % 2 == 0 else i - 1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[indices]

print(
    "Clean string 0:    ",
    model.to_string(clean_tokens[0]),
    "\nCorrupted string 0:",
    model.to_string(corrupted_tokens[0]),
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean string 0:     <|endoftext|>When John and Mary went to the shops,  John gave the bag to 
Corrupted string 0: <|endoftext|>When John and Mary went to the shops,  Mary gave the bag to
Clean logit diff: 2.8493
Corrupted logit diff: -2.8493


In [126]:
def ioi_metric(
    logits,
    answer_tokens = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
):
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

torch.testing.assert_close(ioi_metric(clean_logits).item(), 1.0)
torch.testing.assert_close(ioi_metric(corrupted_logits).item(), 0.0)
torch.testing.assert_close(ioi_metric((clean_logits + corrupted_logits) / 2).item(), 0.5)

# 1. Residual stream patching

In [None]:
act_patch_resid_pre = patching.get_act_patch_resid_pre(
    model=model, corrupted_tokens=corrupted_tokens, clean_cache=clean_cache, patching_metric=ioi_metric
)

labels = [f'{tok} {i}' for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

imshow(
    act_patch_resid_pre,
    labels={'x': 'Position', 'y': 'Layer'},
    x=labels,
    title='resid_pre Activation Patching',
    width=700,
)

100%|██████████| 192/192 [00:11<00:00, 17.33it/s]


# 2. Patching in residual stream by block

In [129]:
act_patch_block_every = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)

imshow(
    act_patch_block_every,
    x=labels,
    facet_col=0,
    facet_labels=['Residual Stream', 'Attn Output', 'MLP Output'],
    title='Logit Difference From Patched Attn Head Output',
    labels={'x': 'Sequence Position', 'y': 'Layer'},
    width=1200,
)

100%|██████████| 192/192 [00:10<00:00, 18.70it/s]
100%|██████████| 192/192 [00:10<00:00, 18.27it/s]
100%|██████████| 192/192 [00:10<00:00, 18.87it/s]


# 3. Head patching

In [130]:
act_patch_attn_head_out_all_pos = patching.get_act_patch_attn_head_out_all_pos(
    model, corrupted_tokens, clean_cache, ioi_metric
)

imshow(
    act_patch_attn_head_out_all_pos,
    labels={'y': 'Layer', 'x': 'Head'},
    title='attn_head_out Activation Patching (All Pos)',
    width=600,
)

100%|██████████| 144/144 [00:08<00:00, 16.43it/s]


# 4. Decomposing heads

In [115]:
act_patch_attn_head_all_pos_every = patching.get_act_patch_attn_head_all_pos_every(
    model, corrupted_tokens, clean_cache, ioi_metric
)

imshow(
    act_patch_attn_head_all_pos_every,
    facet_col=0,
    facet_labels=['Output', 'Query', 'Key', 'Value', 'Pattern'],
    title='Activation Patching Per Head (All Pos)',
    labels={'x': 'Head', 'y': 'Layer'},
    width=1200,
)

  0%|          | 0/144 [00:00<?, ?it/s]

100%|██████████| 144/144 [00:07<00:00, 19.36it/s]
100%|██████████| 144/144 [00:08<00:00, 17.93it/s]
100%|██████████| 144/144 [00:08<00:00, 17.96it/s]
100%|██████████| 144/144 [00:07<00:00, 18.19it/s]
100%|██████████| 144/144 [00:07<00:00, 18.24it/s]


In [131]:
def topk_of_Nd_tensor(tensor, k):
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()


# get heads with largest value patching
# from plot above, these are the 4 heads in layers 7 and 8
k = 4
top_heads = topk_of_Nd_tensor(act_patch_attn_head_all_pos_every[3], k=k)

# get all their attention patterns
attn_patterns_for_important_heads = torch.stack([
    clean_cache['pattern', layer][:, head].mean(0)
        for layer, head in top_heads
])

# display results
display(HTML(f'<h2>Top {k} Logit Attribution Heads (from value-patching)</h2>'))
display(cv.attention.attention_patterns(
    attention = attn_patterns_for_important_heads,
    tokens = model.to_str_tokens(tokens[0]),
))

# Sources

1. [Ground truth - Arena::Activation Patching](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification#keeping-track-of-your-guesses-predictions)
2. [Locating and Editing Factual Associations in GPT, by Meng, K, et. al.](https://arxiv.org/pdf/2202.05262)
3. [A Mathematical Framework for Transformer Circuits, by Chris Olah, Neel Nanda, et. al.](https://transformer-circuits.pub/2021/framework/index.html)
4. [Neel Nanda's walkthrough of Mathematica Framework](https://www.youtube.com/watch?v=KV5gbOmHbjU)