|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating layers<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Attention to coffee: MI and token distances<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# stats library for kendall correlation (when one variable is ordinal [sorted categorical])
import scipy.stats as stats
from statsmodels.stats.multitest import fdrcorrection

from sklearn.feature_selection import mutual_info_regression

import torch

# vector matplotlib
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Model, hooks, tokens, & activations

In [None]:
# load pretrained GPT-2 model and tokenizer
from transformers import AutoModelForCausalLM,GPT2Tokenizer
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2-xl')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# variable for the number of transformer layers
nLayers = len(gpt2.transformer.h)

gpt2.eval()

In [None]:
# hook function to store attention vectors (see also part5_neuron_hookVsHiddenStates.ipynb)
activations = {}

def implant_hook_attn(layer_number):
  def hook(module, input, output):
    activations[f'att_proj_{layer_number}'] = output.detach().numpy()
  return hook

# and mlp layers
def implant_hook_mlp(layer_number):
  def hook(module, input, output):
    activations[f'mlp_proj_{layer_number}'] = output.detach().numpy()
  return hook

# implant hooks
for layeri in range(nLayers):
  gpt2.transformer.h[layeri].attn.c_proj.register_forward_hook(implant_hook_attn(layeri))
  gpt2.transformer.h[layeri].mlp.c_proj.register_forward_hook(implant_hook_mlp(layeri))

In [None]:
# from https://en.wikipedia.org/wiki/Turkish_coffee
text = 'Turkish coffee is very finely ground coffee brewed by boiling. Any coffee bean may be used; arabica varieties are considered best, but robusta or a blend is also used.[1] The coffee grounds are left in the coffee when served.[2][3] The coffee may be ground at home in a manual grinder made for the very fine grind, ground to order by coffee merchants in most parts of the world, or bought ready-ground from many shops.'
tokens = tokenizer.encode(text,return_tensors='pt')
print(f'There are {len(tokens[0])} tokens, {len(set(tokens[0].tolist()))} of which are unique.')

In [None]:
# find all the "coffee" target indices
target = ' coffee'
target_idxs = torch.where(tokens==tokenizer.encode(target)[0])[1]
target_idxs

In [None]:
# forward pass to get activations
with torch.no_grad():
  output = gpt2(tokens,output_hidden_states=True) # hs not analyzed, but extracted for inspiration ;)

In [None]:
activations['att_proj_3'].shape

# Exercise 2: Trimmed-MI function

In [None]:
# a function for mutual information
def mutInfo(x,y,outlierThresh=0):

  # remove outliers based on a z-score threshold
  if outlierThresh>0:

    # z-standardize the variables
    zx = (x-x.mean()) / x.std(ddof=1)
    zy = (y-y.mean()) / y.std(ddof=1)

    # remove data points based on threshold exceedances
    outlier = (abs(zx)>outlierThresh) | (abs(zy)>outlierThresh)
    x = x[~outlier]
    y = y[~outlier]


  # histogram and convert to proportion (estimate of probability)
  Z  = np.histogram2d(x,y,bins=15)[0]
  pZ = Z / Z.sum()
  px = pZ.sum(axis=1)
  py = pZ.sum(axis=0)

  # calculate entropy
  eps = 1e-12
  Hx = -np.sum( px * np.log2(px+eps) )
  Hy = -np.sum( py * np.log2(py+eps) )
  HZ = -np.sum( pZ * np.log2(pZ+eps) )

  return Hx+Hy - HZ

In [None]:
# extract some data
x = activations['att_proj_3'][0,target_idxs[0],:]
y = activations['att_proj_3'][0,target_idxs[1],:]

# z-transform
zx = (x-x.mean()) / x.std(ddof=1)
zy = (y-y.mean()) / y.std(ddof=1)

# identify outliers
threshold = 4
outlier = (abs(zx)>threshold) | (abs(zy)>threshold)

# mutual information with and without
miAll  = mutInfo(x,y)
miTrim = mutInfo(x,y,outlierThresh=threshold) # trimmed

miSkAll = mutual_info_regression(x.reshape(-1,1),y)[0]
miSkTrim = mutual_info_regression(x[~outlier].reshape(-1,1),y[~outlier])[0]

# plot
plt.figure(figsize=(6,5))
plt.plot(x[~outlier],y[~outlier],'ko',markerfacecolor=[.7,.9,.7,.3],label='Trimmed')
plt.plot(x[outlier],y[outlier],'ks',markerfacecolor=[.9,.7,.7],label='Outliers')
plt.gca().set(xlabel='Target word 1',ylabel='Target word 2')

plt.title(f'MI of all data: {miAll:.2f}\nMI of trimmed data: {miTrim:.2f}')

plt.legend()
plt.show()

In [None]:
# comparisons
print(f'Manual, all data : {miAll:.3f}')
print(f'Manual, trimmed  : {miTrim:.3f}')
print(f'Difference:      : {miAll-miTrim:.3f}\n')

print(f'Sklearn, all data: {miSkAll:.3f}')
print(f'Sklearn, trimmed : {miSkTrim:.3f}')
print(f'Difference:      : {miSkAll-miSkTrim:.3f}')

# Exercise 3: MI and token distances in one layer

In [None]:
mi = np.zeros((len(target_idxs),len(target_idxs)))
tokdists = np.zeros((len(target_idxs),len(target_idxs)))

# double-loop over the word pairs
for toki in range(len(target_idxs)):
  for tokj in range(toki+1,len(target_idxs)):

    # extract the data
    x = activations['att_proj_3'][0,target_idxs[toki],:]
    y = activations['att_proj_3'][0,target_idxs[tokj],:]

    # pairwise mutual information
    mi[toki,tokj] = mutInfo(x,y,4)
    tokdists[toki,tokj] = target_idxs[tokj]-target_idxs[toki]

    # next line is for exercise 5
    # mi[toki,tokj] = mutual_info_regression(x.reshape(-1,1),y)[0]

In [None]:
fig,axs = plt.subplots(1,3,figsize=(12,4))

h = axs[0].imshow(mi,origin='lower',vmin=0,vmax=.6)
axs[0].set(xlabel='Target token index',ylabel='Target token index',title='Mutual information')
fig.colorbar(h,ax=axs[0],pad=.02,fraction=.047)

# distances
h = axs[1].imshow(tokdists,origin='lower',vmin=0,vmax=60)
axs[1].set(xlabel='Target token index',ylabel='Target token index',title='Inter-token distances')
fig.colorbar(h,ax=axs[1],pad=.02,fraction=.047)

# correlate MI with token distance
r = stats.kendalltau(tokdists[np.nonzero(tokdists)],mi[np.nonzero(mi)])

axs[2].plot(tokdists[np.nonzero(tokdists)],mi[np.nonzero(mi)],'ks',markersize=10,markerfacecolor=[.7,.7,.9])
axs[2].set(xlabel='Inter-token distance',ylabel='Mutual information',
           title=f"Kendall's $\\tau: {r.statistic:.2f}$ ($p={r.pvalue:.4f}$)")

plt.tight_layout()
plt.show()

# Exercise 4: MI metrics over layers

In [None]:
MIresults = np.zeros((2,nLayers,2))

# initialize temp matrices (overwritten in each layer)
miA = np.zeros((len(target_idxs),len(target_idxs)))
miM = np.zeros((len(target_idxs),len(target_idxs)))

sublayerComps = np.zeros((nLayers,2,2))



# loop over layers
for layeri in range(nLayers):


  # double-loop over the word pairs
  for toki in range(len(target_idxs)):
    for tokj in range(toki+1,len(target_idxs)):

      ### ATTENTION block
      # extract the data
      x = activations[f'att_proj_{layeri}'][0,target_idxs[toki],:]
      y = activations[f'att_proj_{layeri}'][0,target_idxs[tokj],:]

      # trimmed manual MI implementation
      miA[toki,tokj] = mutInfo(x,y,4)

      # Exercise 5: pairwise mutual information using sklearn
      # miA[toki,tokj] = mutual_info_regression(x.reshape(-1,1),y)[0]


      ### MLP block
      # extract the data
      x = activations[f'mlp_proj_{layeri}'][0,target_idxs[toki],:]
      y = activations[f'mlp_proj_{layeri}'][0,target_idxs[tokj],:]

      # pairwise mutual information (second line is for Exercise 5)
      miM[toki,tokj] = mutInfo(x,y,4)
      # miM[toki,tokj] = mutual_info_regression(x.reshape(-1,1),y)[0]


  # ATTENTION summary statistics
  uMIa = miA[np.triu_indices(miA.shape[0],1)] # (uMIa = unique mutual information attn)
  uDi  = tokdists[np.triu_indices(tokdists.shape[0],1)] # doesn't change for attn-vs-mlp
  MIresults[0,layeri,0] = np.mean(uMIa)
  MIresults[0,layeri,1] = stats.kendalltau(uMIa,uDi).statistic

  # MLP summary statistics
  uMIm = miM[np.triu_indices(miM.shape[0],1)]
  MIresults[1,layeri,0] = np.mean(uMIm)
  MIresults[1,layeri,1] = stats.kendalltau(uMIm,uDi).statistic

  # t-test to compare MI
  t = stats.ttest_ind(uMIm,uMIa)
  sublayerComps[layeri,0,0] = t.statistic
  sublayerComps[layeri,0,1] = t.pvalue

  # z-test to compare correlations
  ra = np.atanh(MIresults[0,layeri,1]) # fisher-transformed correlation
  rm = np.atanh(MIresults[1,layeri,1])
  z = (ra-rm) / np.sqrt( 2/(len(uMIa)-3) ) # diff/ste
  p = stats.norm.cdf(-abs(z)) # p-value

  sublayerComps[layeri,1,0] = z
  sublayerComps[layeri,1,1] = p


In [None]:
_,axs = plt.subplots(2,3,figsize=(12,5.5))

sublayer_labels = [ 'ATT','MLP' ]

for i in range(2):

  # plot the average MI
  axs[0,i].plot(MIresults[i,:,0],'ko',markerfacecolor=[.9,.7,.7,.7],markersize=8)
  axs[0,i].set(xlabel='Transformer block',ylabel='Mutual information',
               title=f'{sublayer_labels[i]}: Mutual information by layer',
               ylim=[0,MIresults[:,:,0].max()*1.1])

  # plot the MI correlation with token distance
  axs[1,i].plot(MIresults[i,:,1],'ks',markerfacecolor=[.7,.7,.9,.7],markersize=8)
  axs[1,i].set(xlabel='Transformer block',ylabel='Kendall $\\tau$',
               title=f'{sublayer_labels[i]}: MI-distance correlation by layer',
               ylim=[MIresults[:,:,1].min()*1.1,.1])

  # comparing att vs mlp
  fdr_ps = fdrcorrection(sublayerComps[:,i,1])[0]
  for li in range(nLayers):
    marker = 'kh' if fdr_ps[li] else 'rx'
    axs[i,2].plot(li,sublayerComps[li,i,0],marker,markerfacecolor=[.7,.9,.7,.7],markersize=8)

  # legend (kinda hacky)
  axs[i,2].plot(100,0,'rx',label='Non-sig.') # out of bounds
  axs[i,2].plot(100,0,'kh',markerfacecolor=[.7,.9,.7],markersize=8,label='Significant')
  axs[i,2].legend()
  axs[i,2].set(xlim=[-2,nLayers+1],xlabel='Transformer block',ylabel=f"{'tz'[i]}-statistic",title='ATT vs. MLP by layer')
  axs[i,2].axhline(0,linestyle='--',color=[.7,.7,.7])


plt.tight_layout()
plt.show()

# Exercise 5: Compare manual and sklearn implementations