# TransformerLens Head Detector Demo 

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "png"

In [None]:
import torch
import einops
import pysvelte
from tqdm import tqdm

import transformer_lens
from transformer_lens import HookedTransformer, ActivationCache
from neel_plotly import line, imshow, scatter

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled>

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device = }")

device = 'cpu'


In [None]:
def plot_head_detection_scores(
    scores: torch.Tensor,
    zmin: float = -1,
    zmax: float = 1,
    xaxis: str = "Head",
    yaxis: str = "Layer",
    title: str = "Head Matches"
) -> None:
    imshow(scores, zmin=zmin, zmax=zmax, xaxis=xaxis, yaxis=yaxis, title=title)

def plot_attn_pattern_from_cache(cache: ActivationCache, layer_i: int):
    attention_pattern = cache["pattern", layer_i, "attn"].squeeze(0)
    attention_pattern = einops.rearrange(attention_pattern, "heads seq1 seq2 -> seq1 seq2 heads")
    print(f"Layer {layer_i} Attention Heads:")
    return pysvelte.AttentionMulti(tokens=model.to_str_tokens(prompt), attention=attention_pattern)

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


### Head Detector

In [None]:
def is_square(x):
    return x.ndim == 2 and x.shape[0] == x.shape[1]

In [None]:
def is_lower_triangular(x):
    if not is_square(x):
        return False
    return x.equal(x.tril())

In [None]:
def detect_head(
    model,
    seq,
    detection_pattern,
    heads,
    cache
):
    matches = tor

In [None]:
n_layers = model.cfg.n_layers
n_heads = model.cfg.n_heads

In [None]:
data = torch.zeros(n_layers, n_heads)

In [None]:
layer2heads = {
    layer_idx: list(range(n_heads)) for layer_idx in range(n_layers)
}

In [None]:
layer2heads

{0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 3: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 4: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 5: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 6: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 7: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 8: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 9: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 10: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
 11: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}

In [None]:
from transformer_lens import utils

In [None]:
utils.get_act_name("pattern", 1, "attn")

'blocks.1.attn.hook_pattern'

In [None]:
text = "Persistence is all you need."

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
def compute_head_attention_similarity_score(attention_pattern, target_pattern):
    score = attention_pattern * target_pattern
    return score.sum() / attention_pattern.sum()

In [None]:
def get_target_pattern(tokens):
    seq_len = tokens.shape[-1]
    pattern = torch.zeros(seq_len, seq_len)
    pattern[torch.arange(seq_len), torch.arange(seq_len)] = 1
    return pattern

In [None]:
target_pattern = get_target_pattern(tokens)

In [None]:
target_pattern

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.]])

In [None]:
from IPython.core.debugger import set_trace

In [None]:
batch_idx = 0

In [None]:
for layer_idx, head_idxs in layer2heads.items():
    print(f"layer_idx={layer_idx}, head_idxs={head_idxs}")
    
    hook_name = f"blocks.{layer_idx}.attn.hook_pattern"
    layer_attn_patterns = cache[hook_name]
    
    for head_idx in head_idxs:
        # set_trace()
        head_attn_patterns = layer_attn_patterns[batch_idx, head_idx, :, :]
        head_score = compute_head_attention_similarity_score(
            attention_pattern=head_attn_patterns,
            target_pattern=target_pattern
        )
        data[layer_idx, head_idx] = head_score

layer_idx=0, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=1, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=2, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=3, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=4, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=5, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=6, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=7, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=8, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=9, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=10, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
layer_idx=11, head_idxs=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]


In [None]:
data

tensor([[0.1878, 0.9888, 0.1791, 0.7570, 0.5007, 0.8425, 0.2004, 0.2856, 0.3645,
         0.2010, 0.3534, 0.2259],
        [0.2204, 0.2185, 0.2602, 0.2743, 0.2132, 0.2384, 0.1680, 0.1819, 0.2130,
         0.1605, 0.4746, 0.7680],
        [0.2174, 0.1606, 0.2036, 0.1936, 0.1981, 0.2440, 0.2492, 0.3668, 0.2184,
         0.1947, 0.2794, 0.1931],
        [0.1489, 0.1969, 0.1992, 0.1892, 0.1327, 0.1666, 0.2243, 0.2040, 0.2486,
         0.2181, 0.2232, 0.2539],
        [0.1834, 0.1760, 0.1549, 0.1920, 0.1501, 0.2024, 0.1614, 0.5871, 0.1408,
         0.1688, 0.1483, 0.1250],
        [0.1339, 0.1263, 0.1409, 0.1597, 0.1684, 0.1339, 0.1594, 0.1627, 0.1487,
         0.1294, 0.1449, 0.2265],
        [0.1676, 0.1611, 0.1386, 0.1384, 0.1767, 0.1743, 0.1750, 0.1594, 0.1490,
         0.1307, 0.1423, 0.1855],
        [0.1532, 0.1368, 0.1259, 0.1558, 0.1426, 0.1501, 0.1388, 0.1310, 0.1659,
         0.1522, 0.1273, 0.1286],
        [0.1678, 0.1349, 0.1677, 0.1465, 0.1869, 0.1771, 0.1612, 0.1673, 0.1716,

In [None]:
def get_previous_token_head_detection_pattern(tokens):
    seq_len = tokens.shape[-1]
    detection_pattern = torch.zeros(seq_len, seq_len)
    detec

In [None]:
text = "Persistence is all you need."

In [None]:
tokens = model.to_tokens(text)

In [None]:
tokens

tensor([[50256, 30946, 13274,   318,   477,   345,   761,    13]])

In [None]:
seq_len = tokens.shape[-1]

In [None]:
seq_len

8