|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>RSA (representational similarity analysis)<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
# NOTE: If you get errors importing, run the following !pip... line,
# then restart your session (from Runtime menu) and comment out the pip line.
# !pip install gensim

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import gensim.downloader as api # see previous cell for installation

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# load a word2vec model
w2v = api.load('word2vec-google-news-300')

# Extract embeddings matrices for selected words

In [None]:
# list of words for RSA
words = [ 'space','spaceship','planet','moon','star','galaxy',
          'chair','table','couch','stool','floor',
          'apple','banana','pear','kiwi','orange','peach','watermelon','starfruit','date'
        ]

# embeddings matrix for these words
embedmat = np.array([w2v[w] for w in words])
embedmat.shape

In [None]:
# extract those embeddings vectors for even and odd dimensions
embedmat_evn = embedmat[:,::2]
embedmat_odd = embedmat[:,1::2]

# check matrices sizes
print(f'Size of "even-dimensions" matrix: {embedmat_evn.shape}')
print(f'Size of "odd-dimensions"  matrix: {embedmat_odd.shape}')

# sanity-check that they're really different
plt.figure(figsize=(10,4))
plt.plot(range(0,w2v.vector_size,2),embedmat_evn[0,:],'s-',label='Even dimensions')
plt.plot(range(1,w2v.vector_size,2),embedmat_odd[0,:],'o-',label='Odd dimensions')
plt.gca().set(xlim=[0,w2v.vector_size],xlabel='Dimension',ylabel='Value')
plt.legend(fontsize=10)
plt.show()

# Calculate the cosine similarity matrices

In [None]:
# normalize each vector to its norm (unit length)
E_evn_norm = embedmat_evn / np.linalg.norm(embedmat_evn,axis=1,keepdims=True)
E_odd_norm = embedmat_odd / np.linalg.norm(embedmat_odd,axis=1,keepdims=True)

# cosine similarity matrices
cs_matrix_evn = E_evn_norm @ E_evn_norm.T
cs_matrix_odd = E_odd_norm @ E_odd_norm.T

# Visualize for qualitative comparison

In [None]:
_,axs = plt.subplots(1,2,figsize=(12,5))

# even dims
axs[0].imshow(cs_matrix_evn,vmin=.1,vmax=.6,cmap='plasma')
axs[0].set(xticks=range(0,len(words),2),xticklabels=words[::2],yticks=range(1,len(words),2),yticklabels=words[1::2],
           title='Cossim matrix for EVEN dims')
axs[0].tick_params(axis='x',labelrotation=90)

# odd dims
axs[1].imshow(cs_matrix_odd,vmin=.1,vmax=.6,cmap='plasma')
axs[1].set(xticks=range(0,len(words),2),xticklabels=words[::2],yticks=range(1,len(words),2),yticklabels=words[1::2],
           title='Cossim matrix for ODD dims')
axs[1].tick_params(axis='x',labelrotation=90)

plt.tight_layout()
plt.show()

# Quantitative comparison via RSA

In [None]:
# extract the upper-triangular elements
unique_evn = cs_matrix_evn[np.triu_indices_from(cs_matrix_evn, k=1)]
unique_odd = cs_matrix_odd[np.triu_indices_from(cs_matrix_odd, k=1)]

# Pearson correlation
r = np.corrcoef(unique_evn,unique_odd)[0,1]

# plot
_,axs = plt.subplots(1,2,figsize=(11,5))
axs[0].plot(unique_evn,unique_odd,'ks',markerfacecolor=[.7,.7,.9,.7])
axs[0].set(xlabel='EVEN cosine similarities',ylabel='ODD cosine similarities',xlim=[-.2,1.03],ylim=[-.2,1.03],
              title=f'UNIQUE: Correlation (RSA score): r = {r:.3f}')



## -- why you need to extract only the unique elements...
unique_evn = cs_matrix_evn.flatten()
unique_odd = cs_matrix_odd.flatten()
r = np.corrcoef(unique_evn,unique_odd)[0,1]
axs[1].plot(unique_evn,unique_odd,'ks',markerfacecolor=[.9,.7,.7,.7])
axs[1].set(xlabel='EVEN cosine similarities',ylabel='ODD cosine similarities',xlim=[-.2,1.03],ylim=[-.2,1.03],
              title=f'FULL: Correlation (RSA score): r = {r:.3f}')


plt.tight_layout()
plt.show()