|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating neurons and dimensions<h1>|
|<h2>Lecture:</h2>|<h1><b>Proper noun tuning in GPT2-medium<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
# run first to install and then restart
# !pip install -U datasets huggingface_hub fsspec

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

import statsmodels.api as sm

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, GPT2Tokenizer

from datasets import load_dataset

# vector plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# load GPT2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-medium')

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')

In [None]:
# use GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# move the model to the GPU
model = model.to(device)
model.eval();

# Hooks

In [None]:
model

In [None]:
# hooks
activations = {}

def implant_hook(layer_number):
  def hook(module, input, output):

    # store in the dictionary
    activations[f'mlp_{layer_number}'] = module.c_fc(input[0]).detach().cpu()
  return hook

# put hooks in all layers
for layeri in range(len(model.transformer.h)):
  model.transformer.h[layeri].mlp.register_forward_hook(implant_hook(layeri))

# Get wiki text and find the proper nouns

In [None]:
# get wiki text
text = load_dataset('wikitext','wikitext-2-raw-v1',split='test')

# get tokens from longer entries
tokens = torch.tensor([],dtype=torch.long)
for i in range(200):
  if len(text['text'][i])>50:
    tokens = torch.concatenate((tokens,tokenizer.encode(text['text'][i],return_tensors='pt')),dim=1)

num_tokens = torch.numel(tokens)
print(f'{num_tokens:,} tokens imported')

In [None]:
# Example of finding proper nouns
' Mike'.strip()[0].isupper()

In [None]:
# simple algorithm, may include false positives and false negatives

# initialize vector
isProperNoun = np.zeros(num_tokens,dtype=int)

# initialize first previous token
prevToken = tokenizer.decode(tokens[0,0])

# loop over all tokens
for ti in range(1,num_tokens):

  # current token
  thisToken = tokenizer.decode(tokens[0,ti])

  # check the previous token
  if len(thisToken)>2:

    # conditionals
    condA = thisToken.strip()[0].isupper() # first letter is upper-case (ignoring leading spaces)
    condB = not prevToken.endswith(('.','!','?',':'))  # previous token can't end in an end-of-sentence marker

    # if conditionals are met, this is likely a proper noun
    if condA and condB:
      isProperNoun[ti] = 1

  # update previous token
  prevToken = thisToken

In [None]:
# examine some proper nouns
w = np.where(isProperNoun)[0]

print(f'There are {len(w)} proper nouns, including:')
for i in range(50):
  print(f'{tokenizer.decode(tokens[0,w[i]])}')

In [None]:
# in context...
tokenizer.decode(tokens[0,w[1]-5:w[1]+5])

# Forward pass through the model to get the activations

In [None]:
seq_len = 1024

# push some data through
with torch.no_grad():
  model(tokens[0,:seq_len].to(device))

In [None]:
activations.keys(), activations['mlp_21'].shape

In [None]:
nneurons = activations['mlp_21'].shape[-1]

# Logistic regression exploration

In [None]:
# in one neuron as an illustration

# get the proper nouns in this sequence
propnounsBatch = np.where(isProperNoun[:seq_len])[0]

# get an equal-sample-size set of non-proper-noun tokens for comparison
comparisonTokens = np.array([c for c in np.arange(seq_len) if c not in propnounsBatch])
np.random.shuffle(comparisonTokens)
comparisonTokens = comparisonTokens[:len(propnounsBatch)]

# build a logistic model
X = np.hstack((activations['mlp_10'][0,comparisonTokens,33],activations['mlp_10'][0,propnounsBatch,33]))
y = np.hstack((np.zeros(len(propnounsBatch)),np.ones(len(propnounsBatch))))

result = sm.Logit(y,sm.add_constant(X)).fit()

# Print the summary, which includes coefficients, p-values, and confidence intervals.
print(result.summary())

# Linear classifier over all neurons in one layer

In [None]:
# sample of proper nouns
propnounsBatch = np.where(isProperNoun[:seq_len])[0]

# initialize matrix to store the classifier results
classifierResults = np.zeros((nneurons,2))

# pick an MLP block
whichLayer2use = '18'


# loop over neurons in this layer
for neuroni in range(nneurons):

  # build a model
  X = np.hstack((
      activations[f'mlp_{whichLayer2use}'][0,comparisonTokens,neuroni],
      activations[f'mlp_{whichLayer2use}'][0,propnounsBatch,neuroni])
      )
  y = np.hstack((np.zeros(len(propnounsBatch)),np.ones(len(propnounsBatch))))

  # run the model
  result = sm.Logit(y,sm.add_constant(X)).fit(disp=0)

  # extract the results (p-value and beta)
  classifierResults[neuroni,0] = result.pvalues[1]
  classifierResults[neuroni,1] = result.params[1]

In [None]:
# visualization of model significance and sign

# setup the figure
fig = plt.figure(figsize=(12,4))
gs = GridSpec(1,4,figure=fig)

ax0 = fig.add_subplot(gs[:3])
ax1 = fig.add_subplot(gs[3])

# find the negative and positive betas, and the supra-threshold results
negBetas = classifierResults[:,1]<0
posBetas = classifierResults[:,1]>0
pvalThresh = .05/nneurons # p<.05, Bonferroni-corrected
sigBetas = classifierResults[:,0] < pvalThresh


# positive significant betas
idx2plot = posBetas & sigBetas
ax0.plot(np.where(idx2plot)[0],classifierResults[idx2plot,1],'ro',markerfacecolor=[.7,.7,.7],label='Positive and sig.')
ax1.plot(classifierResults[idx2plot,1],-np.log(classifierResults[idx2plot,0]),'ro',markerfacecolor=[.7,.7,.7,.5])

# positive non-significant betas
idx2plot = posBetas & ~sigBetas
ax0.plot(np.where(idx2plot)[0],classifierResults[idx2plot,1],'rx',markersize=3,label='Positive and non-sig.')
ax1.plot(classifierResults[idx2plot,1],-np.log(classifierResults[idx2plot,0]),'rx',markersize=3)

# negative significant betas
idx2plot = negBetas & sigBetas
ax0.plot(np.where(idx2plot)[0],classifierResults[idx2plot,1],'go',markerfacecolor=[.7,.7,.7],label='Negative and sig.')
ax1.plot(classifierResults[idx2plot,1],-np.log(classifierResults[idx2plot,0]),'go',markerfacecolor=[.7,.7,.7,.5])

# negative non-significant betas
idx2plot = negBetas & ~sigBetas
ax0.plot(np.where(idx2plot)[0],classifierResults[idx2plot,1],'gx',markersize=3,label='Negative and non-sig.')
ax1.plot(classifierResults[idx2plot,1],-np.log(classifierResults[idx2plot,0]),'gx',markersize=3)

ax0.set(ylabel='Beta coefficient',xlabel='Neuron index',xlim=[-10,nneurons+9],
              title='Statistical parameters of proper noun classification')
ax0.legend(fontsize=8)


ax1.axhline(-np.log(pvalThresh),linestyle='--',color='b',label='Significance threshold')
ax1.set(xlabel='Beta coeff',ylabel='-log(p)',title='Betas by p-values')
ax1.legend(fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
# find the neuron with best classification

# largest positive beta that's also significant
maxBeta = np.max(classifierResults[sigBetas,1])
maxBetaNeuron = np.where(classifierResults[:,1]==maxBeta)[0][0]

# largest negative beta that's also significant
minBeta = np.min(classifierResults[sigBetas,1])
minBetaNeuron = np.where(classifierResults[:,1]==minBeta)[0][0]

maxBetaNeuron,minBetaNeuron

# Heatmap of neuron's activation in-sample

In [None]:
# min-max scale the activations for colormapping
posActs = activations[f'mlp_{whichLayer2use}'][0,:,maxBetaNeuron].squeeze()
posActsNorm = (posActs - posActs.min()) / (posActs.max()-posActs.min())

negActs = activations[f'mlp_{whichLayer2use}'][0,:,minBetaNeuron].squeeze()
negActsNorm = (negActs - negActs.min()) / (negActs.max()-negActs.min())

In [None]:
# get width of one letter
fig,ax = plt.subplots(figsize=(10,2))

# draw a text object
temp_text = ax.text(0,0,'n',fontsize=12,fontfamily='monospace')

# Get its bounding box in display coordinates
bbox = temp_text.get_window_extent(renderer=fig.canvas.get_renderer())

# convert from display to axis coordinates
inv = ax.transAxes.inverted()
bbox_axes = inv.transform([[bbox.x0,bbox.y0], [bbox.x1,bbox.y1]])
en_width = bbox_axes[1,0] - bbox_axes[0,0] # bbox is [(x0,y0),(x1,y1)]

plt.close(fig)

In [None]:
tokCount = 0

x_pos = 0  # starting x position (in axis coordinates)
y_pos = 1  # vertical center

fig, ax = plt.subplots(figsize=(10,2))
ax.axis('off')

for toki in range(seq_len//2):

  # text of this token
  toktext = tokenizer.decode([tokens[0,toki]])

  # width of the token
  token_width = en_width*len(toktext)

  # text object with background color matching the activation
  ax.text(x_pos+token_width/2, y_pos, toktext, fontsize=12, ha='center', va='center',fontfamily='monospace',
          bbox = dict(boxstyle='round,pad=.3', facecolor=mpl.cm.Reds(posActsNorm[toki]**2), edgecolor='none', alpha=.8))

  # update the token counter and x_pos
  tokCount += 1
  x_pos += token_width + .01 # plus a small gap

  # end of the line; reset coordinates and counter
  if tokCount>=20:
    y_pos -= .17
    x_pos = 0
    tokCount = 0

plt.show()

# Process the next batch to examine generalization

In [None]:
# push some data through
with torch.no_grad():
  model(tokens[0,seq_len:seq_len*2].to(device))

In [None]:
# min-max scale the activations for colormapping
posActs = activations[f'mlp_{whichLayer2use}'][0,:,maxBetaNeuron].squeeze()
posActsNorm = (posActs - posActs.min()) / (posActs.max()-posActs.min())

In [None]:
tokCount = 0

x_pos = 0  # starting x position (in axis coordinates)
y_pos = 1  # vertical center

fig, ax = plt.subplots(figsize=(10,2))
ax.axis('off')

for toki in range(seq_len//2):

  # text of this token
  toktext = tokenizer.decode([tokens[0,seq_len+toki]])

  # width of the token
  token_width = en_width*len(toktext)

  # text object with background color matching the activation
  ax.text(x_pos+token_width/2, y_pos, toktext, fontsize=12, ha='center', va='center',fontfamily='monospace',
          bbox = dict(boxstyle='round,pad=.3', facecolor=mpl.cm.Reds(posActsNorm[toki]**2), edgecolor='none', alpha=.8))

  # update the token counter and x_pos
  tokCount += 1
  x_pos += token_width + .01 # plus a small gap

  # end of the line; reset coordinates and counter
  if tokCount>=20:
    y_pos -= .17
    x_pos = 0
    tokCount = 0

plt.show()