|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating layers<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Clusters in internal vs. terminal punctuation<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib as mpl

import requests

import torch

# vector plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Model and punctuation indices

In [None]:
# load pretrained GPT-2 model and tokenizer
from transformers import AutoModelForCausalLM,GPT2Tokenizer
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

gpt2 = AutoModelForCausalLM.from_pretrained('gpt2-medium')
gpt2 = gpt2.to(device)
gpt2.eval()
nEmb = gpt2.config.n_embd

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')

In [None]:
text = requests.get('https://www.gutenberg.org/cache/epub/219/pg219.txt').text # Heart of Darkness
tokens = tokenizer.encode(text,return_tensors='pt')
num_tokens = len(tokens[0])

print(f'There are {num_tokens:,} tokens, {len(np.unique(tokens[0].tolist()))} of which are unique.')

In [None]:
# tokens to match exactly
internal_punctuations = [',']#,':',';']
terminal_punctuations = ['.']#,'?','!']


# initialize vector
isPunctuation = np.zeros(num_tokens,dtype=int)

# loop over all tokens (ignore starting tokens before the book starts)
for ti in range(400,num_tokens):

  # current token
  currtok = tokenizer.decode(tokens[0,ti]).strip()


  # test for punctuation
  if currtok in internal_punctuations:
      isPunctuation[ti] = 1

  # check if it's terminal -- and not a decimal-point number!
  elif currtok in terminal_punctuations:
    if (tokenizer.decode(tokens[0,ti-1])[-1] not in '0123456789') and (tokenizer.decode(tokens[0,ti+1])[0] not in '0123456789'):
      isPunctuation[ti] = 2


# report
print(f'There are:\n  {sum(isPunctuation==1):3} internal punctuation marks, and\n  {sum(isPunctuation==2):3} terminal punctuation marks.')

In [None]:
# find the indices
internalIdx = np.where(isPunctuation==1)[0]
terminalIdx = np.where(isPunctuation==2)[0]

# examine some punctuations
context_win = 9
for t in terminalIdx[:5]:
  print(f'Example {t}:\n{tokenizer.decode(tokens[0,t-context_win:t+context_win])}\n')

# Exercise 2: Create batches and get activations

In [None]:
# some data parameters
batchsize   = 250
context_pre = 20
context_pst = 10

In [None]:
# create batches
batch_internal = torch.zeros((batchsize,context_pre+context_pst+1),dtype=torch.long)
batch_terminal = torch.zeros((batchsize,context_pre+context_pst+1),dtype=torch.long)

# loop over sequences to create batches
for b in range(batchsize):

  # internal punctuations
  tokenLoc = internalIdx[b]
  batch_internal[b,:] = tokens[0,tokenLoc-context_pre:tokenLoc+context_pst+1]

  # terminal punctuations
  tokenLoc = terminalIdx[b]
  batch_terminal[b,:] = tokens[0,tokenLoc-context_pre:tokenLoc+context_pst+1]


batch_terminal.shape

In [None]:
# process the internal tokens
with torch.no_grad():
  output_internal = gpt2(batch_internal.to(device),output_hidden_states=True)

# repeat for terminal tokens
with torch.no_grad():
  output_terminal = gpt2(batch_terminal.to(device),output_hidden_states=True)

In [None]:
print(output_internal.hidden_states[3].shape)
len(output_internal.hidden_states)

In [None]:
# for convenience, bring the activations to the CPU in a shorter-named variable
hsIntern = []
hsTermin = []

for i in range(len(output_internal.hidden_states)):
  hsIntern.append( output_internal.hidden_states[i].detach().cpu().numpy() )
  hsTermin.append( output_terminal.hidden_states[i].detach().cpu().numpy() )

In [None]:
hsIntern[4].shape

In [None]:
# visualize

ave_intern = hsIntern[3][:,context_pre,:].mean(axis=0)
ave_termin = hsTermin[3][:,context_pre,:].mean(axis=0)

_,axs = plt.subplots(1,2,figsize=(10,4))

axs[0].plot(ave_intern,'o',label='Internal')
axs[0].plot(ave_termin,'s',label='Terminal')
axs[0].set(xlabel='Embedding dimension index',ylabel='Hidden state activation',title='Activations')
axs[0].legend()

axs[1].plot(ave_intern,ave_termin,'ko',markerfacecolor=[.4,.7,.5],alpha=.4)
axs[1].set(xlabel='Internal',ylabel='Terminal',title='Scatter plot of activations')

plt.tight_layout()
plt.show()

# Exercise 3: MI and cov in one layer

In [None]:
# a function for mutual information
def mutInfo_and_cov(x,y,outlierThresh=0):

  # remove outliers based on a z-score threshold
  if outlierThresh>0:
    zx = (x-x.mean())/x.std(ddof=1)
    zy = (y-y.mean())/y.std(ddof=1)
    outlier = (abs(zx)>outlierThresh) | (abs(zy)>outlierThresh)
    x = x[~outlier]
    y = y[~outlier]


  # histogram and convert to proportion (estimate of probability)
  Z  = np.histogram2d(x,y,bins=15)[0]
  pZ = Z / Z.sum()
  px = pZ.sum(axis=1)
  py = pZ.sum(axis=0)

  # calculate entropy
  eps = 1e-12
  Hx = -np.sum( px * np.log2(px+eps) )
  Hy = -np.sum( py * np.log2(py+eps) )
  HZ = -np.sum( pZ * np.log2(pZ+eps) )

  # mutual information
  MI = Hx+Hy - HZ

  # and covariance
  C = sum( (x-x.mean())*(y-y.mean()) ) / (len(x)-1)

  return MI,C

In [None]:
# sanity-check target index ;)
tokenizer.decode(batch_terminal[3,context_pre])

In [None]:
# this cell takes ~40 sec
whichLayer = 1 # 20

mi_in = np.zeros((batchsize,batchsize))
mi_tr = np.zeros((batchsize,batchsize))
cov_in = np.zeros((batchsize,batchsize))
cov_tr = np.zeros((batchsize,batchsize))



# double-loop over the token pairs
for bi in range(batchsize):
  for bj in range(bi+1,batchsize):

    ## internal punctuation
    # extract the data
    x = hsIntern[whichLayer][bi,context_pre,:]
    y = hsIntern[whichLayer][bj,context_pre,:]

    # pairwise mutual information and covariance
    mi_in[bi,bj],cov_in[bi,bj] = mutInfo_and_cov(x,y,5)


    ## repeat for terminal punctuation
    x = hsTermin[whichLayer][bi,context_pre,:]
    y = hsTermin[whichLayer][bj,context_pre,:]
    mi_tr[bi,bj],cov_tr[bi,bj] = mutInfo_and_cov(x,y,5)

In [None]:
# a variable to extract and vectorize the unique matrix elements
uniqueIndices = np.triu_indices(batchsize,1)
uniqueIndices

In [None]:
# visualize all pairwise MI from one layer

fig = plt.figure(figsize=(8,6))
gs = GridSpec(2,2,figure=fig)
ax0 = fig.add_subplot(gs[0,0])
ax1 = fig.add_subplot(gs[0,1])
ax2 = fig.add_subplot(gs[1,:])


# show the matrices
h = ax0.imshow(mi_in,vmin=.1,vmax=1,origin='lower',aspect='auto')
plt.colorbar(h,ax=ax0,fraction=.046,pad=.02)
ax0.set(xticks=[],yticks=[],xlabel='Sequence index',ylabel='Sequence index',title='Pairwise MI for INTERNAL')

h = ax1.imshow(mi_tr,vmin=.1,vmax=1,origin='lower',aspect='auto')
plt.colorbar(h,ax=ax1,fraction=.046,pad=.02)
ax1.set(xticks=[],yticks=[],xlabel='Sequence index',ylabel='Sequence index',title='Pairwise MI for TERMINAL')


# get histograms from nonzero elements
yIntern,xIntern = np.histogram(mi_in[uniqueIndices],bins=60)
yTermin,xTermin = np.histogram(mi_tr[uniqueIndices],bins=60)

# and show those
ax2.plot(xIntern[:-1],yIntern,label='Internal',linewidth=2)
ax2.plot(xTermin[:-1],yTermin,label='Terminal',linewidth=2)
ax2.set(xlim=[min(xIntern[0],xTermin[0]),max(xIntern[-1],xTermin[-1])],
        xlabel='Mutual information',ylabel='Count',title='Distribution of MI values')
ax2.legend()


plt.tight_layout()
plt.suptitle(f'Mutual information in layer {whichLayer}',y=1.05,fontweight='bold',fontsize=16)
plt.show()

In [None]:
# visualize all pairwise covariance from one layer

fig = plt.figure(figsize=(8,6))
gs = GridSpec(2,2,figure=fig)
ax0 = fig.add_subplot(gs[0,0])
ax1 = fig.add_subplot(gs[0,1])
ax2 = fig.add_subplot(gs[1,:])


# show the matrices
h = ax0.imshow(cov_in,vmin=0,vmax=2,origin='lower')
plt.colorbar(h,ax=ax0,fraction=.046,pad=.02)
ax0.set(xticks=[],yticks=[],xlabel='Sequence index',ylabel='Sequence index',title='Pairwise MI for INTERNAL')

h = ax1.imshow(cov_tr,vmin=0,vmax=2,origin='lower')
plt.colorbar(h,ax=ax1,fraction=.046,pad=.02)
ax1.set(xticks=[],yticks=[],xlabel='Sequence index',ylabel='Sequence index',title='Pairwise MI for TERMINAL')


# get histograms from nonzero elements
yIntern,xIntern = np.histogram(cov_in[uniqueIndices],bins=60)
yTermin,xTermin = np.histogram(cov_tr[uniqueIndices],bins=60)

# and show those
ax2.plot(xIntern[:-1],(yIntern),label='Internal',linewidth=2)
ax2.plot(xTermin[:-1],(yTermin),label='Terminal',linewidth=2)
ax2.set(xlim=[min(xIntern[0],xTermin[0]),max(xIntern[-1],xTermin[-1])],
        xlabel='Covariance',ylabel='Log count',title='Distribution of covariance values')
ax2.legend()


plt.tight_layout()
plt.suptitle(f'Covariance in layer {whichLayer}',y=1.05,fontweight='bold',fontsize=16)
plt.show()

In [None]:
## covariance by mutual information

_,axs = plt.subplots(1,3,figsize=(12,4))

# skip to facilitate plotting
pnts2skip = 17

# plot the internal punctuations
axs[0].plot(cov_in[uniqueIndices][::pnts2skip],mi_in[uniqueIndices][::pnts2skip],'.',markerfacecolor=[.7,.7,.9,.3])
axs[0].set(xlabel='Covariance',ylabel='Mutual information',
           title=f'Internal punctuation (r = {np.corrcoef(cov_in[uniqueIndices],mi_in[uniqueIndices])[0,1]:.3f})')

# plot the terminal punctuations
axs[1].plot(cov_tr[uniqueIndices][::pnts2skip],mi_tr[uniqueIndices][::pnts2skip],'.',markerfacecolor=[.9,.7,.7,.3])
axs[1].set(xlabel='Covariance',ylabel='Mutual information',
           title=f'Terminal punctuation (r = {np.corrcoef(cov_tr[uniqueIndices],mi_tr[uniqueIndices])[0,1]:.3f})')


axs[2].plot(cov_in[uniqueIndices][::pnts2skip],mi_in[uniqueIndices][::pnts2skip],'r.',alpha=.3,label='Intermediate')
axs[2].plot(cov_tr[uniqueIndices][::pnts2skip],mi_tr[uniqueIndices][::pnts2skip],'k.',alpha=.3,label='Terminal')
axs[2].set(xlabel='Covariance',ylabel='Mutual information',title='Both')
axs[2].legend()


plt.tight_layout()
plt.show()

# Exercise 4: MI and cov over all layers

In [None]:
# initializations
mi_in_all = np.zeros((len(hsIntern),batchsize,batchsize))
mi_tr_all = np.zeros((len(hsIntern),batchsize,batchsize))
cv_in_all = np.zeros((len(hsIntern),batchsize,batchsize))
cv_tr_all = np.zeros((len(hsIntern),batchsize,batchsize))


# loop over layers
for layeri in range(len(hsIntern)):

  # double-loop over the token pairs
  for bi in range(batchsize):
    for bj in range(bi+1,batchsize):

      ## internal punctuation
      x = hsIntern[layeri][bi,context_pre,:]
      y = hsIntern[layeri][bj,context_pre,:]
      mi_in_all[layeri,bi,bj],cv_in_all[layeri,bi,bj] = mutInfo_and_cov(x,y,5)

      ## repeat for terminal punctuation
      x = hsTermin[layeri][bi,context_pre,:]
      y = hsTermin[layeri][bj,context_pre,:]
      mi_tr_all[layeri,bi,bj],cv_tr_all[layeri,bi,bj] = mutInfo_and_cov(x,y,5)

  print(f'Finished layer {layeri+1:2}/{len(hsIntern)}')

In [None]:
# scatter plots
_,axs = plt.subplots(5,5,figsize=(12,11))
axs = axs.flatten()


# skip to facilitate plotting
pnts2skip = 7

for layeri in range(len(hsIntern)):

  # plot the dots
  axs[layeri].plot(cv_in_all[layeri][uniqueIndices][::pnts2skip],mi_in_all[layeri][uniqueIndices][::pnts2skip],'r.',alpha=.2,label='Intermediate')
  axs[layeri].plot(cv_tr_all[layeri][uniqueIndices][::pnts2skip],mi_tr_all[layeri][uniqueIndices][::pnts2skip],'k.',alpha=.2,label='Terminal')

  # adjust the axis
  axs[layeri].set(xticks=[],yticks=[],title=f'Layer {layeri}')

axs[20].set(xlabel='Covariance',ylabel='Mutual information')
axs[20].legend(fontsize=8)

plt.tight_layout()
# plt.savefig('ex4.png')
plt.show()

# Exercise 5: Distributions and means

In [None]:
fig,axs = plt.subplots(2,2,figsize=(10,6))

# normalization function for mapping layer index onto color
norm = mpl.colors.Normalize(vmin=0,vmax=len(hsIntern))

for layeri in range(len(hsIntern)):

  ### covariance: plot the distribution
  yy,xx = np.histogram(cv_in_all[layeri][uniqueIndices],bins=np.linspace(-1,105,101),density=True)
  axs[0,0].plot(xx[:-1],yy,color=mpl.cm.plasma(norm(layeri)))

  # and the mean of the distribution
  axs[0,1].plot(layeri,cv_in_all[layeri][uniqueIndices].mean(),'ks',
              markersize=10,markerfacecolor=mpl.cm.plasma(norm(layeri)))



  ### mutual information
  yy,xx = np.histogram(mi_in_all[layeri][uniqueIndices],bins=np.linspace(0,1.5,101),density=True)
  axs[1,0].plot(xx[:-1],yy,color=mpl.cm.plasma(norm(layeri)))

  # the mean
  axs[1,1].plot(layeri,mi_in_all[layeri][uniqueIndices].mean(),'ks',
              markersize=10,markerfacecolor=mpl.cm.plasma(norm(layeri)))



# add colorbars
sm = mpl.cm.ScalarMappable(cmap=mpl.cm.plasma,norm=norm)
plt.colorbar(sm,ax=axs[0,0],pad=.01)
plt.colorbar(sm,ax=axs[1,0],pad=.01)

# labels and titles
axs[0,0].set(xlabel='Covariance value',ylabel='Density',title='Covariance distribution')
axs[0,1].set(xlabel='Hidden layer',ylabel='Covariance',title='Average covariances')
axs[1,0].set(xlabel='Mutual information value',ylabel='Density',title='MI distribution')
axs[1,1].set(xlabel='Hidden layer',ylabel='Mutual information',title='Average MI values')

plt.tight_layout()
plt.show()