# Import

In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
from sklearn.manifold import TSNE
import torch

In [2]:
from word2vec.src.train import train
from word2vec.src.utils import load_yaml

# Train model

In [3]:
config = load_yaml('config.yaml')
config

{'model': 'SkipGram',
 'model_dir': 'artifacts',
 'data_dir': 'data',
 'epochs': 10,
 'batch_size': 64,
 'learning_rate': 0.05,
 'checkpoint_frequency': 1,
 'steps_train': 60,
 'steps_val': 60,
 'n_words': 4,
 'min_word_frequency': 50,
 'max_sequence_length': 256,
 'embed_size': 50,
 'embed_max_norm': 1}

In [4]:
train(config)

Vocabulary size: 4099
Adjusting learning rate of group 0 to 5.0000e-02.
Epoch 1 / 10. Train loss: 47.9278, Validation loss: 32.2594
Adjusting learning rate of group 0 to 4.5000e-02.
Epoch 2 / 10. Train loss: 44.7503, Validation loss: 31.9854
Adjusting learning rate of group 0 to 4.0000e-02.
Epoch 3 / 10. Train loss: 44.3496, Validation loss: 31.7903
Adjusting learning rate of group 0 to 3.5000e-02.
Epoch 4 / 10. Train loss: 44.2255, Validation loss: 31.2724
Adjusting learning rate of group 0 to 3.0000e-02.
Epoch 5 / 10. Train loss: 44.0558, Validation loss: 31.8481
Adjusting learning rate of group 0 to 2.5000e-02.
Epoch 6 / 10. Train loss: 43.8478, Validation loss: 31.6540
Adjusting learning rate of group 0 to 2.0000e-02.
Epoch 7 / 10. Train loss: 43.7491, Validation loss: 31.7971
Adjusting learning rate of group 0 to 1.5000e-02.
Epoch 8 / 10. Train loss: 43.4695, Validation loss: 31.5692
Adjusting learning rate of group 0 to 1.0000e-02.
Epoch 9 / 10. Train loss: 43.5581, Validation lo

# Demo

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load(f"{config['model_dir']}/model.pt", map_location=device)
vocab = torch.load(f"{config['model_dir']}/vocab.pt")

In [6]:
model

SkipGram(
  (embeddings): Embedding(4099, 50, max_norm=1)
  (linear): Linear(in_features=50, out_features=4099, bias=True)
)

In [10]:
embeddings = list(model.parameters())[0]
embeddings = embeddings.detach().numpy()
df_embeddings = pd.DataFrame(embeddings)
df_embeddings

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,0.004864,0.058226,0.102455,-0.068684,-0.110115,0.157051,-0.022095,-0.100077,-0.240014,0.028700,...,-0.051605,-0.046651,-0.065859,0.013713,-0.001473,-0.008831,0.059250,0.093462,-0.212240,0.030380
1,-0.124084,-0.066151,0.154998,-0.069169,0.012323,0.069352,-0.069914,-0.011487,-0.004807,0.006637,...,-0.031389,0.026722,-0.027680,-0.039143,0.058650,-0.099987,-0.047607,0.142750,-0.168110,0.085693
2,0.031894,0.043751,-0.105125,0.006367,0.057542,0.010927,-0.036122,-0.153236,-0.356225,0.000016,...,0.071755,-0.013104,-0.309857,0.042005,0.083636,0.070455,-0.022662,0.031165,-0.140603,-0.032430
3,0.249401,0.041974,0.005398,-0.012846,0.036359,0.100475,-0.125695,-0.022251,-0.111198,0.067785,...,-0.060293,0.112995,-0.001086,0.010267,-0.014071,0.132088,-0.088297,0.021666,-0.533064,0.004619
4,0.046675,0.128045,-0.126386,0.017677,-0.086506,-0.026073,0.068903,-0.230947,0.108470,0.149829,...,-0.247969,-0.004013,-0.243505,0.216038,-0.222477,0.161965,0.075252,0.132246,-0.013527,0.103756
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4094,-0.199458,0.127758,0.004696,-0.151444,0.196218,0.137440,0.119755,-0.103406,-0.065517,-0.059096,...,0.000133,-0.248536,-0.191121,-0.064573,-0.024117,-0.024690,-0.242933,0.131435,-0.051108,0.070901
4095,0.009298,0.052724,0.057545,0.055180,0.101406,0.047301,0.074818,-0.028003,-0.195690,-0.169775,...,-0.202503,0.006349,-0.088079,-0.077961,-0.016131,-0.173869,0.283655,0.135124,-0.108162,-0.108959
4096,-0.089679,0.117325,0.083674,0.061179,-0.234210,0.127248,-0.117865,-0.072367,-0.073205,0.162165,...,-0.080407,-0.067745,-0.050152,-0.161416,-0.143546,0.154494,0.059882,0.054609,-0.073639,0.124874
4097,-0.226796,0.133705,0.016640,0.089745,0.021849,0.014482,-0.232020,-0.129789,-0.055314,0.121525,...,-0.099888,-0.240493,-0.179643,-0.114220,-0.062328,-0.014637,-0.080130,0.043271,-0.143201,-0.018509


# Visualisation

In [11]:
tsne = TSNE(n_components=2)
df_embeddings_transformed = pd.DataFrame(tsne.fit_transform(df_embeddings))
df_embeddings_transformed.index = vocab.get_itos()
df_embeddings_transformed.rename(columns={0: 'x', 1: 'y'}, inplace=True)

In [None]:
df_embeddings_plot = df_embeddings_transformed.sample(50, replace=False)

fig = px.scatter(df_embeddings_plot, x='x', y='y', text=df_embeddings_plot.index)
fig.update_layout(
    height=800,
    title_text='Visualisation of words using t-SNE algorithm'
)
fig.update_traces(textposition='top center')
fig.show()

# Find similar words

In [13]:
def get_embedding(df_embeddings, vocab, word: str) -> str:
    return df_embeddings.loc[vocab.get_stoi()[word]]

In [14]:
def top_similar_words(df_embeddings: pd.DataFrame, word: str | list[float], vocab = None, n: int = 10) -> dict[str, float]:
    if isinstance(word, str):
        word_id = vocab[word]
        word_vec = df_embeddings.loc[word_id]
    else:
        word_id = -1
        word_vec = word
    distances = df_embeddings.apply(lambda embed: np.matmul(word_vec, embed), axis=1)
    if isinstance(word, str):
        distances = distances.drop(word_id)
    similar_distances = distances.sort_values(ascending=False).iloc[:n]
    return {vocab.get_itos()[id]: similar_distances[id] for id in similar_distances.index}

In [15]:
top_similar_words(df_embeddings, 'know', vocab, n=10)

{'think': 0.9609865,
 'really': 0.94858843,
 'want': 0.94356865,
 'don': 0.8979643,
 'feel': 0.8942174,
 'someone': 0.891913,
 'do': 0.8917615,
 'you': 0.8869977,
 'doesn': 0.883993,
 'believe': 0.87273043}

# Vector equations

In [16]:
embed_1 = get_embedding(df_embeddings, vocab, 'man')
embed_2 = get_embedding(df_embeddings, vocab, 'woman')
embed_3 = get_embedding(df_embeddings, vocab, 'king')
embed_4 = embed_3 - embed_1 + embed_2

In [17]:
top_n = top_similar_words(df_embeddings, embed_4, vocab, n=10)
top_n

{'djedkare': 0.97340596,
 'king': 0.9723966,
 'christ': 0.9478258,
 'lord': 0.93334126,
 'st': 0.9236278,
 'mary': 0.9071252,
 'gerard': 0.904146,
 'alfred': 0.8956475,
 'catholic': 0.8883756,
 'biographer': 0.8828482}