|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating layers<h1>|
|<h2>Lecture:</h2>|<h1><b>Token-related similarities within and across Q, K, V matrices<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer

# vector matplotlib
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# load GPT2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

model.eval()

# variable for the number of embedding dimensions
nEmb = model.config.n_embd

# Hook all layers to get Q, K, and V activations

In [None]:
# Define a hook function to store QVK vectors
activations = {}

def implant_hook(layer_number):
  def hook(module, input, output):
    activations[f'attn_{layer_number}_qvk'] = output.detach().numpy()
  return hook


# surgery ;)
for layeri in range(model.config.n_layer):
  model.transformer.h[layeri].attn.c_attn.register_forward_hook(implant_hook(layeri))

# Sentences with "her" as target

In [None]:
# generated by Claude.ai
sentences = [
    "I saw her at the market.",
    "She gave her the book.",
    "They asked her for advice.",
    "We invited her to dinner.",
    "The dog followed her home.",
    "They asked her to join.",
    "He saw her at the park yesterday.",
    "Did you give her your address?",
    "I haven't seen her in ages.",
    "I told her the truth.",
    "They congratulated her on his success.",
    "She recognized her immediately.",
    "The teacher praised her for his work.",
    "I met her last summer.",
    "The child hugged her tightly.",
    "They warned her about the danger.",
    "She drove her to the airport.",
    "We waited for her for hours.",
    "The cat scratched her accidentally.",
    "They surprised her with a gift.",
    "She called her on the phone.",
    "The jury found her not guilty.",
    "I remembered her from school.",
    "They elected her as president.",
    "She forgave her for his mistake.",
    "The police questioned her yesterday.",
    "I helped her with his homework.",
    "They spotted her in the crowd.",
    "She visited her in the hospital.",
    "The manager promoted her last week.",
    "I trusted her completely.",
    "They respected her for his honesty.",
    "She taught her how to swim.",
    "The bird attacked her suddenly.",
    "I greeted her warmly.",
    "They supported her through difficult times.",
    "She ignored her at the party.",
    "The judge sentenced her to community service.",
    "I photographed her during the event.",
    "They believed her despite the evidence.",
    "She surprised her on his birthday.",
    "The guard stopped her at the entrance.",
    "I missed her terribly.",
    "They watched her leave the building.",
    "She accompanied her to the concert.",
    "The crowd cheered her enthusiastically.",
    "I described her to the police.",
    "They thanked her for his help.",
    "She admired her for his courage.",
    "The committee nominated her for the award.",
    "I married her last spring.",
    "They informed her about the changes.",
    "She introduced her to the parents.",
    "The author based the character on her."
]

print(f'There are {len(sentences)} sentences.')

In [None]:
# need to specify a padding token
tokenizer.pad_token = tokenizer.eos_token

# tokenize
tokens = tokenizer(sentences,padding=True,return_tensors='pt')
tokens

In [None]:
# push through the model
with torch.no_grad(): model(**tokens)

activations.keys(),activations['attn_5_qvk'].shape

# Get activations from all attention neurons for target and non-target tokens

In [None]:
# just one layer for this script
layeri = 6

In [None]:
# loop through sentences to get target activations
target_token = tokenizer.encode(' her')[0]

# initialize
actsAll_trg = np.zeros((len(sentences),nEmb*3))
actsAll_non = np.zeros((len(sentences),nEmb*3))


# loop over sentences b/c target position varies
for senti in range(len(sentences)):

  # find the index of the target token (convert to list, then .index to find)
  targidx = tokens['input_ids'][senti].tolist().index(target_token)

  # TARGET get the activation for this token
  actsAll_trg[senti,:] = activations[f'attn_{layeri}_qvk'][senti,targidx,:]

  # NON-TARGET get the activation for this token (note the -1 after targidx!)
  actsAll_non[senti,:] = activations[f'attn_{layeri}_qvk'][senti,targidx-1,:]

actsAll_trg.shape

# Plot a few pairs of vectors

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3.5))

axs[0].plot(actsAll_trg[1,:nEmb],actsAll_trg[40,:nEmb],'ko',markerfacecolor=[.7,.9,.7,.5])
axs[0].set(title='Different tokens, same Q',xlabel='Sentence 1 target token',ylabel='Sentence 40 target token')

axs[1].plot(actsAll_trg[1,nEmb:nEmb*2],actsAll_trg[40,nEmb:nEmb*2],'ko',markerfacecolor=[.7,.7,.9,.5])
axs[1].set(title='Different tokens, same K',xlabel='Sentence 1 target token',ylabel='Sentence 40 target token')

axs[2].plot(actsAll_trg[:,123],actsAll_trg[:,700],'ks',markerfacecolor=[.9,.7,.7,.5])
axs[2].set(title='All tokens, two Qs',xlabel='Q$_{123}$ activation',ylabel='Q$_{700}$ activation')


plt.tight_layout()
plt.show()

# Calculate cosine similarity matrices

In [None]:
# TARGET
actsAllNorm = actsAll_trg / np.linalg.norm(actsAll_trg,axis=0,keepdims=True)
cossim_trg = actsAllNorm.T @ actsAllNorm

# NON-TARGET
actsAllNorm = actsAll_non / np.linalg.norm(actsAll_non,axis=0,keepdims=True)
cossim_non = actsAllNorm.T @ actsAllNorm

# some quick visualizations
_,axs = plt.subplots(1,2,figsize=(12,3))
axs[0].imshow(cossim_trg,vmin=-.4,vmax=.4)
axs[1].hist(cossim_trg[np.nonzero(np.triu(cossim_trg,1))],bins=100)

plt.tight_layout()
plt.show()

# Or analyze the correlation coefficients

In [None]:
# TARGET
actsAllNorm = actsAll_trg - np.mean(actsAll_trg,axis=0,keepdims=True)
actsAllNorm = actsAllNorm / np.linalg.norm(actsAllNorm,axis=0,keepdims=True)
cossim_trg = actsAllNorm.T @ actsAllNorm

# NON-TARGET
actsAllNorm = actsAll_non - np.mean(actsAll_non,axis=0,keepdims=True)
actsAllNorm = actsAllNorm / np.linalg.norm(actsAllNorm,axis=0,keepdims=True)
cossim_non = actsAllNorm.T @ actsAllNorm


# some quick visualizations
_,axs = plt.subplots(1,2,figsize=(12,3))
axs[0].imshow(cossim_trg,vmin=-.4,vmax=.4)
axs[1].hist(cossim_trg[np.nonzero(np.triu(cossim_trg,1))],bins=100)

plt.tight_layout()
plt.show()

In [None]:
# RANDOM (from shuffled matrix)

# shuffle the target activations matrix
randindices = np.random.permutation(np.prod(actsAll_trg.shape)) # get permuted indices
actsRand_flat = actsAll_trg.flatten()[randindices] # randomize the vectorized matrix
actsAll_trgRand = actsRand_flat.reshape(actsAll_trg.shape) # reshape back to 2D

# normalize and get cossim
actsAllNorm = actsAll_trgRand / np.linalg.norm(actsAll_trgRand,axis=0,keepdims=True)
cossim_trgRand = actsAllNorm.T @ actsAllNorm
randcs_trg = cossim_trgRand[np.nonzero(np.triu(cossim_trgRand,1))]

# Create a matrix mask to extract within- and across-matrix cosine similarities

In [None]:
# unique values for each layer type
qLoc = 1
kLoc = 2
vLoc = 3

# a vector mask
vectorMask = np.concatenate( (np.full(nEmb,1),np.full(nEmb,2),np.full(nEmb,3)) )

# outer product to create a matrix with unique values for each interaction
matrixMask = vectorMask[:,None] @ vectorMask[None,:]
matrixMask = np.triu(matrixMask,1)

# illustration
_,axs = plt.subplots(1,2,figsize=(12,3.5))
axs[0].plot(vectorMask,'ks',markerfacecolor='w',alpha=.5)
axs[0].set(title='Vector mask',xlabel='Dimension index (Q+K+V)',ylabel='Layer type')

axs[1].imshow(matrixMask)
axs[1].set(title='Matrix mask',xlabel='Indices',ylabel='Indices')

plt.tight_layout()
plt.show()

In [None]:
# Now we can access the matrix interactions by finding specific numerical values in the mask,
# e.g., Q-K interactions have mask index qLoc*kLoc = 2

print('Unique elements in the matrix mask: ',np.unique(matrixMask.flatten()))

# Extract unique similarities per group

In [None]:
### TARGET
# extract the elements
QQcs_trg = cossim_trg[matrixMask==qLoc*qLoc]
KKcs_trg = cossim_trg[matrixMask==kLoc*kLoc]
VVcs_trg = cossim_trg[matrixMask==vLoc*vLoc]

# cross-terms
QKcs_trg = cossim_trg[matrixMask==qLoc*kLoc]
QVcs_trg = cossim_trg[matrixMask==qLoc*vLoc]
KVcs_trg = cossim_trg[matrixMask==kLoc*vLoc]


### NON-TARGET
# extract the elements
QQcs_non = cossim_non[matrixMask==qLoc*qLoc]
KKcs_non = cossim_non[matrixMask==kLoc*kLoc]
VVcs_non = cossim_non[matrixMask==vLoc*vLoc]

# cross-terms
QKcs_non = cossim_non[matrixMask==qLoc*kLoc]
QVcs_non = cossim_non[matrixMask==qLoc*vLoc]
KVcs_non = cossim_non[matrixMask==kLoc*vLoc]


QKcs_trg.shape

# Generate histograms for target and non-target

In [None]:
# all use the same bins
edges = np.linspace(-1,1,101)
edgesX = (edges[:-1]+edges[1:])/2

### TARGET
# within-matrix histograms
yQQ_trg,_ = np.histogram(QQcs_trg,bins=edges,density=True)
yKK_trg,_ = np.histogram(KKcs_trg,bins=edges,density=True)
yVV_trg,_ = np.histogram(VVcs_trg,bins=edges,density=True)

# and between-matrix
yQK_trg,_ = np.histogram(QKcs_trg,bins=edges,density=True)
yQV_trg,_ = np.histogram(QVcs_trg,bins=edges,density=True)
yKV_trg,_ = np.histogram(KVcs_trg,bins=edges,density=True)


### NON-TARGET
# within-matrix histograms
yQQ_non,_ = np.histogram(QQcs_non,bins=edges,density=True)
yKK_non,_ = np.histogram(KKcs_non,bins=edges,density=True)
yVV_non,_ = np.histogram(VVcs_non,bins=edges,density=True)

# and between-matrix
yQK_non,_ = np.histogram(QKcs_non,bins=edges,density=True)
yQV_non,_ = np.histogram(QVcs_non,bins=edges,density=True)
yKV_non,_ = np.histogram(KVcs_non,bins=edges,density=True)


### Random
yRand_trg,_ = np.histogram(randcs_trg,bins=edges,density=True)

# Plotting!

In [None]:
# create the figure
_,axs = plt.subplots(2,2,figsize=(14,6))


### TARGET within-matrix
axs[0,0].plot(edgesX,yQQ_trg,label='QQ',linewidth=2)
axs[0,0].plot(edgesX,yKK_trg,label='KK',linewidth=2)
axs[0,0].plot(edgesX,yVV_trg,label='VV',linewidth=2)
axs[0,0].plot(edgesX,yRand_trg/3,label='Random',linewidth=2,color='gray',linestyle='--')
axs[0,0].set(xlim=[-1,1],xticks=[],ylabel='Density',ylim=[0,None],title='TARGET within-matrix')


### TARGET across-matrix
axs[0,1].plot(edgesX,yQK_trg,label='QK',linewidth=2)
axs[0,1].plot(edgesX,yQV_trg,label='QV',linewidth=2)
axs[0,1].plot(edgesX,yKV_trg,label='KV',linewidth=2)
axs[0,1].set(xlim=[-1,1],xticks=[],ylabel='Density',ylim=[0,None],title='TARGET across-matrix')



### NON-TARGET within-matrix
axs[1,0].plot(edgesX,yQQ_non,label='QQ',linewidth=2)
axs[1,0].plot(edgesX,yKK_non,label='KK',linewidth=2)
axs[1,0].plot(edgesX,yVV_non,label='VV',linewidth=2)
axs[1,0].set(xlim=[-1,1],xlabel='Cosine similarity',ylabel='Density',ylim=[0,None],title='NON-TARGET within-matrix')


### NON-TARGET across-matrix
axs[1,1].plot(edgesX,yQK_non,label='QK',linewidth=2)
axs[1,1].plot(edgesX,yQV_non,label='QV',linewidth=2)
axs[1,1].plot(edgesX,yKV_non,label='KV',linewidth=2)
axs[1,1].set(xlim=[-1,1],xlabel='Cosine similarity',ylabel='Density',title='NON-TARGET across-matrix',ylim=[0,None])


# activate and adjust legends
for a in axs.flatten(): a.legend(fontsize=10)

plt.tight_layout()
plt.show()

# Some additional explorations...

In [None]:
# plot token average activations
plt.figure(figsize=(10,4))

plt.plot(range(nEmb),actsAll_trg[:,:nEmb].mean(axis=0),'ko',markerfacecolor=[.9,.7,.7,.5],label='Q')
plt.plot(range(nEmb,nEmb*2),actsAll_trg[:,nEmb:nEmb*2].mean(axis=0),'ks',markerfacecolor=[.7,.9,.7,.5],label='K')
plt.plot(range(nEmb*2,nEmb*3),actsAll_trg[:,nEmb*2:].mean(axis=0),'kv',markerfacecolor=[.7,.7,.9,.5],label='V')

plt.gca().set(xlim=[-5,nEmb*3+4],xlabel='Neurons',ylabel='Token average activation')
plt.legend()
plt.show()

In [None]:
# curious clustering in cosine similarities
plt.figure(figsize=(10,4))
plt.plot(QQcs_trg[:20000],'k.',markersize=2)

plt.gca().set(xlim=[-5,20004],xlabel='Q pairs',ylabel='Cosine similarity',title='Target-related cosine similarity in Q-Q pairs')
plt.show()