<a href="https://colab.research.google.com/github/sc22lg/ML-Notebooks/blob/gpt2-small-paper-recreation/semantic_attention_recreation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## A Recreation of the Results of: The Self-Hating Attention Head: A Deep Dive in GPT-2 - Matteo Migliarini July 2025
by Leo Gott

Original publication can be found [here](https://www.lesswrong.com/posts/wxPvdBwWeaneAsWRB/the-self-hating-attention-head-a-deep-dive-in-gpt-2-1)

### Overall idea:
"gpt2-small's head L1H5 directs attention to semantically similar tokens and actively suppresses self-attention"
### Results to re-create:
- Create inputs to ellicit expected behaviour
- Use inputs to identify heads performing behaviour in gpt2-small (expected head L1H5)
- Perform mean-ablation of preceding components to find which components effect L1H5

### Setup:

In [99]:
import os
import sys
from pathlib import Path

import pkg_resources

installed_packages = [pkg.key for pkg in pkg_resources.working_set]
if "transformer-lens" not in installed_packages:
    %pip install transformer_lens==2.11.0 einops eindex-callum jaxtyping git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

import pandas as pd
import circuitsvis as cv
import einops
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)
from transformer_lens.hook_points import HookPoint

In [100]:
import random as rand
import plotly.express as px
from IPython.display import display

### 1.1 Generate input prompt

In [101]:
semantic_words_file = pd.read_csv('semantic_words.csv', header=None)
print(semantic_words_file.to_string())

        0        1          2         3       4         5
0     dog      cat       fish       rat     bat     horse
1     red     blue      green      pink    gray      cyan
2     you       he        his       she     her     their
3     car      bus        van      taxi    jeep      tram
4   Italy  Iceland    Austria    Mexico   Spain    France
5     sad      mad       glad     angry    calm     happy
6  Monday  Tuesday  Wednesday  Thursday  Friday  Saturday
7      69       72         88        68      90        57
8    1800     2025       1275      1945    1600      2100


In [102]:
# Create shuffled list of tokens
n_sequences = 30
n_tokens = 12
n_rows = semantic_words_file.shape[0]

inputs = np.empty((n_sequences, n_tokens), dtype=tuple)

for i in range(n_sequences):
  subset = semantic_words_file.sample(4)
  for j in range(n_tokens):
    category_list = subset.sample(1)
    category = category_list.index[0]
    token = category_list.iloc[0].sample(1).values[0]
    inputs[i, j] = (category, token)
# print(inputs)

In [103]:
#create masks representing where tokens in an input share a category
masks = np.zeros((n_sequences, n_tokens, n_tokens))

for seq in range(n_sequences):
  for i in range(n_tokens):
    for j in range(n_tokens):
      if inputs[seq, i][0] == inputs[seq, j][0] and inputs[seq, i][1] != inputs[seq, j][1] and i > j: # ensures upper triangle is 0s
        masks[seq, i, j] = 1

In [None]:
show_mask = 1
fig = px.imshow(masks[show_mask], labels=dict(x="Token Index", y="Token Index", color="Same Category")) # Added labels dictionary
fig.update_layout(xaxis = dict(
        tickmode = 'array',
        tickvals = list(range(n_tokens)),
        ticktext = [inputs[show_mask, i][1] for i in range(n_tokens)] # Use tokens from inputs for x-axis
    ),
    yaxis = dict(
        tickmode = 'array',
        tickvals = list(range(n_tokens)),
        ticktext = [inputs[show_mask, i][1] for i in range(n_tokens)] # Use tokens from inputs for y-axis
    )
)
fig.show()

### Load & test gpt2-small:

In [105]:
model = HookedTransformer.from_pretrained("gpt2-small", device="cpu")

Loaded pretrained model gpt2-small into HookedTransformer


In [106]:
# run model on selected sequence & cache activation
sequence_index = show_mask
test_input = ' '.join([inputs[sequence_index, i][1] for i in range(n_tokens)])
print("input: " + test_input)
input_tokens = model.to_tokens(test_input)
logits, cache = model.run_with_cache(input_tokens)

input: France 1945 dog cat Iceland France rat Spain angry Italy sad rat


In [None]:
layer_number = 1
layer1_patterns = cache["pattern", layer_number]
print(layer1_patterns.shape)
print(input_tokens.shape)
print(input_tokens.squeeze())
str_tokens = model.to_str_tokens(input_tokens.squeeze())

display(
    cv.attention.attention_patterns(
        tokens=str_tokens,
        attention=layer1_patterns.squeeze(),
        attention_head_names=[f"L{layer_number}H{i}" for i in range(12)],
    )
)

### Pattern detector:

In [108]:
print(f"layers: {model.cfg.n_layers}")
print(f"heads per layer: {model.cfg.n_heads}")

layers: 12
heads per layer: 12


In [111]:
def evaluate_semantic_head(cache: ActivationCache, layer:int, head:int):
  attention_pattern = cache["pattern", layer].squeeze()[head]
  expected_attention = attention_pattern[1:,1:][t.from_numpy(masks[sequence_index]).bool()] # slice attention to ignore |endoftext| and then apply mask
  return t.mean(expected_attention)

def semantic_head_detector(cache: ActivationCache):
  scores = np.zeros((model.cfg.n_layers, model.cfg.n_heads))

  # calculate attention score for current input for each attention head
  for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
      scores[layer, head] = evaluate_semantic_head(cache, layer, head)
  return scores


In [None]:
attn_head_scores = semantic_head_detector(cache)
fig = px.imshow(attn_head_scores, labels=dict(x="Head", y="Layer", color="Attention Score"))
fig.show()

## Measuring Component importance:
In order to measure which components are contributing to the function of L1H5, we perform a mean ablation study. We replace each previous compenent of the network with its mean and re-measure the attention score of our metric for L1H5 to see if it was impacted by the ablation of that component. The components which when ablated has the most effect on the performance of L1H5 are the components which contribute to its function.

In [None]:
#perform mean ablation on positional embedding
