In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
os.chdir('..')

In [None]:
import os
from rich.table import Table, Column

import einops
import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import test_prompt, rprint

from plotly_utils import imshow, line

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: mps


# 1. Model setup

##### Note on `center_unembed`

The logits are fed into a softmax. Softmax is translation invariant (e.g., adding 1 to every logit doesn't change the output), so it can be simplified by setting the mean of the logits to be zero. This is equivalent to setting the mean of every output vector of $W_U$ to zero.

##### Note on `center_writing_weights`

Every component reading an input from the residual stream is preceded by a LayerNorm, which means that the mean of a residual stream vector (i.e., the component in the direction of all ones) never matters. This means the all ones component of weights and biases whose output writes to the residual stream can be removed without changing the model's behavior.

##### Note on `fold_ln`

Layer Normalization, unlike  Batch Normalization, can't be turned off at inference time. From an interpretability perspective, naively ignoring it results in a close, but not exact, approximation to the original model. If LayerNorm is followed by a linear layer, it is computationally equivalent to reducing by centering & normalising, followed by a linear layer. Apart from dividing by the norm, these are all linear algebra operations, and from an interpretability perspective, if anything is linear, it can be ignored because it will break up into sums during backprogation, i.e., there is no need to track interference between terms.

##### Note on `center_writing_weights`

A related idea to folding layernorm - every component reading an input from the residual stream is preceded by a LayerNorm, which means that the mean of a residual stream vector (ie the component in the direction of all ones) never matters. This means we can remove the all ones component of weights and biases whose output writes to the residual stream.

##### Note on `refactor_factored_attn_matrices`

This argument means we redefine the matrices $W_Q$, $W_K$, $W_V$ and $W_O$ in the model, without changing the model's actual behaviour.

For example, we know that instead of working with $W_Q$ and $W_K$ individually, the only matrix we actually need to use in the model is the low-rank matrix $W_Q W_K^T$. So if we perform singular value decomposition $W_Q W_K^T = U S V^T$, then we see that we can just as easily define $W_Q = U \sqrt{S}$ and $W_K = V \sqrt{S}$ and use these instead. This means that $W_Q$ and $W_K$​	both have orthogonal columns with matching norms. This is arguably a more interpretable setup, because now there's no obvious asymmetry between the keys and queries.

In a similar way, since $W_OV = W_O W_V = U S V^T$, we can define $W_V = U S$ and $W_O = V^T$. This is arguably a more interpretable setup, because now $W_O$ is just a rotation, and doesn't change the norm, so 
$z$ has the same norm as the result of the head.

More details in [link](https://github.com/TransformerLensOrg/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln).

In [None]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_unembed=True,
    fold_ln=True,
    center_writing_weights=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
# shape: (n_blocks, n_heads, d_model, d_head)
# d_head = d_model // n_heads

model.W_Q.shape, model.W_K.shape

(torch.Size([12, 12, 768, 64]), torch.Size([12, 12, 768, 64]))

In [6]:
model.W_Q[0, 0].sum(0).shape

torch.Size([64])

In [7]:
# column norms are the same (except first few, for fiddly bias reasons)
line([model.W_Q[0, 0].pow(2).sum(0), model.W_K[0, 0].pow(2).sum(0)])

In [8]:
# columns are orthogonal
W_Q_dot_products = einops.einsum(model.W_Q[0, 0], model.W_Q[0, 0], 'd_model d_head_1, d_model d_head_2 -> d_head_1 d_head_2')
imshow(W_Q_dot_products)

### 1.1. Verify if the model performs as expected

##### 1.1.1. Run the model on single instance of a task

In [None]:
# prepend_bos adds a BOS (beginning of sequence) to the start of the prompt. 
# GPT-2 was not trained with this, but it makes the model more stable, as the first token is treated weirdly.
example_prompt = 'After John and Mary went to the store, John gave a bottle of milk to'
example_answer = ' Mary'
test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


##### 1.1.2. Run the model on multiple instances of a task

Each prompt is given twice - one with the first name as the indirect object, one with the second name. These prompts are composed only by single token names and the corresponding names are always in the same token positions.

The metric used to evaluate the model's performance will be the logit difference, i.e., the difference in logit between the indirect object's name and the subject's name.

Tokens are a massive headache when reverse engineering language models. Different inputs will have different numbers of tokens, the relevant tokens will be at different positions, the total number of tokens will vary, etc. Language models often devote significant amounts of parameters in early and later layers to convert  inputs from tokens to a more sensible internal format and back again. It is beneficial to avoid thinking about tokenization whenever possible when doing exploratory analysis, although it is relevant later when trying to perform rigorous analysis.

In [10]:
prompt_format = [
    'When John and Mary went to the shops, {} gave the bag to',
    'When Tom and James went to the park, {} gave the ball to',
    'When Dan and Sid went to the shops, {} gave an apple to',
    'After Martin and Amy went to the park, {} gave a drink to',
]

In [15]:
# adjacent prompts have answers swapped
name_pairs = [
    (' Mary', ' John'),
    (' Tom', ' James'),
    (' Dan', ' Sid'),
    (' Martin', ' Amy'),
]

prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]
prompts

['When John and Mary went to the shops,  John gave the bag to',
 'When John and Mary went to the shops,  Mary gave the bag to',
 'When Tom and James went to the park,  James gave the ball to',
 'When Tom and James went to the park,  Tom gave the ball to',
 'When Dan and Sid went to the shops,  Sid gave an apple to',
 'When Dan and Sid went to the shops,  Dan gave an apple to',
 'After Martin and Amy went to the park,  Amy gave a drink to',
 'After Martin and Amy went to the park,  Martin gave a drink to']

In [17]:
# answers for each prompt, in the form (correct, incorrect)
answers = [names[::i] for names in name_pairs for i in (1, -1)]
answers

[(' Mary', ' John'),
 (' John', ' Mary'),
 (' Tom', ' James'),
 (' James', ' Tom'),
 (' Dan', ' Sid'),
 (' Sid', ' Dan'),
 (' Martin', ' Amy'),
 (' Amy', ' Martin')]

In [20]:
answer_tokens = torch.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])
answer_tokens

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]], device='mps:0')

In [None]:
cols = [
    'Prompt',
    Column('Correct', style='rgb(0,200,0) bold'),
    Column('Incorrect', style='rgb(255,0,0) bold'),
]
table = Table(*cols, title='Prompts & Answers:')

for prompt, answer in zip(prompts, answers):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]))

rprint(table)

In [26]:
# get logits and cache of all internal activations for later analysis
tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(device)
logits, cache = model.run_with_cache(tokens)

In [36]:
def logits_to_ave_logit_diff(
    logits,
    answer_tokens = answer_tokens,
    per_prompt = False,
):
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

ave_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False)
print(f'Average logit difference: {ave_logit_diff:.3f}')
ave_logit_diff_per_prompt = logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True)
print(f'Average logit difference per prompt: {ave_logit_diff_per_prompt.mean():.3f}')

Average logit difference: 2.710
Average logit difference per prompt: 2.710


In [38]:
cols = [
    "Prompt",
    Column("Correct", style="rgb(0,200,0) bold"),
    Column("Incorrect", style="rgb(255,0,0) bold"),
    Column("Logit Difference", style="bold"),
]
table = Table(*cols, title="Logit differences")

for prompt, answer, logit_diff in zip(prompts, answers, ave_logit_diff_per_prompt):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]), f"{logit_diff.item():.3f}")

rprint(table)

# 2. Conclusion

Attention is really good at the primitive operations of looking nearby, or copying information. A simple model could figure out that at `to`, it should look for names and predict that those names came next (e.g., the skip trigram " John...to → John"). But it's much harder to tell how many of each previous names there are - attending to each copy of John will look exactly the same as attending to a single John token. So this will be pretty hard to figure out on the ` to` token.

The natural place to break this symmetry is on the second ` John` token - telling whether there is an earlier copy of the current token should be a much easier task. So it might be expected there to be a head which detects duplicate tokens on the second ` John` token, and then another head which moves that information from the second ` John` token to the ` to` token.

The model then needs to learn to predict ` Mary` and not ` John`. For that, it needs a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names. So this only attends to Mary. And then the output of this head maps to the logits.

# Sources

1. [Ground truth - Arena::Indirect Object Identification](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [NOTEBOOK - Exploratory Analysis Demo, by Neel Nanda](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh)
4. [An analogy for understanding transformers, by Callum McDougall](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)
5. [A mathematical framework for transformer circuits, by Chris Olah, Neel Nanda, et. al.](https://transformer-circuits.pub/2021/framework/index.html)