# TransformerLens: Exploratory Analysis Demo

### Resources

* Working through the tutorial: [Exploratory Analysis Demo](https://transformerlensorg.github.io/TransformerLens/generated/demos/Exploratory_Analysis_Demo.html)
* Based on the demo notebook: [Exploratory_Analysis_Demo.ipynb](https://github.com/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb)
* [Mechanistic interpretability terms glossary](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_Jzi6YHRHKP1JziwdE02qdYZ)
* [TransformerLens supported model list](https://github.com/TransformerLensOrg/TransformerLens/blob/a634e5757cd2ae43add9beb492be7a2dfb443a30/transformer_lens/loading_from_pretrained.py#L49)

### Setup steps

In [63]:
IN_COLAB = False

from IPython import get_ipython
ip = get_ipython()
if not ip.extension_manager.loaded:
    ip.extension_manager.load('autoreload')
    %autoreload 2

### Imports

In [64]:

from plotting_helpers import imshow, plot_line, plot_scatter

from functools import partial
from typing import List, Optional, Union

import einops
import random
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

# Configure torch and determinism
torch.set_grad_enabled(False)
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.use_deterministic_algorithms(True)

### Task: Indirect Object Identification

* **Indirect Object Identification**: identifying the object of a sentence to generate the next token (e.g. generate "Mary" given the sentence "After John and Mary went to the shops, John gave a bottle of milk to").
* *Hypothesis:* to perform this task, a network must circuits that can:
    1. Recognise names.
    2. Identify which name is already duplicated.
    3. Predict the name that has not been duplicated.

#### Model setup

In [65]:
# Input model & hyperparameters
model_name = "gpt2-small"

# Get current device type
device = utils.get_device()

# Flags that preserve the model's output but simplify its internals.
hyperparams = {
    "device": device,
    "center_unembed": True,
    "center_writing_weights": True,
    "fold_ln": True,
    "refactor_factored_attn_matrices": True
}

# Instantiate model
model = HookedTransformer.from_pretrained(
    model_name,
    **hyperparams
)

Loaded pretrained model gpt2-small into HookedTransformer


In [66]:
# Test model can perform Indirect Object Identification task
test_prompt = "After John and Mary went to the shops, John gave a bottle of milk to"
test_completion = " Mary"

utils.test_prompt(test_prompt, test_completion, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.19 Prob: 72.20% Token: | Mary|
Top 1th token. Logit: 15.43 Prob:  4.57% Token: | John|
Top 2th token. Logit: 15.31 Prob:  4.04% Token: | the|
Top 3th token. Logit: 15.16 Prob:  3.49% Token: | them|
Top 4th token. Logit: 14.89 Prob:  2.65% Token: | his|
Top 5th token. Logit: 13.74 Prob:  0.84% Token: | her|
Top 6th token. Logit: 13.56 Prob:  0.70% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.67% Token: | their|
Top 8th token. Logit: 13.06 Prob:  0.42% Token: | Mrs|
Top 9th token. Logit: 13.04 Prob:  0.42% Token: | Jesus|


In [67]:
def format_prompt_for_analysis(prompt, correct_completion, incorrect_completion):
    """
    Formats a prompt by inserting the incorrect completion choice,
    along with the token ids for the correct and incorrect completions.

    Args:
        prompt (str): The prompt string with a '{}' placeholder for completion.
        correct_completion (str): The correct completion string (e.g., " Mary").
        incorrect_completion (str): The incorrect completion string (e.g., " John").

    Returns:
        formatted_prompt (str): The prompt with the incorrect completion inserted.
        completions_tok (tuple): Tuple of token ids for (correct_completion, incorrect_completion).
    """

    # Tokenise given correct and incorrect completions to prompt
    correct_completion_tok = model.to_single_token(correct_completion)
    incorrect_completion_tok = model.to_single_token(incorrect_completion)

    completions_tok = (correct_completion_tok, incorrect_completion_tok)

    # Add one of the completions (correct/incorrect) to the prompt
    # This allows us to inspect logits over the vocabulary when the completion is generated
    formatted_prompt = prompt.format(incorrect_completion)

    return formatted_prompt, completions_tok

In [68]:
# Define metric "logit difference" to score correct vs incorrect completion
def compute_logit_difference(generation_logits, completion_choice_tokens, per_prompt=False):
    """
    Logit Difference metric for Indirect Object Identification task.
    Scores the difference in logit between the indirect object's name and the subject's name.

    Args:
        generation_logits (Tensor): Logits output from the model of shape (batch, seq_len, vocab_size).
        completion_choice_tokens (Tensor): Tensor of shape (batch, 2) or (2,) with token ids for correct and incorrect completions.
        per_prompt (bool): If True, return logit difference per prompt; if False, return mean over batch.

    Returns:
        logit_differences (Tensor): Logit difference(s) between correct and incorrect completions. Shape is (batch,) if per_prompt is True, else float.
    """

    # Get logits corresponding to the last token (i.e. our completion)
    logits_last_token = generation_logits[:, -1, :]

    # Get logits corresponding to the correct and incorrect completion tokens
    if generation_logits.shape[0] > 1:
        # Compute over batch of prompts/completions
        logits_completion_choices = logits_last_token.gather(dim=-1, index=completion_choice_tokens)
        logit_differences = logits_completion_choices[:, 0] - logits_completion_choices[:, 1]
    else:
        # Compute on single prompt and completion choice pair
        completion_choice_tokens = completion_choice_tokens[0]
        logit_differences = logits_last_token[0, completion_choice_tokens[0]] - logits_last_token[0, completion_choice_tokens[1]]

    # Average results over batch
    if not per_prompt and generation_logits.shape[0] > 1:
        logit_differences = logit_differences.mean()

    return logit_differences

In [69]:
# Multiple prompts for fuller analysis
# prompts_to_format = [
#     "When John and Mary went to the shops,{} gave the bag to",
#     "When Tom and James went to the park,{} gave the ball to",
#     "When Dan and Sid went to the shops,{} gave an apple to",
#     "After Martin and Amy went to the park,{} gave a drink to",
# ]

# completion_choices = [
#     (" Mary", " John"),
#     (" Tom", " James"),
#     (" Dan", " Sid"),
#     (" Martin", " Amy"),
# ]

prompts_to_format = [
    "When John and Mary went to the shops,{} gave the bag to",
]

completion_choices = [
    (" Mary", " John"),
]

# List of prompts and completions, in the format (correct_token, incorrect_token)
prompts_strs = []
completions_toks = []

# Create list of the token ids corresponding to each answer
for prompt, completion in zip(prompts_to_format, completion_choices):
    formatted_prompt, completion_toks = format_prompt_for_analysis(
        prompt,
        completion[0],
        completion[1],
    )

    prompts_strs.append(formatted_prompt)
    completions_toks.append(completion_toks)
    print(formatted_prompt, completion_toks)

    # formatted_prompt, completion_toks = format_prompt_for_analysis(
    #     prompt,
    #     completion[1],
    #     completion[0],
    # )

    # prompts_strs.append(formatted_prompt)
    # completions_toks.append(completion_toks)
    # print(formatted_prompt, completion_toks)

prompts_toks = model.to_tokens(prompts_strs, prepend_bos=True).to(device)
completions_toks = torch.tensor(completions_toks).to(device)

When John and Mary went to the shops, John gave the bag to (5335, 1757)


In [70]:
# Generate completion and cache activation values
model_output_logits, activation_cache = model.run_with_cache(prompts_toks)

print(model_output_logits.shape)

torch.Size([1, 15, 50257])


In [71]:
logit_differences = compute_logit_difference(
    model_output_logits,
    completions_toks,
    per_prompt=True
)

if len(logit_differences.shape) == 0:
    print(f"Per prompt logit difference: {logit_differences:.3f}")
else:
    print(f"Per prompt logit difference: {logit_differences}")
    mean_logit_difference = logit_differences.mean()
    print(f"Average logit difference: {mean_logit_difference:.3f}")

Per prompt logit difference: 3.337


Result:
* Average logit difference is $3.5$.
* This represents putting a $e^{3.5}$ or approx. $33 \times$ higher probability on the correct answer.
  
Hypothesis:
* We might expect there to be a head which detects duplicate tokens on the second " John" token, and then another head which moves that information from the second " John" token to the " to" token.
* The model then needs to learn to predict " Mary" and not " John". There are potentially two ways to do this:
  * Detect all preceding names and move this information to " to" point. Then delete any names corresponding to the duplicate token feature, e.g. by an MLP layer that deletes the " John" direction of the residual stream.
  * Have a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names.

A test that could distinguish these two is to look at which components of the model add directly to the logits. If it’s mostly attention heads that attend to " Mary" and not to " John" it’s probably hypothesis 2, if it’s mostly MLPs it’s probably hypothesis 1. We should be able to identify duplicate token heads by finding ones which attend from " John" to " John", and whose outputs are then moved to the " to" token by V-Composition with another head. 

(Note: the above reasoning is simplistic and could easily break in a larger model)

In [72]:
def plot_top_predicted_tokens(generation_logits: torch.Tensor, k: int = 10) -> None:
    """
    Plot the logits of the top k predicted tokens from a model generation.

    Args:
        generation_logits (Tensor): Logits output from the model of shape (batch, seq_len, vocab_size).
        k (int): Number of tokens from vocab to plot.
    """

    # Get logits corresponding to the last token (i.e. our completion)
    logits_last_token = generation_logits[0, -1, :]

    logits_by_token = list(enumerate(logits_last_token.tolist()))
    logits_by_token = sorted(
        logits_by_token,
        key=lambda x: x[1],
        reverse=True
    )

    logit_values = np.array([logit for tok, logit in logits_by_token])
    # logit_values = (logit_values - min(logit_values))
    # logit_values = logit_values / max(logit_values)

    highest_predictions = logit_values[:k]
    highest_predictions = np.expand_dims(highest_predictions, axis=0)

    token_labels = [model.to_single_str_token(tok) for tok, logit in logits_by_token[:k]]

    px.imshow(
        highest_predictions,
        x=token_labels,
        y=["Logit value  "],
        title="Logit values for final token of generated sequence",
        labels=dict(x="Token from vocabulary"),
        color_continuous_midpoint=0.0,
        color_continuous_scale="turbo",
    ).show()


plot_top_predicted_tokens(model_output_logits, k=25)

### Direct Logit Attribution

* **Direct logit attribution**: An approach to reverse engineering a circuit by understanding how a given model produces its answer (i.e. start at the end) and to then working backwards through the network.
* **Residual stream**: The 'central object' of the transformer network architecture to which many layers read from and write to. It is the sum of the outputs of each layer and the original token / positional embedding.
* As it is composed of a sum, the residual stream can be decomposed back into the contribution of each attention layer in a transformer model. Additionally, the output of each attention layer can be decomposed into the sum of outputs of each attention head, and each MLP layer into the sum of outputs of each neuron (plus bias).
* The output logits are given by: end of residual stream $\to$ normalisation function (LayerNorm) $\to$ unembedding layer (Unembed).
* As these are (approximately) linear maps, we can work back from the logits to find the network components that contribute most to a given output; this is **direct logit attribution**.
* As the model is trained to optimise the cross-entropy loss, the exact values of the logits is not important, instead their relative values to each other (hence the use of the 'logit difference' metric).

#### The logit lens

Attribution of logits throughout layers in the model. 
The logit lens looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers.


In [73]:
def residual_stream_to_logit_diff(
    residual_stream: torch.Tensor,
    activation_cache: ActivationCache,
    logit_difference_directions: torch.Tensor,
) -> float:
    """
    Computes the mean logit difference for a given residual stream and logit difference directions.

    Args:
        residual_stream (torch.Tensor): The residual stream tensor from the transformer model, typically of shape (batch, seq_len, d_model) or (batch, d_model).
        activation_cache (ActivationCache): The activation cache containing model activations and utility functions, used here for applying LayerNorm.
        logit_difference_directions (torch.Tensor): The directions in residual space corresponding to the difference between correct and incorrect completion tokens, shape (batch, d_model) or (d_model,).

    Returns:
        mean_logit_differences (float): The mean logit difference after projecting the normalized residual stream onto the logit difference directions.
    """

    # Apply the "LayerNorm" function (as model would)
    scaled_residual_stream = activation_cache.apply_ln_to_stack(
        residual_stream,
        layer=-1,
        pos_slice=-1,  # subset of the positions we use (here the final token)
    )

    # Compute the logit difference
    mean_logit_differences = einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stream,
        logit_difference_directions,
    )
    mean_logit_differences = mean_logit_differences / residual_stream.shape[0]

    return mean_logit_differences

In [74]:
# Accessing the network's residual stream
# "resid_post" is the residual stream at the end of the layer, -1 gets the final layer.
# The general syntax is [activation_name, layer_index, sub_layer_type].
residual_stream_final_layer = activation_cache["resid_post", -1]
print(f"Final layer residual stream shape: {residual_stream_final_layer.shape}")

residual_stream_final_token = residual_stream_final_layer[:, -1, :]

# Getting an output logit is equivalent to projecting onto a direction in the residual stream
residual_stream_projection = model.tokens_to_residual_directions(completions_toks)
logit_difference_directions = residual_stream_projection[:, 0] - residual_stream_projection[:, 1]
print(f"Logit difference directions shape: {logit_difference_directions.shape}")

mean_logit_difference = residual_stream_to_logit_diff(
    residual_stream_final_token,
    activation_cache,
    logit_difference_directions,
)

print(f"Logit difference: {mean_logit_difference:.3f}")

Final layer residual stream shape: torch.Size([1, 15, 768])
Logit difference directions shape: torch.Size([1, 768])
Logit difference: 4.633


In [80]:
accumulated_residual_stream, component_labels = activation_cache.accumulated_resid(
    layer=-1,       # layer from which we gather the residual stream
    pos_slice=-1,   # subset of the positions we use (here the final token)
    incl_mid=True,  # use the residual stream in the middle of a layer (i.e. after attention & before MLP)
    return_labels=True,
)

logit_differences = residual_stream_to_logit_diff(
    accumulated_residual_stream,
    activation_cache,
    logit_difference_directions
)

logit_differences = utils.to_numpy(logit_differences)
x_axis_range = np.arange(model.cfg.n_layers * 2 + 1) / 2
node_labels = [f'Stream component: "{l}"' for l in component_labels]

print(f"Model layers (x values): {x_axis_range}")
print(f"Logit differences (y values): {logit_differences}")

px.line(
    y=logit_differences,
    x=x_axis_range,
    hover_name=node_labels,
    title="Logit difference of layer components from residual stream",
    labels=dict(x="Layer", y="Logit difference"),
).show()

Model layers (x values): [ 0.   0.5  1.   1.5  2.   2.5  3.   3.5  4.   4.5  5.   5.5  6.   6.5
  7.   7.5  8.   8.5  9.   9.5 10.  10.5 11.  11.5 12. ]
Logit differences (y values): [-0.00043612  0.00040156  0.00036878  0.00084074  0.00061268  0.00255602
  0.00117801  0.00256532  0.00267829 -0.00029886  0.00306021  0.00597805
  0.00067935 -0.00122963  0.00075706  0.01887021  0.01560044  0.0410631
  0.05046535  0.2078262   0.20520501  0.19907951  0.22312     0.17324272
  0.18531337]


Result:

* Logit differences are much greater in later layers of the network.

Explanation:

* The model is unable to complete the task in earlier layers, so layer logits have a mostly uniform and low value.
* The peak logit difference at layer 9 means that most of the model's generation performance on this task comes from (the logits of) layer 9.
* Additionally, the performance jump comes from residual component "9_mid". Here "mid" means we have applied the attention layer but not its subsequent MLP layer. This means the performance comes mostly from the attention layer.

#### Layer attribution

Attribution of layers throughout the model on a per-layer basis, including all attention head layers and MLP layers within the 'attention block'.

In [83]:
residual_stream, component_labels = activation_cache.accumulated_resid(
    layer=-1,
    pos_slice=-1,
    return_labels=True,
)

logit_differences = residual_stream_to_logit_diff(
    accumulated_residual_stream,
    activation_cache,
    logit_difference_directions
)

logit_differences = utils.to_numpy(logit_differences)
node_labels = [f'Stream component: "{l}"' for l in component_labels]

print(f"Model layers (x values): {x_axis_range}")
print(f"Logit differences (y values): {logit_differences}")

px.line(
    y=logit_differences,
    title="Logit difference for all model layers (attention and MLP)",
    labels=dict(x="Layer", y="Logit difference"),
).show()

Model layers (x values): [ 0.   0.5  1.   1.5  2.   2.5  3.   3.5  4.   4.5  5.   5.5  6.   6.5
  7.   7.5  8.   8.5  9.   9.5 10.  10.5 11.  11.5 12. ]
Logit differences (y values): [-0.00043612  0.00040156  0.00036878  0.00084074  0.00061268  0.00255602
  0.00117801  0.00256532  0.00267829 -0.00029886  0.00306021  0.00597805
  0.00067935 -0.00122963  0.00075706  0.01887021  0.01560044  0.0410631
  0.05046535  0.2078262   0.20520501  0.19907951  0.22312     0.17324272
  0.18531337]


#### Layer attribution

The contribution of each attention layer (within the attention block) can be decomposed into the outputs of each attention head.

Decomposition of an attention layer:
* Each attention layer consists of 12 independent heads.
* In GPT2, this gives a total of 144 attention heads.
* The attention layer output is calculated by concatenating the values of each head, and multiplying by a weight matrix.
* This is equivalent to splitting the weight matrix and computing weighted values on a per-head basis, then summing (plus a bias). 

In [122]:
per_head_residual_stream, component_labels = activation_cache.stack_head_results(
    layer=-1,
    pos_slice=-1,
    return_labels=True,
)

per_head_logit_differences = residual_stream_to_logit_diff(
    per_head_residual_stream,
    activation_cache,
    logit_difference_directions
)

per_head_logit_differences = einops.rearrange(
    per_head_logit_differences,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)

per_head_logit_differences = utils.to_numpy(per_head_logit_differences)
transposed_per_head_logit_differences = per_head_logit_differences.T

px.imshow(
    transposed_per_head_logit_differences,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    title="Logit difference from each attention head",
    labels={"x": "Model attention layer", "y": "Attention head in layer"},
    origin="lower",
).show()

### Attention analysis

* For each attention head, we can observe its attention pattern to identify what positions they move information from and to.
* Here we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token.