# Attention & Induction Heads

I've tried to write this in kind of a logbook style, in the order of discovery.

Culture gives the GPTs a set of tasks (similar to ARC-AGI). The first two grids offer a one-shot example of the task (`A, f(A)`), the third grid offers the new input to be transformed, and the final grid is judged as the output.

Notably the tasks are all reversable, aka. `B -> f(B)` and `f(B) -> B` are both possible.

In [13]:
import torch as t
from transformer_lens import HookedTransformer, ActivationCache
import circuitsvis as cv
from IPython.display import display
from datasets import load_dataset, Dataset, DatasetDict

from interp.all import *

device = get_device()
model = load_hooked(0).eval()
assert isinstance(model, HookedTransformer)

dataset_full = load_dataset("tommyp111/culture-puzzles-1M", split="train")
dataset = load_dataset("tommyp111/culture-puzzles-1M-partitioned")

dataset_full.set_format("pt")
dataset.set_format("pt")

assert isinstance(dataset_full, Dataset)
assert isinstance(dataset, DatasetDict)


final_grid_slice = slice(303, None)

Let's see the effect of each of the attention layers. Since we are only interested in the last grid, we'll look at the loss from only that:

In [11]:
print("Running zero ablation on attention...")

def zero_abl_hook(activation, hook):
    return t.zeros_like(activation)

verbose = False

k = 5
num_tasks = 128

batch = dataset_full[:64]['input_ids']

with t.inference_mode():
    orig_loss = (
        model(batch, return_type="loss", loss_per_token=True)[:, final_grid_slice]
        .mean()
        .item()
    )
    print(f"Orig loss: {orig_loss:.2e}")
    for i in range(model.cfg.n_layers):
        loss = (
            model.run_with_hooks(
                batch,
                return_type="loss",
                fwd_hooks=[(f"blocks.{i}.hook_attn_out", zero_abl_hook)],
                loss_per_token=True,
            )[:, final_grid_slice]
            .mean()
            .item()
        )
        print(f"Ablate attn (layer {i}) diff: {loss:.2e}")

Running zero ablation on attention...
Orig: 2.84e-03
Ablate attn (layer 0) diff: 2.65e+00
Ablate attn (layer 1) diff: 6.48e-03
Ablate attn (layer 2) diff: 2.92e-03
Ablate attn (layer 3) diff: 2.02e-03
Ablate attn (layer 4) diff: 2.85e-03
Ablate attn (layer 5) diff: 3.26e-03
Ablate attn (layer 6) diff: 3.03e-03
Ablate attn (layer 7) diff: 2.63e-03
Ablate attn (layer 8) diff: 2.85e-03
Ablate attn (layer 9) diff: 2.95e-03
Ablate attn (layer 10) diff: 3.18e-03
Ablate attn (layer 11) diff: 3.44e-03


Wow! we can see that the first (0th) layer has an outsized impact on the loss tasks. All other layers have a several order of magnitude smaller impact.

Visualizing the effect of ablating the first layer:

In [27]:
model.add_hook("blocks.0.hook_attn_out", zero_abl_hook)  # type: ignore
generate_and_print(model, dataset["frame"][1]["input_ids"].to(device))
model.reset_hooks()

  0%|          | 0/100 [00:00<?, ?it/s]

correct: False
Y
[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[36m███[0m[36m███[0m[36m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[36m███[0m[36m███[0m[36m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;87m███[0m[38;5;87m███[0m[38;5;87m███[0m[36m███[0m[36m███[0m[36m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;87m███[0m[38;5;87m███[0m[38;5;87m███[0m[36m███[0m[36m███[0m[36m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;87m███[0m[38;5;87m███[0m[38;5;87m███[0m[36m███[0m[36m███[0m[36m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;87m███[0m[38;5;87m███[0m[38;5;87m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m

## Jumping to conclusions

This was my first indication that the model was incorporating an induction head in the first layer attention head

```
preds = [11, ... 12, ... 13, ... 14, 13, ... ]
```

Given 11, 12, 13 before, and 14 as the previous token, the model will consistently predict 13 as the next token (which is an invalid template)

This could indicate that the model uses the attention layer as a previous token circuit? And disabling the layer the model loses the ability to attend to the previous token?

Let's see if we can first find the previous token head

In [3]:
quiz = quizzes[None, 0]
str_tokens = model.to_str_tokens(quiz)

logits, cache = model.run_with_cache(quiz, remove_batch_dim=True)

def prev_attn_detector(cache, thres, n_lookback):
    out = []
    for layer in range(GPT_SMALL.n_layers):
        attn_pattern = cache["pattern", layer].cpu()
        diag = attn_pattern[
            :,
            t.arange(GPT_SMALL.n_ctx),
            t.clamp(t.arange(GPT_SMALL.n_ctx) - n_lookback, 0)
        ]
        mean_diag = diag.mean(dim=1)
        mask = mean_diag > thres
        heads, values = t.nonzero(mask, as_tuple=True)[0], mean_diag[mask]
        out.extend([f"{layer}.{head.item()}: {value.item():.4f}"
                    for head, value in zip(heads, values)])
    return out

print("Heads attending to previous token =", ", ".join(prev_attn_detector(cache, thres=0.2, n_lookback=1)))


Heads attending to previous token = 0.7: 0.3272, 3.1: 0.3402, 5.4: 0.3529, 7.2: 0.6515


Hm! there's only one head that attends to the previous token in layer 0, and the value is quite low.

In fact there's a much stronger head in layer 7

In [5]:
attn_pattern = cache["pattern", 0][7]

display(
    cv.attention.attention_pattern(
        tokens=str_tokens,
        attention=attn_pattern,
    )
)

Let's look back a whole grid then:

In [6]:
print("Heads attending to previous grid =", ", ".join(prev_attn_detector(cache, thres=0.1, n_lookback=100)))

Heads attending to previous grid = 0.0: 0.4590, 0.5: 0.2725, 0.7: 0.1920


There's an even stronger head, 0!

And we can see that it strongly relates the to the current token from the previous grid, and if in the first grid it attends to the previous token.

(It also slightly attends to the token 20 tokens either side of the previous grid)

This could be a fixed distance induction head???? Of course the model will know *exactly where the previous token will be! Always 100 tokens back.

In [7]:
attn_pattern = cache["pattern", 0][0]

display(
    cv.attention.attention_pattern(
        tokens=str_tokens,
        attention=attn_pattern,
    )
)

Ok let's try to test this, if the attention pattern writes directly into W_U subspace, we can ablate ALL other layers & heads, and see if the output repeats the previous grid.

This isn't foolproof, I'm sure there must be a lot more in the network, but it's a nice test to see if we're on track.

So we're essentially doing: W_E (layer 0 head 0) W_U

In [4]:
ablate_mlp = [
    (f"blocks.{i}.hook_mlp_out", zero_abl_hook)
    for i in range(model.cfg.n_layers)
]


def zero_abl_except_head(activation, hook):
    B, H, T, _ = activation.shape
    act = t.zeros(B, H, T, T)
    act[:, :, t.arange(T), t.clamp(t.arange(T)-100, 0)] = 1
    return act

ablate_attn = [
    (f"blocks.{i}.attn.hook_pattern",
    zero_abl_except_head if i == 0 else zero_abl_hook)
    for i in range(model.cfg.n_layers)
]


for hook in ablate_mlp + ablate_attn:
    model.add_hook(*hook)  # type: ignore

generate_and_print(model, quiz)
model.reset_hooks()

  0%|          | 0/100 [00:00<?, ?it/s]

correct: False
X
[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;254m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;254m███[0m[38;5;254m███[0m[35m███[0m[35m███[0m[35m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[35m███[0m[35m███[0m[35m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[35m█


Ok so it's not perfect, but it's pretty close!

We could go further down this rabbit hole, and try to extract a "clean" representation of the previous grid, but this seems fairly trivial and got me thinking -- for 90% of even the new grid, the model can just copy the previous grid. What seems more interesting is the 10% of the time when the model does something different.

It makes sense that even though it has a much smaller impact on the overall loss, the calculations done in later layers are much more interesting.

And just for fun, ablating each of the attention heads, how does it affect this single example

In [5]:
model.reset_hooks()

for layer in range(GPT_SMALL.n_layers):
    print(f"ablating attn layer {layer}: ", end="")
    model.add_hook(f"blocks.{layer}.hook_attn_out", zero_abl_hook)  # type: ignore
    correct, pred = generate(model, quizzes[:1], verbose=False)
    print("correct:", correct.item())
    if not correct:
        print("Ground Truth")
        print(repr_grid(quizzes[0, 303:]))
        print("Predicted")
        print(repr_grid(pred[0, 303:]))
    model.reset_hooks()


ablating attn layer 0: correct: False
Ground Truth
f(Y)
[35m███[0m[35m███[0m[35m███[0m[35m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;254m███[0m
[35m███[0m[35m███[0m[35m███[0m[35m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;254m███[0m
[35m███[0m[35m███[0m[35m███[0m[35m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;208m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;219m███[0m[38;5;254m███[0m
[38;5;254m███[0m[38;5;208m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;254m███[0m[38;5;208m███[0m[3



The two ways (I know of) of creating an induction head are:

1. Using the previous token head through k-composition -- induction head: uses the current token as query, and attends highly to a matching previous token
2. Using the previous token head through q-composition -- induction head: the model uses the current token as the query, then the OV circuit rotates the positional embedding (in ROPE) to the next token

We're using sinusoidal positional encoding so might be harder (but not impossible) to rotate the positional embedding

### Multiplying through the residual stream

Let's see which layers use the output of layer 0 the most. We can do this by multiplying through the residual stream
