In [None]:
import torch
from jaxtyping import Float, Int


import circuitsvis as cv
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


##### Example 1

In [None]:
prompt = "123412341234"

In [None]:
attn_weights = torch.randn(4, 5, 12, 12)

In [None]:
prompt

'123412341234'

`attn_weights` has batch size 4, with for attention head and the weight between 12 tokens

In [None]:
attn_weights.shape

torch.Size([4, 5, 12, 12])

Calculate the induction stripe. And explain the code.

**Explain**

The sequence consists of three repetitions of the pattern "1234". So the length of the pattern is `4`

In [None]:
seq_len = 4

The induction stripe consists of the attention weights between each token in the second half of the sequence and the token after its first occurrence.
- `offset=1-seq_len`: by offset `1-seq_len=1-4=-3`, we extract the attention weights with the desired pattern

In [None]:
induction_stripe = attn_weights.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)

In [None]:
induction_stripe.shape

torch.Size([4, 5, 9])

##### Example 2

In [None]:
induction_stripe = torch.randn(5, 9)

In [None]:
import einops

`induction_stripe` has 5 heads and 9 attention weights for each head

In [None]:
induction_stripe.shape

torch.Size([5, 9])

Calculate the induction score for each head

In [None]:
induction_score = einops.reduce(
    induction_stripe,
    "head_index position -> head_index",
    reduction="mean"
)

In [None]:
induction_score.shape

torch.Size([5])

In [None]:
induction_score

tensor([-0.1851,  0.0256, -0.2207,  0.3466,  0.1657])

##### Example 3

In [None]:
induction_head_layer = 5
induction_head_index = 5

In [None]:
single_random_sequence = torch.randint(1000, 10000, (1, 20)).to(model.cfg.device)

In [None]:
repeated_random_sequence = einops.repeat(single_random_sequence, "batch seq_len -> batch (2 seq_len)")

In [None]:
def visualize_pattern_hook(
    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(repeated_random_sequence), 
            attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
        )
    )

In [None]:
model.run_with_hooks(
    repeated_random_sequence, 
    return_type=None, 
    fwd_hooks=[(
        utils.get_act_name("pattern", induction_head_layer), 
        visualize_pattern_hook
    )]
)