In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir('..')

In [2]:
import os

import einops
import torch
from transformer_lens import HookedTransformer

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: mps


In [4]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
prompt_format = [
    'When John and Mary went to the shops, {} gave the bag to',
    'When Tom and James went to the park, {} gave the ball to',
    'When Dan and Sid went to the shops, {} gave an apple to',
    'After Martin and Amy went to the park, {} gave a drink to',
]

In [6]:
name_pairs = [
    (' Mary', ' John'),
    (' Tom', ' James'),
    (' Dan', ' Sid'),
    (' Martin', ' Amy'),
]

prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]
answers = [names[::i] for names in name_pairs for i in (1, -1)]
answer_tokens = torch.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])

In [7]:
# get logits and cache of all internal activations for later analysis
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)
logits, cache = model.run_with_cache(tokens)

In [8]:
def logits_to_ave_logit_diff(
    logits,
    answer_tokens = answer_tokens,
    per_prompt = False,
):
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

ave_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)

# 1. Direct Logit Attribution

The easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted. Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called `direct logit` attribution.

### 1.1. Background and motivation of the logit difference

The central object of a transformer is the residual stream. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See `A Mathematical Framework for Transformer Circuits` for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer).

In general, there are two natural ways to interpret the model's outputs: the output logits or the output log probabilities. Let $\vec{x}$ be the logits, $\vec{L}$ be the log probabilities, and $\vec{p}$ be the probabilities. Then we have the following relations:

$$
p_i = \mathrm{softmax}(\vec{x})_i = \frac{e^{x_i}}{\sum_{i=1}^{n} e^{x_i}}
$$

and

$$
L_i = \log(p_i)
$$

Combining these, we get:

$$
L_i = \log \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} = x_i - \log \sum_{j=1}^{n} e^{x_j}
$$

The sum term on the right is the same for all $i$, so we get:

$$
L_i - L_j = x_i - x_j
$$

In other words, the logit difference $x_i - x_j$ is the same as the log probability difference $L_i - L_j$, motivating the use of logit differences to understand the model's outputs.

### 1.2. Logit diff directions

**Getting an output logit is equivalent to projecting onto a direction in the residual stream, and the same is true for getting the logit diff.**

Suppose the final value in the residual stream for a single sequence and a position within that sequence is $x$ (i.e., $x$ is a vector of length $d_{\text{model}}$). Then, we get logits by multiplying by the unembedding matrix $W_U$ (which has shape($d_{\text{model}}$, $d_{\text{vocab}}$)):

$$
\text{output} = x^T W_U
$$

Now, the logit difference between two tokens $i$ and $j$ is given by:

$$
\text{logit diff}_{ij} = x^T W_U[:, i] - x^T W_U[:, j] = x^T (W_U[:, i] - W_U[:, j])
$$

This means that the logit difference is given by the projection of the residual stream onto the vector $W_U[:, i] - W_U[:, j]$. This vector is called the **logit diff direction**, because it points in the direction of the largest logit difference between the two tokens.

In [None]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]], device='mps:0')

In [None]:
# map answer_tokens to logit diff direction
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions = correct_residual_directions - incorrect_residual_directions
print('Logit difference directions shape:', logit_diff_directions.shape)

Logit difference directions shape: torch.Size([8, 768])


In [10]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream = final_residual_stream[:, -1, :]

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions, "batch d_model, batch d_model ->"
) / len(prompts)

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {ave_logit_diff:.10f}")

torch.testing.assert_close(average_logit_diff, ave_logit_diff)

Final residual stream shape: torch.Size([8, 16, 768])
Calculated average logit diff: 2.7098135948
Original logit difference:     2.7098159790


# Sources

1. [Ground truth - Arena::Logit attribution](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification#keeping-track-of-your-guesses-predictions)
2. [A mathematical framework for transformer circuits, by Chris Olah, Neel Nanda, et. al.](https://transformer-circuits.pub/2021/framework/index.html)