|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>Creating and interpreting linear "semantic axes"<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
# !pip install gensim

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import gensim.downloader as api

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# import word2vec
w2v = api.load('word2vec-google-news-300')

# Normalize all embedding vectors

In [None]:
print(w2v.vectors.shape)
vectors_norm = w2v.vectors / np.linalg.norm(w2v.vectors, axis=1, keepdims=True)

# Create a "semantic axis"

In [None]:
# pick two words to define the axis
word4pos = 'future'#'good'#
word4neg = 'past'#'evil'#

# get the vectors for those words
v2add = w2v[word4pos]
v2sub = w2v[word4neg]

# create the "semantic axis" with "raw" vectors
semantic_axis = v2add - v2sub
semantic_axis /= np.linalg.norm(semantic_axis) # post-subtraction normalization

# now starting from the normed vectors
v2add = vectors_norm[w2v.key_to_index[word4pos],:]
v2sub = vectors_norm[w2v.key_to_index[word4neg],:]
semantic_axis_norm = v2add - v2sub


_,axs = plt.subplots(1,2,figsize=(12,4))
axs[0].plot(semantic_axis,label='Pre-norm')
axs[0].plot(semantic_axis_norm,label='Post-norm')
axs[0].legend()
axs[0].set(xlabel='Embeddings dimension',ylabel='Embedding weight')

axs[1].plot(semantic_axis,semantic_axis_norm,'ks',markerfacecolor=[.7,.7,.9])
axs[1].set(xlabel='Difference of "raw" vectors',ylabel='Difference of normed vectors')

plt.tight_layout()
plt.show()

# Filter for "real" words

In [None]:
testwords = [ 'theInternet.com','health','FRITZ!Box','headphones' ]
filterWords = np.where([word.isalpha() and len(word)>2 for word in testwords])[0]

print('Word set:')
print([w for w in testwords])

print('\nIncluded words:')
print([testwords[w] for w in filterWords])

print('\nExcluded words:')
print([testwords[w] for w in ~filterWords])

In [None]:
allwords = list(w2v.key_to_index.keys())
words2use = np.where([word.isalpha() and len(word)>2 for word in allwords])[0]

# to test without filtering:
# words2use = np.arange(len(allwords))

# report
print(f'{len(words2use):,} out of {len(allwords):,} ({100*len(words2use)/len(allwords):.2f}%) tokens kept.')

# Project all words onto the axis

In [None]:
# calculate dot products
# dotprods = vectors_norm[words2use] @ semantic_axis
dotprods = w2v.vectors[words2use] @ semantic_axis

# find top and bottom 10 highest scores
top10 = dotprods.argsort()[-10:][::-1]
bot10 = dotprods.argsort()[:10]


# print them out
print('10 most positive-projected words:')
for widx in top10:
  print(f' Similarity of {dotprods[widx]:.3f} for "{w2v.index_to_key[words2use[widx]]}"')

print('\n10 most negative-projected words:')
for widx in bot10:
  print(f' Similarity of {dotprods[widx]:.3f} for "{w2v.index_to_key[words2use[widx]]}"')