In [2]:
import warnings
warnings.filterwarnings('ignore')

from typing import Callable

import circuitsvis as cv
import einops
import plotly_utils
import torch
from eindex import eindex
from huggingface_hub import hf_hub_download
from torch import Tensor
from transformer_lens import HookedTransformerConfig, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

In [11]:
!nvidia-smi

Tue Apr  8 14:01:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.86.15              Driver Version: 570.86.15      CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA T4G                     On  |   00000000:00:1F.0 Off |                    0 |
| N/A   42C    P8             15W /   70W |       1MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [12]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Oct_30_00:08:18_PDT_2024
Cuda compilation tools, release 12.6, V12.6.85
Build cuda_12.6.r12.6/compiler.35059454_0


In [15]:
import torch
print(torch.__version__)
print(torch.version.cuda)

2.6.0+cpu
None


In [14]:
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("CUDA device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")

CUDA available: False
CUDA device count: 0
CUDA device name: N/A


In [3]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using device: {device}')

Using device: cpu


# 1. Hook Points and Hook Functions

One of the great things about interpreting neural networks is that we have full control over our system. From a computational perspective, we know exactly what operations are going on inside (even if we don't know what they mean). And we can make precise, surgical edits and see how the model's behavior and other internals change. This is an extremely powerful tool, because it can let us set up careful counterfactuals and causal intervention to easily understand model behavior. Accordingly, being able to do this is a core operation in mechanistic interpretability.

`Hook points` are used to allow observability and editing of every activation inside the transformer. They allow us to add `hook functions` to any activation and then run the model with those hook functions by calling `model.run_with_hooks`. They also have methods like `hook.layer()` and attributes like `hook.name` that are sometimes useful to call within the functions. Performing observability of activation patterns is useful for things like:
  - extracting activations for a specific task
  - doing long-running calculations across many inputs, e.g. finding the text that most activates a specific neuron

`Hook functions` are used to observe or edit the activations. They take as arguments an `activation_value`, a tensor representing some activation pattern in the model, and a `hook_point`. If a `hook function` is being used to edit activations, then it should return a tensor of the same shape as the activation pattern. But if it is being used for observability, it should not return anything.

### 1.1. Finding induction heads in the toy model from `induction_heads.py`

In [None]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir='causal',
    attn_only=True,
    tokenizer_name='EleutherAI/gpt-neox-20b',
    use_attn_result=True,
    normalization_type=None,
    positional_embedding_type='shortformer',
)

REPO_ID = 'callummcdougall/attn_only_2L_half'
FILENAME = 'attn_only_2L_half.pth'

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

model = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(pretrained_weights)

In [None]:
seq_len = 50
batch_size = 10

prefix = (torch.ones(batch_size, 1) * model.tokenizer.bos_token_id).long()
tokens = torch.randint(0, model.cfg.d_vocab, (batch_size, seq_len), dtype=torch.int64)
repeated_tokens = torch.cat([prefix, tokens, tokens], dim=-1).to(device)

print('Shape of repeated tokens:', repeated_tokens.shape)

In [None]:
# tensor to store induction score for each head
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
print('Shape of induction score store:', induction_score_store.shape)

In [None]:
def induction_score_hook(activations: Tensor, hook: HookPoint):
    """
    Calculates the induction score, and stores it in the [layer, head] position of the `induction_score_store` tensor.

    Args:
        activations (Tensor): The activation pattern tensor with dimensions (batch, head_index, ?, ?).
        hook (HookPoint): The hook point that triggered this function, eg. `blocks.0.attn.hook_pattern`. It gives access
                          to metadata like the layer number and allows capturing values during forward passes.
    """

    # extracts a stripe of activation values on a diagonal `1 - seq_len` positions below the main diagonal
    # corresponding to whether a token is attending to its previous occurrence
    induction_stripe = activations.diagonal(dim1=-2, dim2=-1, offset=1 - seq_len)
    # get an average score per head
    induction_score = einops.reduce(induction_stripe, 'batch head_index position -> head_index', 'mean')
    # store the result
    induction_score_store[hook.layer(), :] = induction_score

In [None]:
# boolean filter on activation names -- returns true for attention patterns
pattern_hook_names_filter = lambda name: name.endswith("pattern")

model.run_with_hooks(
    repeated_tokens,
    return_type=None,  # for efficiency, logits are not calculated here
    fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)],
)

In [None]:
plotly_utils.imshow(
    induction_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Induction Score by Head",
    text_auto=".2f",
    width=900,
    height=350,
)

### 1.2. Finding induction heads in GPT2-small

In [None]:
gpt2_small = HookedTransformer.from_pretrained('gpt2-small').to(device)
induction_score_store = torch.zeros((gpt2_small.cfg.n_layers, gpt2_small.cfg.n_heads), device=gpt2_small.cfg.device)
gpt2_small.run_with_hooks(repeated_tokens, return_type=None, fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)])

In [None]:
plotly_utils.imshow(
    induction_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Induction Score by Head",
    text_auto=".1f",
    width=700,
    height=500,
)

In [None]:
def visualize_activations_hook(activations: Tensor, hook: HookPoint):
    """ Creates a visualization of the attention patterns for the current layer and head. """
    print("Layer: ", hook.layer())
    display(cv.attention.attention_patterns(tokens=model.to_str_tokens(repeated_tokens[0]), attention=activations.mean(0)))

In [None]:
induction_head_layers = [5, 6, 7]

fwd_hooks = [(utils.get_act_name("pattern", layer), visualize_activations_hook) for layer in induction_head_layers]

gpt2_small.run_with_hooks(
    repeated_tokens,
    return_type=None,
    fwd_hooks=fwd_hooks,
)

# 2. Direct logit attribution

In the end of the day, interpretability boils down to the following question:

> How much of the model's performance on some particular task is attributable to each component of the model?

In order to develop such an understanding for how transformers perform certain tasks, we might look at how a head interacts with other heads in different layers, perform causal intervention by seeing how well the model performs when we remove a head, among other things. As a consequence of the residual stream, the output logits are the sum of the contributions of each layer, and thus the sum of the results of each head. This means that the output logits can be decomposed into direct contributions coming from each head.

Suppose a model knows that the token Harry is followed by the token Potter. The logits on Harry are given by `residual @ W_U`, where `W_U` stands for the unembedding matrix. By its turn, the residual stream is the sum of all previous layers: `residual = embed + attn_out_0 + attn_out_1`. So `logits = (embed @ W_U) + (attn_out @ W_U) + (attn_out_1 @ W_U)`. More specifically, the logit of the Potter token corresponds to a column of `W_U`, i.e. a single number that is the sum of `(embed @ potter_U) + (attn_out_0 @ potter_U) + (attn_out_1 @ potter_U)`. This means that each attention layer output can be decomposed into the sum of the result of each head.

In the example below the toy model is used. It is composed by paths from the output of each component directly to the logits:
  - direct path: the residual connections from the embedding to unembedding
  - layer 0 head: via the residual connection and skipping layer 1
  - layer 1 head: via the residual connection directly to the logits

Direct logit attribution looks at the direct effect that a component embeds into the residual stream on the logits. Subtle effects can be missed by looking at just the logits corresponding to the correct token, like a head suppressing other plausible logits to increase the log prob of the correct one.

In [None]:
text = 'We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this.'

tokens = model.to_tokens(text)
str_tokens = model.to_str_tokens(text)
logits, cache = model.run_with_cache(text, remove_batch_dim=True)

print('Shape of tokens:', tokens.shape)
print('Shape of logits:', logits.shape)

In [None]:
def logit_attribution(embed: Tensor, l1_results: Tensor, l2_results: Tensor, W_U: Tensor, tokens: Tensor) -> Tensor:
    """
    Inputs:
        embed: token + position embeddings of the tokens [seq_len, d_model]
        l1_results: output of attention heads at layer 1 [seq_len, n_heads, d_model]
        l2_results: output of attention heads at layer 2 [seq_len, n_heads, d_model]
        W_U: unembedding matrix [d_model, d_vocab]
        tokens: token ids of the input sequence [seq_len]

    Returns:
        Tensor of shape (seq_len-1, n_components)
        represents the concatenation (along dim=-1) of logit attributions from:
            the direct path (seq-1,1)
            layer 0 logits (seq-1, n_heads)
            layer 1 logits (seq-1, n_heads)
        so n_components = 1 + 2*n_heads

    The final element of the output logits is ignored, as there is no next token to predict after the last token.
    """
    W_U_correct_tokens = W_U[:, tokens[1:]]

    direct_attributions = einops.einsum(W_U_correct_tokens, embed[:-1], "emb seq, seq emb -> seq")
    l1_attributions = einops.einsum(W_U_correct_tokens, l1_results[:-1], "emb seq, seq nhead emb -> seq nhead")
    l2_attributions = einops.einsum(W_U_correct_tokens, l2_results[:-1], "emb seq, seq nhead emb -> seq nhead")
    return torch.concat([direct_attributions.unsqueeze(-1), l1_attributions, l2_attributions], dim=-1)

In [None]:
# checks if logit attribution is correct by taking the sum of logit attributions
# and comparing it to the actual values in the residual stream of the model
embed = cache["embed"]
l1_results = cache["result", 0]
l2_results = cache["result", 1]

with torch.inference_mode():
    logit_attr = logit_attribution(embed, l1_results, l2_results, model.W_U, tokens[0])
    correct_token_logits = logits[0, torch.arange(len(tokens[0]) - 1), tokens[0, 1:]]  # kth entry is the predicted logit for the correct k+1th token
    torch.testing.assert_close(logit_attr.sum(1), correct_token_logits, atol=1e-3, rtol=0)
    print("Tests passed!")

### 2.1. Interpreting logit attribution

Most variation in the logit attribution plot comes from the direct path, in particular for tokens that are the first token in common bigrams. For instance, the highest contribution on the direct path comes from | manip|, because this is very likely to be followed by |ulative| (or presumably a different stem like | ulation|). | super| -> |human| is another example of a bigram formed when the tokenizer splits one word into multiple tokens.

There are also examples that come from two different words, rather than a single word split by the tokenizer. These include:

| more| -> | likely| (12)
| machine| -> | learning| (24)
| by| -> | default| (38)
| how| -> | to| (58)

In addition, the heads in layer 1 have higher contributions than the heads in layer 0. This is because this plot doesn't pick up on a head's effect in composition with another head. So the attribution for layer-0 heads won't involve any composition, whereas the attributions for layer-1 heads will involve not only the single-head paths through those attention heads, but also the 2-layer compositional paths through heads in layer 0 and layer 1.

In [None]:
plotly_utils.plot_logit_attribution(model, logit_attr, tokens, title="Logit attribution")

### 2.2. Interpreting logit attribution for induction heads

The first half of the plot below is mostly meaningless, because the sequences are random and carry no predictable pattern, and so there can't be any part of the model that is doing meaningful computation to make predictions.

In the second half, heads `1.4` and `1.10` have a large logit attribution score. This makes sense given the previous observation that these heads seemed to be performing induction (since they both exhibited the characteristic induction pattern), however it's worth emphasizing that this plot gives additional evidence of induction because just observing some head is attending to a particular token doesn't mean it's necessarily using that information to make a concrete prediction.

In [None]:
logits, rep_cache = model.run_with_cache(repeated_tokens[0], remove_batch_dim=True)

print('Shape of tokens:', repeated_tokens[0].shape)
print('Shape of logits:', logits.shape)

In [None]:
with torch.inference_mode():
    embed = rep_cache["embed"]
    l1_results = rep_cache["result", 0]
    l2_results = rep_cache["result", 1]
    logit_attr = logit_attribution(embed, l1_results, l2_results, model.W_U, repeated_tokens[0])
    correct_token_logits = logits[0, torch.arange(len(repeated_tokens[0]) - 1), repeated_tokens[0, 1:]]
    torch.testing.assert_close(logit_attr.sum(1), correct_token_logits, atol=1e-3, rtol=0)
    print("Tests passed!")

In [None]:
plotly_utils.plot_logit_attribution(model, logit_attr, repeated_tokens[0], title="Logit attribution (random induction prompt)")

# 3. Intervening on activations using hooks

In [None]:
import functools

def get_log_probs(
    logits: Tensor, tokens: Tensor
) -> Tensor:
    logprobs = logits.log_softmax(dim=-1)
    # We want to get logprobs[b, s, tokens[b, s+1]], in eindex syntax this looks like:
    correct_logprobs = eindex(logprobs, tokens, "b s [b s+1]")
    return correct_logprobs

def head_zero_ablation_hook(
    z: Tensor,
    hook: HookPoint,
    head_index_to_ablate: int,
) -> None:
    z[:, :, head_index_to_ablate, :] = 0.0


def get_ablation_scores(
    model: HookedTransformer,
    tokens: Tensor,
    ablation_function: Callable = head_zero_ablation_hook,
) -> Tensor:
    """
    Returns a tensor of shape (n_layers, n_heads) containing the increase in cross entropy loss from ablating the output
    of each head.
    """
    # Initialize an object to store the ablation scores
    ablation_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    seq_len = (tokens.shape[1] - 1) // 2
    logits = model(tokens, return_type="logits")
    loss_no_ablation = -get_log_probs(logits, tokens)[:, -(seq_len - 1) :].mean()

    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            # Use functools.partial to create a temporary hook function with the head number fixed
            temp_hook_fn = functools.partial(ablation_function, head_index_to_ablate=head)
            # Run the model with the ablation hook
            ablated_logits = model.run_with_hooks(tokens, fwd_hooks=[(utils.get_act_name("z", layer), temp_hook_fn)])
            # Calculate the loss difference (= negative correct logprobs), only on the last `seq_len` tokens
            loss = -get_log_probs(ablated_logits, tokens)[:, -(seq_len - 1) :].mean()
            # Store the result, subtracting the clean loss so that a value of zero means no change in loss
            ablation_scores[layer, head] = loss - loss_no_ablation

    return ablation_scores

def test_get_ablation_scores(
    ablation_scores: Tensor,
    model: HookedTransformer,
    rep_tokens: Tensor,
):

    ablation_scores_expected = get_ablation_scores(model, rep_tokens)

    torch.testing.assert_close(ablation_scores, ablation_scores_expected)

    print("All tests in `test_get_ablation_scores` passed!")

ablation_scores = get_ablation_scores(model, repeated_tokens)
test_get_ablation_scores(ablation_scores, model, repeated_tokens)

In [None]:
plotly_utils.imshow(
    ablation_scores,
    labels={"x": "Head", "y": "Layer", "color": "Logit diff"},
    title="Loss Difference After Ablating Heads",
    text_auto=".2f",
    width=900,
    height=350,
)

# BEFORE FINISH

- GO BACK TO `A BIT MORE ABOUT HOOKS`

# Sources

1. [Ground truth - TransformerLens: Hooks, by ARENA](https://arena-chapter1-transformer-interp.streamlit.app/[1.2]_Intro_to_Mech_Interp)
2. [Pytorch 101: Understanding hooks, by DigitalOcean](https://www.digitalocean.com/community/tutorials/pytorch-hooks-gradient-clipping-debugging)
3. [Garcon, by Neel Nanda, Chris Olah, et al.](https://transformer-circuits.pub/2021/garcon/index.html)