|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>T-SNE projection and DBscan clustering<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
from transformers import GPT2Model,GPT2Tokenizer

# pretrained GPT-2 model and tokenizer
gpt2 = GPT2Model.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [None]:
# get the embeddings matrix
embedding = gpt2.wte.weight.detach().numpy()

# Gram matrix

In [None]:
# extract the first N embeddings
nToks = 100
subEmbed = embedding[:nToks,:]

In [None]:
# image for creating a Gram matrix
_,axs = plt.subplots(1,3,figsize=(12,4))

G = subEmbed @ subEmbed.T

axs[0].imshow(subEmbed,vmin=-.2,vmax=.2)
axs[0].set(xlabel='Embedding dimension',ylabel='Token',title='Embeddings')

axs[1].imshow(subEmbed.T,vmin=-.2,vmax=.2)
axs[1].set(ylabel='Embedding dimension',xlabel='Token',title='Embeddings transpose')

axs[2].imshow(G,vmin=2,vmax=7)
axs[2].set(xlabel='Token',ylabel='Token',title='Gram matrix')

plt.tight_layout()
plt.show()

# TSNE on some embeddings vectors

In [None]:
# reduce to 2D with t-SNE
tsne = TSNE(n_components=2,perplexity=5)
tsne_result = tsne.fit_transform(subEmbed)

# the result is an Nx2 matrix
tsne_result.shape

In [None]:
# plot the results
_,axs = plt.subplots(1,2,figsize=(13,5))

# show the gram matrix
axs[0].imshow(G,origin='lower',vmin=2,vmax=7)
axs[0].set(xticks=range(1,nToks,3),xticklabels=[tokenizer.decode(i) for i in np.arange(1,nToks,3)],
           yticks=range(0,nToks,3),yticklabels=[tokenizer.decode(i) for i in np.arange(0,nToks,3)],
           title=f'Gram matrix, first {nToks} token embeds')

axs[1].scatter(tsne_result[:,0], tsne_result[:,1], color=[.7,.7,1],edgecolor='k')


# label words
yoffset = .02 * np.diff(plt.gca().get_ylim()) # shift words up by x%
for i in range(nToks):
  axs[1].text(tsne_result[i,0], tsne_result[i,1]+yoffset, tokenizer.decode([i]),  ha='center')

axs[1].set(xlabel='TSNE dim 1',ylabel='TSNE dim 2',title='T-SNE visualization of embeddings')

plt.tight_layout()
plt.show()

# DBscan to find clusters

In [None]:
# dbscan
clustmodel = DBSCAN(eps=7,min_samples=3).fit(tsne_result)
dir(clustmodel)

In [None]:
# cluster assignment labels
groupidx = clustmodel.labels_

# number of clusters
nclust = max(groupidx)+1 # +1 for indexing

# calculate the cluster centers
cents = np.zeros((nclust,2))
for ci in range(nclust):
  cents[ci,0] = np.mean(tsne_result[groupidx==ci,0])
  cents[ci,1] = np.mean(tsne_result[groupidx==ci,1])

In [None]:
# draw lines from each data point to the centroids of each cluster
plt.figure(figsize=(8,6))
lineColors = 'rkbgm'

# plot each dot according to its cluster (or lack thereof)
for i in range(len(tsne_result)):
  if groupidx[i]==-1:
    plt.plot(tsne_result[i,0],tsne_result[i,1],'k+')
  else:
    plt.plot([ tsne_result[i,0], cents[groupidx[i],0] ],[ tsne_result[i,1], cents[groupidx[i],1] ],lineColors[groupidx[i]])


# now draw the raw data in different colors
for i in range(nclust):
  plt.plot(tsne_result[groupidx==i,0],tsne_result[groupidx==i,1],'o',markerfacecolor=lineColors[i])

# and now plot the centroid locations
plt.plot(cents[:,0],cents[:,1],'kd',markerfacecolor=[.8,.7,.1],markersize=10)
plt.gca().set(xlabel='tSNE axis 1',ylabel='tSNE axis 2',title=f'Result of dbscan clustering (k={nclust})')

plt.show()

In [None]:
groupidx

In [None]:
for cidx in range(-1,nclust):

  # find all the tokens in this group
  tokensInGroup = np.where(groupidx==cidx)[0]

  # print them out
  if cidx==-1:
    print(f'\nUngrouped tokens:')
  else:
    print(f'\nTokens in group {cidx}:')
  print([ ' '.join(tokenizer.decode([t])) for t in tokensInGroup ])