|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>kNN for synonym-searching in BERT<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Demo of kNN for classification

In [None]:
# data labels and categories ("X" is the unlabeled data value)
dataLabels = 'ABCDEFGHIJLKMNOPX'
categories = ( np.linspace(0,1,len(dataLabels)-1)>.5 ).astype(int)
unlabeled  = len(dataLabels)-1 # final value

# generate some random data
# data = np.random.randn(len(dataLabels),2)


# Euclidean distance from unlabeled data value to all others
eucldist = np.sqrt(np.sum( (data-data[unlabeled,:])**2 ,axis=1))

# plot all letters
for i in range(len(data)-1):
  c = 'br'[categories[i]]
  plt.plot(data[i,0],data[i,1],marker=f'${dataLabels[i]}$',color=c,markersize=10)

# plot the seed
plt.plot(data[unlabeled,0],data[unlabeled,1],marker=f'${dataLabels[unlabeled]}$',color='k',markersize=12)

plt.gca().set(xlim=[-3,3],ylim=[-3,3],xlabel='Embedding dimension "1"',ylabel='Embedding dimension "2"')
plt.show()

In [None]:
# distance sorting indices (excluding self-distance)
distidx = np.argsort(eucldist)[1:]

# print by sorted distance
for i in distidx:
  print(f'"{dataLabels[i]}" is {eucldist[i]:>5.2f} units from "{dataLabels[unlabeled]}"')

In [None]:
# find the categories of the k nearest neighbors
k = 3

# label the unlabeled
targCat = np.median(categories[distidx[:k]]).astype(int)

# print the result
print(f'The categories of the {k} nearest neighors are {categories[distidx[:k]]}\n')
print(f'The unlabeled data value is in category "{targCat}" ({"br"[targCat]} in the plot)')

# Now for the synonym-searching

In [None]:
from transformers import BertTokenizer, BertModel

# load BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

In [None]:
# the embeddings matrix
embeddings = model.embeddings.word_embeddings.weight.detach().numpy()
print(f'Embeddings matrix shape: {embeddings.shape}')

# kNN on BERT

In [None]:
# pick a "seed" vector
seedword = 'beauty'

seedvect = embeddings[tokenizer.encode(seedword,add_special_tokens=False),:]

# Euclidean distance to all other vectors
eucDist = np.sqrt( np.sum( (embeddings-seedvect)**2 ,axis=1) )

# cosine similarity for comparison
E = embeddings / np.linalg.norm(embeddings,axis=1,keepdims=True)
cs = (seedvect/np.linalg.norm(seedvect)) @ E.T
cs = np.squeeze(cs) # remove singleton dimension

# for visualization, replace 0 with non
eucDist_nan = eucDist+0
eucDist_nan[eucDist==0] = np.nan



# visualizations
_,axs = plt.subplots(1,2,figsize=(12,4))
axs[0].scatter(range(len(eucDist)),eucDist_nan,s=50,c=cs,alpha=.4)
axs[0].set(xlim=[-20,len(eucDist)+20],xlabel='Token index',ylabel='Euclidean distance',
           title=f'Distance to "{seedword}", colored by cosine sim.')

axs[1].plot(eucDist_nan,cs,'ko',markerfacecolor=[.7,.7,.9,.6])
axs[1].set(xlabel='Euclidean distance',ylabel='Cosine similarity',
           title='Relation between $S_c$ and Euclidean distance')

plt.tight_layout()
plt.show()

In [None]:
# now for the top-k closest tokens
k = 15
topKidx = np.argsort(eucDist)[:k]

print(f'Nearest {k} words to "{seedword}":')
for i in topKidx:
  print(f'  Distance of {eucDist[i]:.3f} to "{tokenizer.decode(i)}"')