In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir('..')

In [2]:
import os

import circuitsvis as cv
import einops
import numpy as np
import torch
from IPython.display import HTML, display
from transformer_lens import HookedTransformer
from transformer_lens import utils

from plotly_utils import line, imshow

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: mps


In [4]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
prompt_format = [
    'When John and Mary went to the shops, {} gave the bag to',
    'When Tom and James went to the park, {} gave the ball to',
    'When Dan and Sid went to the shops, {} gave an apple to',
    'After Martin and Amy went to the park, {} gave a drink to',
]

In [6]:
name_pairs = [
    (' Mary', ' John'),
    (' Tom', ' James'),
    (' Dan', ' Sid'),
    (' Martin', ' Amy'),
]

prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]
answers = [names[::i] for names in name_pairs for i in (1, -1)]
answer_tokens = torch.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])

In [7]:
# get logits and cache of all internal activations for later analysis
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)
logits, cache = model.run_with_cache(tokens)

In [8]:
def logits_to_ave_logit_diff(
    logits,
    answer_tokens = answer_tokens,
    per_prompt = False,
):
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

ave_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)

# 1. Direct Logit Attribution

The easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted. Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards. The main technique used to do this is called `direct logit` attribution.

### 1.1. Background and motivation of the logit difference

The central object of a transformer is the residual stream. This is the sum of the outputs of each layer and of the original token and positional embeddings. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See `A Mathematical Framework for Transformer Circuits` for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer).

In general, there are two natural ways to interpret the model's outputs: the output logits or the output log probabilities. Let $\vec{x}$ be the logits, $\vec{L}$ be the log probabilities, and $\vec{p}$ be the probabilities. Then we have the following relations:

$$
p_i = \mathrm{softmax}(\vec{x})_i = \frac{e^{x_i}}{\sum_{i=1}^{n} e^{x_i}}
$$

and

$$
L_i = \log(p_i)
$$

Combining these, we get:

$$
L_i = \log \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} = x_i - \log \sum_{j=1}^{n} e^{x_j}
$$

The sum term on the right is the same for all $i$, so we get:

$$
L_i - L_j = x_i - x_j
$$

In other words, the logit difference $x_i - x_j$ is the same as the log probability difference $L_i - L_j$, motivating the use of logit differences to understand the model's outputs.

### 1.2. Logit diff directions

**Getting an output logit is equivalent to projecting onto a direction in the residual stream, and the same is true for getting the logit diff.**

Suppose the final value in the residual stream for a single sequence and a position within that sequence is $\vec{x}$ (i.e., a vector of length $d_{\text{model}}$). Then, we get logits by multiplying by the unembedding matrix $W_U$ (which has shape($d_{\text{model}}$, $d_{\text{vocab}}$)):

$$
\text{output} = \vec{x}^T W_U
$$

Now, the logit difference between two tokens $i$ and $j$ is given by:

$$
\text{logit diff}_{ij} = \vec{x}^T W_U[:, i] - \vec{x}^T W_U[:, j] = \vec{x}^T (W_U[:, i] - W_U[:, j])
$$

This means that the logit difference is given by the projection of the residual stream onto the vector $W_U[:, i] - W_U[:, j]$. This vector is called the **logit diff direction**, because it points in the direction of the largest logit difference between the two tokens.

In [9]:
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]], device='mps:0')

In [10]:
# map answer_tokens to logit diff direction
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions = correct_residual_directions - incorrect_residual_directions
print('Logit difference directions shape:', logit_diff_directions.shape)

Logit difference directions shape: torch.Size([8, 768])


In [11]:
final_residual_stream = cache['resid_post', -1]
print(f'Final residual stream shape: {final_residual_stream.shape}')

final_token_residual_stream = final_residual_stream[:, -1, :]
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions, 'batch d_model, batch d_model ->'
) / len(prompts)

print(f'Calculated average logit diff: {average_logit_diff:.10f}')
print(f'Original logit difference:     {ave_logit_diff:.10f}')

torch.testing.assert_close(average_logit_diff, ave_logit_diff)

Final residual stream shape: torch.Size([8, 16, 768])
Calculated average logit diff: 2.7098135948
Original logit difference:     2.7098159790


# 2. Logit Lens

This technique looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequent layers.

From the plot below, it can be seen that the model is unable to do the task until layer 7, almost all performance coming from attention layer 9, and performance actually decreases from there. This tells us that there must be something going on (primarily in layers 7, 8 and 9) which writes to the residual stream in the correct way to solve the IOI task.

In [12]:
def residual_stack_to_logit_diff(
    residual_stack,
    cache,
    logit_diff_directions = logit_diff_directions,
):
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return (
        einops.einsum(scaled_residual_stack, logit_diff_directions, "... batch d_model, batch d_model -> ...")
        / batch_size
    )

In [13]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)

logit_lens_logit_diffs= residual_stack_to_logit_diff(accumulated_residual, cache)

line(
    logit_lens_logit_diffs,
    hovermode='x unified',
    title='Logit Difference From Accumulated Residual Stream',
    labels={'x': 'Layer', 'y': 'Logit Diff'},
    xaxis_tickvals=labels,
    width=800,
)

# 3. Layer attribution

The analysis above can be extended to look at the contribution of each layer to the final output, equivalent to the differences between adjacent residual streams.

We see that only attention layers matter, which makes sense. The IOI task is about moving information around (i.e. moving the correct name and not the incorrect name), and less about processing it. And again we note that attention layer 9 improves things a lot, while attention 10 and attention 11 decrease performance.

Note that a layer here is the `kth` layer in a stack of transformer blocks, but each block consists of an attention layer and an MLP layer.

In [14]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)

line(
    per_layer_logit_diffs,
    hovermode='x unified',
    title='Logit Difference From Each Layer',
    labels={'x': 'Layer', 'y': 'Logit Diff'},
    xaxis_tickvals=labels,
    width=800,
)

# 4. Head attribution

We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively. The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But, as described in `A Mathematical Framework`, this is equivalent to splitting the output weight matrix into a per-head output and adding them up (including an overall bias term for the entire layer).

Below we see that only a few heads really matter - heads 9.6 and 9.9 contribute a lot positively, explaining why attention layer 9 is so important, while heads 10.7 and 11.10 contribute a lot negatively, explaining why attention layer 10 and layer 11 are actively harmful. These correspond to some of the name movers and negative name movers discussed in the `Mathematical Framework` paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backup name movers).

These observations supports the claim in `A Mathematical Framework` paper that **attention heads are the right level of abstraction to understand attention**.

In [15]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(per_head_residual, '(layer head) ... -> layer head ...', layer=model.cfg.n_layers)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)

imshow(
    per_head_logit_diffs,
    labels={'x': 'Head', 'y': 'Layer'},
    title='Logit Difference From Each Head',
    width=600,
)

Tried to stack head results when they weren't cached. Computing head results now


# 5. Attention analysis

Attention heads are particularly fruitful to interpret, because we can look directly at their attention patterns and study from what positions they move information from and to. This is particularly useful here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token.

We use the `circuitsvis` library (developed from Anthropic's `PySvelte` library) to visualize the attention patterns. We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt as an illustration.

A common mistake to make when looking at attention patterns is thinking that they must convey information about the token looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the residual stream position corresponding to that input token. Especially later on in the model, there may be components in the residual stream that have nothing to do with the input token, e.g., the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in ".", "!" or "?".

In [18]:
def topk_of_Nd_tensor(tensor, k: int):
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()


k = 3

for head_type in ['Positive', 'Negative']:
    top_heads = topk_of_Nd_tensor(per_head_logit_diffs * (1 if head_type == 'Positive' else -1), k)

    attn_patterns_for_important_heads = torch.stack(
        [cache['pattern', layer][:, head][0] for layer, head in top_heads]
    )

    display(HTML(f'<h2>Top {k} {head_type} Logit Attribution Heads</h2>'))
    display(
        cv.attention.attention_patterns(
            attention=attn_patterns_for_important_heads,
            tokens=model.to_str_tokens(tokens[0]),
        )
    )

In [19]:
def topk_of_Nd_tensor(tensor, k: int):
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()


k = 3

for head_type in ['Positive', 'Negative']:
    top_heads = topk_of_Nd_tensor(per_head_logit_diffs * (1 if head_type == 'Positive' else -1), k)

    attn_patterns_for_important_heads = torch.stack(
        [cache['pattern', layer][:, head][0] for layer, head in top_heads]
    )

    display(HTML(f'<h2>Top {k} {head_type} Logit Attribution Heads</h2>'))
    display(
        cv.attention.attention_heads(
            attention=attn_patterns_for_important_heads,
            tokens=model.to_str_tokens(tokens[0]),
            attention_head_names=[f"{layer}.{head}" for layer, head in top_heads],
        )
    )

# Sources

1. [Ground truth - Arena::Logit attribution](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification#keeping-track-of-your-guesses-predictions)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [A mathematical framework for transformer circuits, by Chris Olah, Neel Nanda, et. al.](https://transformer-circuits.pub/2021/framework/index.html)