<a href="https://colab.research.google.com/github/sc22lg/ML-Notebooks/blob/paper_recreation/semantic_attention_recreation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## A Recreation of the Results of: The Self-Hating Attention Head: A Deep Dive in GPT-2 - Matteo Migliarini July 2025
by Leo Gott

Original publication can be found [here](https://www.lesswrong.com/posts/wxPvdBwWeaneAsWRB/the-self-hating-attention-head-a-deep-dive-in-gpt-2-1)

### Overall idea:
"gpt2-small's head L1H5 directs attention to semantically similar tokens and actively suppresses self-attention"
### Results to re-create:
- Create inputs to ellicit expected behaviour
- Use inputs to identify heads performing behaviour in gpt2-small (expected head L1H5)
- Perform mean-ablation of preceding components to find which components effect L1H5

### Setup:

In [1]:
import os
import sys
from pathlib import Path
import functools

import pkg_resources

installed_packages = [pkg.key for pkg in pkg_resources.working_set]
if "transformer-lens" not in installed_packages:
    %pip install transformer_lens==2.11.0 einops eindex-callum jaxtyping git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

import pandas as pd
import circuitsvis as cv
import einops
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int
from torch import Tensor
from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)
from transformer_lens.hook_points import HookPoint

  import pkg_resources


In [2]:
import random as rand
import plotly.express as px
from IPython.display import display

### 1.1 Generate input prompt

In [3]:
semantic_words_file = pd.read_csv('semantic_words.csv', header=None)
print(semantic_words_file.to_string())

        0        1          2         3          4          5
0  Monday  Tuesday  Wednesday  Thursday     Friday   Saturday
1     red     blue      green    silver      white       Blue
2    1918     1920       1930      1943       1998       2000
3     You       He        his       she        her      their
4   Italy  Iceland    Austria    Mexico      Spain     France
5     dog      cat      horse      bird       fish     lizard
6      60       65         69        70         71         90
7     car      bus        van     truck  motorbike  aeroplane
8   chair    table       sofa      desk        bed      shelf
9   river     lake      ocean       sea       pond     stream


In [4]:
# Create shuffled list of tokens
n_sequences = 30
n_tokens = 12
n_rows = semantic_words_file.shape[0]

inputs = np.empty((n_sequences, n_tokens), dtype=tuple)
seed:int = 5 # seeds that work: 5
rnd = np.random.RandomState(seed)

for i in range(n_sequences):
  subset = semantic_words_file.sample(4, random_state = rnd)
  for j in range(n_tokens):
    category_list = subset.sample(1, random_state = rnd)
    category = category_list.index[0]
    token = category_list.iloc[0].sample(1, random_state = rnd).values[0]
    inputs[i, j] = (category, token)
# print(inputs)

### 1.2 Create Masks

In [5]:
#create masks representing where tokens in an input share a category
masks = np.zeros((n_sequences, n_tokens, n_tokens))

for seq in range(n_sequences):
  for i in range(n_tokens):
    for j in range(n_tokens):
      if inputs[seq, i][0] == inputs[seq, j][0] and inputs[seq, i][1] != inputs[seq, j][1] and i > j: # ensures upper triangle is 0s
        masks[seq, i, j] = 1

In [6]:
show_mask = 1
fig = px.imshow(masks[show_mask], labels=dict(x="Token Index", y="Token Index", color="Same Category")) # Added labels dictionary
fig.update_layout(xaxis = dict(
        tickmode = 'array',
        tickvals = list(range(n_tokens)),
        ticktext = [inputs[show_mask, i][1] for i in range(n_tokens)] # Use tokens from inputs for x-axis
    ),
    yaxis = dict(
        tickmode = 'array',
        tickvals = list(range(n_tokens)),
        ticktext = [inputs[show_mask, i][1] for i in range(n_tokens)] # Use tokens from inputs for y-axis
    )
)
fig.show()

### 2.1 Load & test gpt2-small:

In [7]:
model = HookedTransformer.from_pretrained("gpt2-small", device="cpu")
print(model.cfg)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer
HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': 'cpu',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type

In [8]:
sequence_index = show_mask
string_inputs = []
token_inputs = []

#convert sequences to strings then tokens
for seq in range(n_sequences):
  string_inputs.append(' '.join([inputs[seq, i][1] for i in range(n_tokens)]))
  tokens = model.to_tokens(string_inputs[seq])
  token_inputs.append(tokens)

print(token_inputs)

#remove invalid token sequences and masks by iterating in reverse
for i in range(len(token_inputs) - 1, -1, -1):
  if(len(token_inputs[i][0]) != n_tokens+1):
    del token_inputs[i]
    del string_inputs[i]
    masks = np.delete(masks, i, axis=0) # Delete the corresponding mask


#print(len(token_inputs))
for seq in token_inputs:
  print(len(seq.squeeze()))

#run model
input_tokens = token_inputs[sequence_index]
logits, cache = model.run_with_cache(input_tokens)

[tensor([[50256, 30527,  5828,  7850,  3797, 17333,  7850, 15533,  7850,  6512,
         13546,  7795,  6512]]), tensor([[50256, 14573,   607,   511,  4171,  6915,   679,  3431,  5118,  3996,
           921,   679, 18316]]), tensor([[50256,  7718,  5916,  9166,  1323,  9166,  3126,  7779,  1097,  9151,
          7850,  5916, 42406]]), tensor([[50256, 11024,  8031, 17322,  4101,  8602,  9166,  4881,  9166,  8602,
          6135,  5118,  8602]]), tensor([[50256,  1639,  3797,   673,  4881,  3126,  5828,   607, 17322, 42406,
          8602,   921,  5916]]), tensor([[50256, 11024,  4751,  5828,  5118,  3084, 14062, 21577,  3084, 18316,
         25859,  8031,  4751]]), tensor([[50256, 45001,  8031,  3321,  5828,  4881, 21577,  3431,  3583,  8644,
          7795,  3321, 17333]]), tensor([[50256, 19844, 17322,  8644,  3635,  8223,  3797, 42406,  3583,  8644,
          5916, 42406,  9166]]), tensor([[50256, 17585,  3290,  2330,  4518,  3797,  4751,  2330,  5916, 42406,
          5118,  7795, 2

### 2.2 Display Attention Patterns

In [9]:
layer_number = 1
layer1_patterns = cache["pattern", layer_number]
print(layer1_patterns.shape)
print(input_tokens.shape)
print(input_tokens.squeeze())
str_tokens = model.to_str_tokens(input_tokens.squeeze())

display(
    cv.attention.attention_patterns(
        tokens=str_tokens,
        attention=layer1_patterns.squeeze(),
        attention_head_names=[f"L{layer_number}H{i}" for i in range(12)],
    )
)

torch.Size([1, 12, 13, 13])
torch.Size([1, 13])
tensor([50256, 14573,   607,   511,  4171,  6915,   679,  3431,  5118,  3996,
          921,   679, 18316])


### 2.3 Pattern detector:

In [10]:
print(f"layers: {model.cfg.n_layers}")
print(f"heads per layer: {model.cfg.n_heads}")

layers: 12
heads per layer: 12


In [11]:
def evaluate_semantic_head(cache: ActivationCache, layer:int, head:int):
  attention_pattern = cache["pattern", layer].squeeze()[head]
  expected_attention = attention_pattern[1:,1:][t.from_numpy(masks[sequence_index]).bool()] # slice attention to ignore |endoftext| and then apply mask
  return t.mean(expected_attention)

def semantic_head_detector(cache: ActivationCache):
  scores = np.zeros((model.cfg.n_layers, model.cfg.n_heads))

  # calculate attention score for current input for each attention head
  for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
      scores[layer, head] = evaluate_semantic_head(cache, layer, head)
  return scores


In [12]:
attn_head_scores = semantic_head_detector(cache)
fig = px.imshow(attn_head_scores, labels=dict(x="Head", y="Layer", color="Attention Score"))
fig.show()

## 3.1 Measuring Component importance:
In order to measure which components are contributing to the function of L1H5, we perform a mean ablation study. We replace each previous compenent of the network with its mean and re-measure the attention score of our metric for L1H5 to see if it was impacted by the ablation of that component. The components which when ablated has the most effect on the performance of L1H5 are the components which contribute to its function.

In [13]:
pos_embed_name = utils.get_act_name("pos")
print(model.W_pos.size())
print(model.hook_dict) # print hookpoint names

torch.Size([1024, 768])
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.ln1.hook_scale': HookPoint(), 'blocks.0.ln1.hook_normalized': HookPoint(), 'blocks.0.ln2.hook_scale': HookPoint(), 'blocks.0.ln2.hook_normalized': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'bl

In [14]:
batch = 1
heads = model.cfg.n_heads
seq_len = n_tokens

### Example Hook Funciton

In [15]:
def hook_function(
    attn_pattern: Float[Tensor, "batch heads seq_len seq_len"],
    hook: HookPoint
) -> Float[Tensor, "batch heads seq_len seq_len"]:

    # modify attn_pattern (can be inplace)
    return attn_pattern

In [16]:
loss = model.run_with_hooks(
    input_tokens,
    return_type="loss",
    fwd_hooks=[
        ('blocks.1.attn.hook_pattern', hook_function)
    ]
)

### 3.2 Ablation hooks

In [28]:
ablated_head_scores = t.zeros(1)

def eval_head(attention_pattern, mask_index):
  expected_attention = attention_pattern[1:,1:][t.from_numpy(masks[mask_index]).bool()] # slice attention to ignore |endoftext| and then apply mask
  return t.mean(expected_attention)

#evaluate attention score @ head 1,5
def eval_attention_hook(
    attn_pattern: Float[Tensor, "batch heads seq_len seq_len"],
    hook: HookPoint,
    mask_index: int,
) -> Float[Tensor, "batch heads seq_len seq_len"]:
    global ablated_head_scores
    eval_score = eval_head(attn_pattern.squeeze()[5], mask_index)
    ablated_head_scores = t.cat((ablated_head_scores, eval_score.unsqueeze(dim=0)))
    return attn_pattern

#perform mean ablation on embedding layer
def embedding_ablation_hook(
    embed_output: Float[t.Tensor, "n_ctx d_model"],
    hook: HookPoint,
):
  # Ablate by replacing with the mean
  embed_output[:] = embed_output.mean(dim=1, keepdim=True)
  return embed_output

#perform mean ablation on attention layer
def attention_ablation_hook(
    attn_pattern: Float[Tensor, "batch heads seq_len seq_len"],
    hook: HookPoint
) -> Float[Tensor, "batch heads seq_len seq_len"]:

  attn_pattern[:] = attn_pattern.mean(dim=1, keepdim=True)
  return attn_pattern

#perform mean ablation on mlp layer
def mlp_ablation_hook(
    mlp_output: Float[t.Tensor, "n_ctx d_model"],
    hook: HookPoint,
) -> Float[Tensor, "n_ctx d_model"]:
  # Ablate by replacing with the mean
  mlp_output[:] = mlp_output.mean(dim=1, keepdim=True)
  return mlp_output


test_index = sequence_index
temp_hook_fn = functools.partial(eval_attention_hook, mask_index = test_index)
#run with no ablation
model.run_with_hooks(input_tokens, return_type = None, fwd_hooks = [('blocks.1.attn.hook_pattern', temp_hook_fn)], reset_hooks_end = True, clear_contexts=True)

for i in range(len(token_inputs)):
  temp_hook_fn = functools.partial(eval_attention_hook, mask_index = i) # update with new mask
  #run with embedding ablation
  model.run_with_hooks(token_inputs[i], return_type = None, fwd_hooks = [('blocks.1.attn.hook_pattern', temp_hook_fn), (utils.get_act_name('embed'), embedding_ablation_hook)], reset_hooks_end = True, clear_contexts=True)
  #run with positional embedding ablation
  model.run_with_hooks(token_inputs[i], return_type = None, fwd_hooks = [('blocks.1.attn.hook_pattern', temp_hook_fn), ('hook_pos_embed', embedding_ablation_hook)], reset_hooks_end = True, clear_contexts=True)
  #run with L0 attention ablation
  model.run_with_hooks(token_inputs[i], return_type = None, fwd_hooks = [('blocks.1.attn.hook_pattern', temp_hook_fn), ('blocks.0.attn.hook_pattern', attention_ablation_hook)], reset_hooks_end = True, clear_contexts=True)
  #run with L0 MLP ablation
  model.run_with_hooks(token_inputs[i], return_type = None, fwd_hooks = [('blocks.1.attn.hook_pattern', temp_hook_fn), ('blocks.0.hook_mlp_out', mlp_ablation_hook)], reset_hooks_end = True, clear_contexts=True)

ablated_head_scores = ablated_head_scores[1:]
print(ablated_head_scores)

tensor([0.1877, 0.0933, 0.3569, 0.2782, 0.1141, 0.0966, 0.2191, 0.2238, 0.1089,
        0.1107, 0.2735, 0.2908, 0.1198, 0.0918, 0.3223, 0.2794, 0.1215, 0.0982,
        0.2669, 0.2552, 0.1067, 0.0959, 0.2157, 0.1943, 0.1053, 0.0993, 0.3405,
        0.3234, 0.1225, 0.0943, 0.3120, 0.2834, 0.1158, 0.1094, 0.2823, 0.2947,
        0.1208, 0.1015, 0.3436, 0.4109, 0.1172, 0.0999, 0.2907, 0.2728, 0.1136,
        0.0964, 0.2959, 0.2907, 0.1104, 0.0900, 0.4442, 0.3978, 0.1149, 0.0901,
        0.2202, 0.2069, 0.1169, 0.0878, 0.3197, 0.3383, 0.1130, 0.0845, 0.3877,
        0.3588, 0.1075, 0.1084, 0.2127, 0.2140, 0.1149, 0.0897, 0.2872, 0.2757,
        0.1100, 0.0934, 0.3560, 0.3313, 0.1282, 0.0963, 0.2805, 0.2767, 0.1182,
        0.1006, 0.2345, 0.1808, 0.1091, 0.1076, 0.3343, 0.3587, 0.1173],
       grad_fn=<SliceBackward0>)
