|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Identifying circuits and components<h1>|
|<h2>Lecture:</h2>|<h1><b>Isolating and investigating attention heads<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import scipy.stats as stats

import torch
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Tokenizer

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
model.eval()

In [None]:
model.config

In [None]:
# some helpful variables
nheads = model.config.n_head
head_dim = model.config.n_embd // nheads
sqrtD = torch.sqrt(torch.tensor(head_dim)) # used for attention equation

print(f'There are {nheads} heads, each with {head_dim} dimensions.')

# Implanting a hook in the model

In [None]:
# hook the query vectors
activations = {}

def implant_hook(layer_number):
  def hook(module, input, output):
    activations[keyName] = output.detach()
  return hook

# implant the hooks
whichlayer = 6
keyName = f'attn_{whichlayer}'
model.h[whichlayer].attn.c_attn.register_forward_hook(implant_hook(whichlayer))

# Forward pass and get activations

In [None]:
# https://en.wikipedia.org/wiki/Fiji
txt = "The majority of Fiji's islands were formed by volcanic activity starting around 150 million years ago. Some geothermal activity still occurs today on the islands of Vanua Levu and Taveuni."

# tokenize
tokens = tokenizer.encode(txt,return_tensors='pt')
ntokens = len(tokens[0])

# run through the model
with torch.no_grad():
  model(tokens)

In [None]:
# checking sizes
print(activations.keys())
print(activations[keyName].shape)

# Split into heads

In [None]:
# first, separate the Q,K,V matrices
Q,K,V = torch.split(activations[keyName][0,:,:],model.config.n_embd,dim=1)
Q.shape

In [None]:
# now split into heads
Q_h = torch.split(Q,head_dim,dim=1)

print(f'There are {len(Q_h)} heads')
print(f'Each head has size {Q_h[2].shape}')

In [None]:
# repeat for the keys
K_h = torch.split(K,head_dim,dim=1)

# Means and standard deviations of different heads over tokens

In [None]:
# averages and standard deviations over all vectors and all tokens
Qh_means = np.array([ q[1:].mean() for q in Q_h ])
Qh_stds  = np.array([ q[1:].std()  for q in Q_h ])

Kh_means = np.array([ k[1:].mean() for k in K_h ])
Kh_stds  = np.array([ k[1:].std()  for k in K_h ])

In [None]:
# visualizations
_,axs = plt.subplots(1,3,figsize=(12,3.5))

axs[0].errorbar(np.arange(nheads)-.2,Qh_means,Qh_stds,fmt='ks',markerfacecolor=[.7,.9,.7],label='Q')
axs[0].errorbar(np.arange(nheads)+.2,Kh_means,Kh_stds,fmt='ko',markerfacecolor=[.9,.7,.7],label='K')
axs[0].set(xlabel='Heads',ylabel='Activation',title='Means and stds')
axs[0].legend()

axs[1].plot(Qh_means,Kh_means,'kh',markerfacecolor=[.7,.7,.7],markersize=10)
axs[1].set(xlabel='Q heads',ylabel='K heads',title='Average activations')

axs[2].plot(Qh_stds,Kh_stds,'k^',markerfacecolor=[.7,.7,.7],markersize=10)
axs[2].set(xlabel='Q heads',ylabel='K heads',title='Standard deviations')

plt.suptitle(f'Descriptives of attention head activations in layer {whichlayer}',fontweight='bold')
plt.tight_layout()
plt.show()

# Attention scores over tokens

In [None]:
# initializations
samehead_dp = np.array([])
diffhead_dp = np.array([])


# loop over pairs of heads
for qi in range(nheads):
  for ki in range(nheads):

    # dot product for last token in Q with all previous tokens in K (excluding first token)
    dp = Q_h[qi][-1,:] @ K_h[ki][1:-1,:].t()

    # store in the appropriate matrix
    if qi==ki:
      samehead_dp = np.concatenate((samehead_dp,dp))
    else:
      diffhead_dp = np.concatenate((diffhead_dp,dp))



## visualizations
_,axs = plt.subplots(1,2,figsize=(10,4))

# plot the raw data
axs[0].plot(np.random.randn(len(samehead_dp))/60 - .1,samehead_dp,'ko',markerfacecolor=[.7,.9,.7,.7],markersize=8)
axs[0].plot(np.random.randn(len(diffhead_dp))/60 + .1,diffhead_dp,'ks',markerfacecolor=[.9,.7,.7,.7],markersize=8)
axs[0].axhline(0,linestyle='--',color=[.7,.7,.7],zorder=-3)
axs[0].set(xticks=[-.1,.1],xticklabels=['Same head','Diff heads'],
              ylabel='QK$^T$ dot products',title='Raw attention scores',xlim=[-.3,.3])

# distributions
y,x = np.histogram(samehead_dp,bins=30)
axs[1].plot(x[:-1],y/y.max(),'g',linewidth=2,label='Same head')

y,x = np.histogram(diffhead_dp,bins=80)
axs[1].plot(x[:-1],y/y.max(),'r',linewidth=2,label='Diff heads')

axs[1].legend()
axs[1].set(xlabel='Dot product value',ylabel='Proportion (norm.)',title='Distributions')
axs[1].axvline(0,linestyle='--',color=[.7,.7,.7])

plt.tight_layout()
plt.show()

# Softmax-transformed attention values within each head

In [None]:
# empty initializations
final2prev = np.array([])
selfAttend = np.array([])
firstSelf  = np.array([])


# loop over heads
for qi in range(nheads):

  # raw attention scores with mask
  attn_scores = (Q_h[qi] @ K_h[qi].t()) / sqrtD
  pastmask = torch.tril(torch.ones(ntokens,ntokens))
  attn_scores[pastmask==0] = -torch.inf

  # softmax
  attn_sm = F.softmax( attn_scores ,dim=-1)

  # the final token with all previous tokens (including the first but excluding self-attn)
  final_with_prev = attn_sm[-1,:-1]

  # matching tokens are self-attention
  matching_toks = torch.diag(attn_sm[1:,1:]) # exclude the first token in the sequence
  first_selfTok = attn_sm[0,0].unsqueeze(0)  # isolate the first token

  # add to dataset
  final2prev = np.concatenate((final2prev,final_with_prev))
  selfAttend = np.concatenate((selfAttend,matching_toks))
  firstSelf  = np.concatenate((firstSelf,first_selfTok))


## visualize
plt.figure(figsize=(10,4))

plt.plot(np.random.randn(len(final2prev))/70 - .1,final2prev,'ko',markerfacecolor=[.7,.9,.7,.7],markersize=8)
plt.plot(np.random.randn(len(selfAttend))/70 + .1,selfAttend,'ks',markerfacecolor=[.9,.7,.7,.7],markersize=8)
plt.plot(np.random.randn(len(firstSelf))/70  + .3,firstSelf,'ks',markerfacecolor=[.7,.7,.9,.7],markersize=8)

plt.gca().set(xticks=[-.1,.1,.3],ylabel='Softmax attention weight',xlim=[-.3,.5],
              xticklabels=['Final to\nprev','Self-attention\nother tokens','Self-attention\nfirst token'])

plt.show()

# Kernel density estimator of softmax-attention distribution

In [None]:
# dataset
sparseData = [-1.5,.4,.45,.5,1.4]

# high-resolution grid to estimate distribution
xgrid = np.linspace(-2,2,301)

# create the kde object with the data and a smoothing (bandwidth) parameter
kde = stats.gaussian_kde(sparseData,bw_method=.2)

# estimate the pdf at the x-grid points
y = kde(xgrid)


## visualize!
plt.figure(figsize=(8,4))

# plot the raw data
for d in sparseData:
  plt.plot([d,d],[0,1],'r')

# and the kde
plt.plot(xgrid,y,linewidth=2)

plt.gca().set(xlim=xgrid[[0,-1]],ylim=[0,None],xlabel='Data value',
              ylabel='Probability density estimate',title='Simple demo of KDE estimation')
plt.show()

In [None]:
# high-resolution grid to estimate distribution
smx = np.linspace(0,1,300)

# get a kernel density estimator object
kde = stats.gaussian_kde(final2prev)

# evaluate it at grid points
y = kde(smx)

# sample sizes differ, and we only care about the shape
y = y/y.max()

# and plot it!
plt.figure(figsize=(10,3))
plt.plot(smx,y,'g',linewidth=3,label='Final to previous')


## same for self-attention, but more compact :)
y = stats.gaussian_kde(selfAttend)(smx)
plt.plot(smx,y/y.max(),'r',linewidth=3,label='Self-attention')

plt.legend()
plt.gca().set(ylim=[-.001,.1],xlim=[0,1],title='KDE distribution of attention scores',
              xlabel='Softmax attention scores',ylabel='KDE probability (a.u.)')

plt.show()