<a href="https://colab.research.google.com/github/sc22lg/ML-Notebooks/blob/gpt2-small-paper-recreation/semantic_attention_recreation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## A Recreation of the Results of: The Self-Hating Attention Head: A Deep Dive in GPT-2 - Matteo Migliarini July 2025
by Leo Gott

Original publication can be found [here](https://www.lesswrong.com/posts/wxPvdBwWeaneAsWRB/the-self-hating-attention-head-a-deep-dive-in-gpt-2-1)

### Overall idea:
"gpt2-small's head L1H5 directs attention to semantically similar tokens and actively suppresses self-attention"
### Results to re-create:
- Create inputs to ellicit expected behaviour
- Use inputs to identify heads performing behaviour in gpt2-small (expected head L1H5)
- Perform mean-ablation of preceding components to find which components effect L1H5

### Setup:

In [1]:
import os
import sys
from pathlib import Path

import pkg_resources

installed_packages = [pkg.key for pkg in pkg_resources.working_set]
if "transformer-lens" not in installed_packages:
    %pip install transformer_lens==2.11.0 einops eindex-callum jaxtyping git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

import pandas as pd
import circuitsvis as cv
import einops
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)
from transformer_lens.hook_points import HookPoint

  import pkg_resources


In [5]:
import random as rand

### 1.1 Generate input prompt

In [4]:
semantic_words_file = pd.read_csv('semantic_words.csv', header=None)
print(semantic_words_file.to_string())

         0           1           2          3           4           5
0   Monday     Tuesday   Wednesday   Thursday      Friday    Saturday
1      red        blue       green     silver       white        Blue
2     1918        1920        1930       1943        1998        2000
3      You          He         his        she         her       their
4    Italy     Iceland     Austria     Mexico       Spain      France
5      dog         cat       horse    hamster        fish      lizard
6       60          65          69         70          71          90
7    angry       happy         sad    excited       bored    stressed
8      car         bus         van      truck   motorbike   aeroplane
9     rose       tulip        lily      daisy      orchid   sunflower
10  guitar       piano      violin       drum       flute     trumpet
11  soccer  basketball      tennis   baseball       rugby      hockey
12  circle      square    triangle  rectangle     hexagon     octagon
13   chair       tab

In [37]:
# Create shuffled list of tokens
n_sequences = 30
n_tokens = 16
n_rows = semantic_words_file.shape[0]

inputs = np.empty((n_sequences, n_tokens), dtype=tuple)

for i in range(n_sequences):
  subset = semantic_words_file.sample(4)
  for j in range(n_tokens):
    category_list = subset.sample(1)
    category = category_list.index[0]
    token = category_list.iloc[0].sample(1).values[0]
    inputs[i, j] = (category, token)
print(inputs)

[[(11, 'soccer') (15, 'dentist') (15, 'surgeon') (2, ' 2000')
  (2, ' 2000') (2, '1918') (7, ' stressed') (2, ' 1920') (11, 'soccer')
  (7, ' happy') (11, 'rugby') (7, ' bored') (11, 'soccer')
  (11, 'basketball') (7, ' happy') (7, ' excited')]
 [(6, ' 65') (3, ' she') (6, ' 71') (6, ' 69') (6, '60') (14, 'lake')
  (12, 'circle') (3, 'You') (12, 'square') (12, 'rectangle') (14, 'lake')
  (14, 'river') (12, 'octagon') (3, 'You') (3, ' she') (6, ' 71')]
 [(6, ' 65') (12, 'square') (12, 'rectangle') (9, 'orchid') (13, 'chair')
  (12, 'triangle') (9, 'lily') (13, 'bed') (6, ' 71') (12, 'rectangle')
  (12, 'triangle') (6, ' 69') (6, ' 70') (12, 'rectangle') (9, 'orchid')
  (9, 'rose')]
 [(16, 'javascript') (1, ' silver') (13, 'shelf') (13, 'shelf')
  (6, ' 90') (13, 'shelf') (1, ' white') (1, ' blue') (1, ' silver')
  (13, 'shelf') (16, 'ruby') (16, 'java') (6, ' 71') (13, 'chair')
  (13, 'desk') (16, 'java')]
 [(10, 'violin') (16, 'javascript') (6, ' 65') (16, 'java') (6, '60')
  (11, 'bas