|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: cluster the "x" terms<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
from transformers import GPT2Model,GPT2Tokenizer

# pretrained GPT-2 model and tokenizer
gpt2 = GPT2Model.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
embedding = gpt2.wte.weight.detach().numpy()

# Exercise 1: Find the "x" terms

In [None]:
# find all tokens that start with x and
xToks = []

for i in range(tokenizer.vocab_size):

  # get the token (ignoring preceding space)
  tok = tokenizer.decode([i])

  # add it to the list if it's between 4 and 8 characters
  if ('x' in tok) & (len(tok.strip())>3) & (len(tok.strip())<9):
    print(tok)
    xToks.append(i)

nToks = len(xToks)

In [None]:
print(f'There are {nToks} tokens with an "x"')

# Exercise 2: tSNE

In [None]:
# extract a submatrix with only the relevant tokens
subEmbed = embedding[xToks,:]


# reduce to 2D with t-SNE
tsne = TSNE(n_components=2,perplexity=5)
tsne_result = tsne.fit_transform(subEmbed)

# plot results
_,axs = plt.subplots(1,2,figsize=(13,5))

axs[0].imshow(subEmbed@subEmbed.T,origin='lower',vmin=2,vmax=7)
axs[0].set(xticks=range(1,nToks,21),
           yticks=range(0,nToks,21),yticklabels=[tokenizer.decode(xToks[i]) for i in np.arange(0,nToks,21)],
           title=f'Gram matrix of {nToks} token embeds')
axs[0].set_xticklabels([tokenizer.decode(xToks[i]) for i in np.arange(1,nToks,21)],rotation=90)

axs[1].scatter(tsne_result[:,0], tsne_result[:,1], color=[.7,.7,1],edgecolor='k')


# label words
yoffset = .02 * np.diff(plt.gca().get_ylim()) # shift words up by x%
for i in range(nToks):
  axs[1].text(tsne_result[i,0], tsne_result[i,1]+yoffset, tokenizer.decode([xToks[i]]),  ha='center', fontsize=10,color='gray')

axs[1].set(xlabel='TSNE dim 1',ylabel='TSNE dim 2',title='T-SNE visualization of embeddings')

plt.tight_layout()
plt.show()

# Exercise 3: DBSCAN

In [None]:
## dbscan
epsilon = 7
minsamples = 3

clustmodel = DBSCAN(eps=epsilon,min_samples=minsamples).fit(tsne_result)
groupidx = clustmodel.labels_

# number of clusters
nclust = max(groupidx)+1 # +1 for indexing

# compute cluster centers
cents = np.zeros((nclust,2))
for ci in range(nclust):
  cents[ci,0] = np.mean(tsne_result[groupidx==ci,0])
  cents[ci,1] = np.mean(tsne_result[groupidx==ci,1])

# draw lines from each data point to the centroids of each cluster
plt.figure(figsize=(8,6))
lineColors = 'rkbgm'
for i in range(len(tsne_result)):
  if groupidx[i]==-1:
    plt.plot(tsne_result[i,0],tsne_result[i,1],'k+')
  else:
    plt.plot([ tsne_result[i,0], cents[groupidx[i],0] ],[ tsne_result[i,1], cents[groupidx[i],1] ],lineColors[groupidx[i]%len(lineColors)])


# now draw the raw data in different colors
for i in range(nclust):
  plt.plot(tsne_result[groupidx==i,0],tsne_result[groupidx==i,1],'o',markerfacecolor=lineColors[i%len(lineColors)])

# and now plot the centroid locations
plt.gca().set(xlabel='tSNE axis 1',ylabel='tSNE axis 2',
              title=f'DBSCAN ($\\epsilon$={epsilon}, minclust={minsamples}) identified {nclust} clusters')

plt.show()

In [None]:
# print the tokens in each cluster

for cidx in range(-1,nclust):

  # find all the tokens in this group
  tokensInGroup = np.where(groupidx==cidx)[0]

  # print them out
  if cidx==-1:
    print(f'\nUngrouped tokens:')
  else:
    print(f'\nTokens in group {cidx}:')
  print([ ''.join(tokenizer.decode([xToks[t]])) for t in tokensInGroup ])

# Exercise 4: Cluster count by parameter value

In [None]:
# vary the epsilons for a fixed number of samples
epsilons = np.linspace(2,20,15)
numClusters = np.zeros(len(epsilons))

for i,epi in enumerate(epsilons):
  clustmodel = DBSCAN(eps=epi,min_samples=3).fit(tsne_result)
  numClusters[i] = max(clustmodel.labels_)+1

plt.figure(figsize=(10,3))
plt.plot(epsilons,numClusters,'ks-',markerfacecolor=[.9,.7,.7],markersize=10)
plt.gca().set(xlabel='Epsilon parameter',ylabel='Number of labeled clusters',
              title='Minimum sample size fixed to 3')
plt.show()

In [None]:
# vary the number of samples for a fixed epsilon
minsamples = np.arange(1,20)
numClusters = np.zeros(len(minsamples))

for i,nsamp in enumerate(minsamples):
  clustmodel = DBSCAN(eps=7,min_samples=nsamp).fit(tsne_result)
  numClusters[i] = max(clustmodel.labels_)+1

plt.figure(figsize=(10,3))
plt.plot(minsamples,numClusters,'ks-',markerfacecolor=[.7,.9,.7],markersize=10)
plt.gca().set(xlabel='Minimum samples per cluster',ylabel='Number of labeled clusters',
              title='$\epsilon$ fixed to 7')
plt.show()

In [None]:
# vary both!
epsilons = np.linspace(2,20,15)
minsamples = np.arange(1,20)

numClusters = np.zeros((len(epsilons),len(minsamples)))

for i,epi in enumerate(epsilons):
  for j,nsamp in enumerate(minsamples):
    clustmodel = DBSCAN(eps=epi,min_samples=nsamp).fit(tsne_result)
    numClusters[i,j] = max(clustmodel.labels_)+1


# show in an image
plt.figure(figsize=(10,7))
plt.imshow(numClusters,vmin=0,vmax=80,origin='lower',aspect='auto',
           extent=[epsilons[0],epsilons[-1],minsamples[0],minsamples[-1]])
plt.gca().set(xlabel='Epsilon parameter',ylabel='Minimum samples per cluster',title='Number of labeled clusters')
plt.colorbar(pad=.02)
plt.show()