In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir('..')

In [None]:
import os
from collections import OrderedDict

import circuitsvis as cv
import einops
import numpy as np
import torch
from IPython.display import HTML, display
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformers import GPT2LMHeadModel

from plotly_utils import line, imshow

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: mps


In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

checkpoint = torch.load('checkpoint_step.pth')
custom_state = checkpoint['model_state_dict']

fixed_state = OrderedDict()
for k, v in custom_state.items():
    new_key = k.replace('_orig_mod.', '')
    fixed_state[new_key] = v

for k, v in fixed_state.items():
    print(k, v.shape)


model.load_state_dict(fixed_state, strict=True)
print('Model loaded from local checkpoint')

Model loaded from local checkpoint


In [5]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
prompt_format = [
    'When John and Mary went to the shops, {} gave the bag to',
    'When Tom and James went to the park, {} gave the ball to',
    'When Dan and Sid went to the shops, {} gave an apple to',
    'After Martin and Amy went to the park, {} gave a drink to',
]

In [None]:
name_pairs = [
    (' Mary', ' John'),
    (' Tom', ' James'),
    (' Dan', ' Sid'),
    (' Martin', ' Amy'),
]

prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]
answers = [names[::i] for names in name_pairs for i in (1, -1)]
answer_tokens = torch.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])

In [9]:
# get logits and cache of all internal activations for later analysis
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)
logits, cache = model.run_with_cache(tokens)

In [10]:
def logits_to_ave_logit_diff(
    logits,
    answer_tokens = answer_tokens,
    per_prompt = False,
):
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

ave_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)

# 1. Direct Logit Attribution

In [11]:
answer_tokens

tensor([[ 2940, 23165],
        [23165,  2940],
        [ 4543, 10490],
        [10490,  4543],
        [35754,  5181],
        [ 5181, 35754],
        [ 8114,  3905],
        [ 3905,  8114]], device='mps:0')

In [12]:
# map answer_tokens to logit diff direction
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions = correct_residual_directions - incorrect_residual_directions
print('Logit difference directions shape:', logit_diff_directions.shape)

Logit difference directions shape: torch.Size([8, 768])


In [None]:
final_residual_stream = cache['resid_post', -1]
print(f'Final residual stream shape: {final_residual_stream.shape}')

final_token_residual_stream = final_residual_stream[:, -1, :]
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions, 'batch d_model, batch d_model ->'
) / len(prompts)

print(f'Calculated average logit diff: {average_logit_diff:.10f}')
print(f'Original logit difference:     {ave_logit_diff:.10f}')

torch.testing.assert_close(average_logit_diff, ave_logit_diff)

Final residual stream shape: torch.Size([8, 16, 768])
Calculated average logit diff: 2.8493238921
Original logit difference:     2.8493238921


# 2. Logit Lens

In [15]:
def residual_stack_to_logit_diff(
    residual_stack,
    cache,
    logit_diff_directions = logit_diff_directions,
):
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return (
        einops.einsum(scaled_residual_stack, logit_diff_directions, "... batch d_model, batch d_model -> ...")
        / batch_size
    )

In [16]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)

logit_lens_logit_diffs= residual_stack_to_logit_diff(accumulated_residual, cache)

line(
    logit_lens_logit_diffs,
    hovermode='x unified',
    title='Logit Difference From Accumulated Residual Stream',
    labels={'x': 'Layer', 'y': 'Logit Diff'},
    xaxis_tickvals=labels,
    width=800,
)

# 3. Layer attribution

In [17]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)

line(
    per_layer_logit_diffs,
    hovermode='x unified',
    title='Logit Difference From Each Layer',
    labels={'x': 'Layer', 'y': 'Logit Diff'},
    xaxis_tickvals=labels,
    width=800,
)

# 4. Head attribution

In [18]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(per_head_residual, '(layer head) ... -> layer head ...', layer=model.cfg.n_layers)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)

imshow(
    per_head_logit_diffs,
    labels={'x': 'Head', 'y': 'Layer'},
    title='Logit Difference From Each Head',
    width=600,
)

Tried to stack head results when they weren't cached. Computing head results now


# Sources

1. [Ground truth - Arena::Logit attribution](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification#keeping-track-of-your-guesses-predictions)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [A mathematical framework for transformer circuits, by Chris Olah, Neel Nanda, et. al.](https://transformer-circuits.pub/2021/framework/index.html)