# Import

In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
from sklearn.manifold import TSNE
import torch

In [2]:
from word2vec.src.train import train
from word2vec.src.utils import load_yaml

# Train model

In [3]:
config = load_yaml('config.yaml')
config

{'model': 'CBOW',
 'model_dir': 'artifacts',
 'data_dir': 'data',
 'epochs': 10,
 'batch_size': 64,
 'learning_rate': 0.05,
 'checkpoint_frequency': 1,
 'steps_train': 60,
 'steps_val': 60,
 'n_words': 4,
 'min_word_frequency': 50,
 'max_sequence_length': 256,
 'embed_size': 50,
 'embed_max_norm': 1}

In [4]:
train(config)

Vocabulary size: 4099
Adjusting learning rate of group 0 to 5.0000e-02.
Epoch 1 / 10. Train loss: 5.9044, Validation loss: 3.8893
Adjusting learning rate of group 0 to 4.5000e-02.
Epoch 2 / 10. Train loss: 5.4031, Validation loss: 3.9390
Adjusting learning rate of group 0 to 4.0000e-02.
Epoch 3 / 10. Train loss: 5.2968, Validation loss: 3.9414
Adjusting learning rate of group 0 to 3.5000e-02.
Epoch 4 / 10. Train loss: 5.2300, Validation loss: 3.9712
Adjusting learning rate of group 0 to 3.0000e-02.
Epoch 5 / 10. Train loss: 5.1643, Validation loss: 3.9180
Adjusting learning rate of group 0 to 2.5000e-02.
Epoch 6 / 10. Train loss: 5.1266, Validation loss: 3.9618
Adjusting learning rate of group 0 to 2.0000e-02.
Epoch 7 / 10. Train loss: 5.0710, Validation loss: 3.9375
Adjusting learning rate of group 0 to 1.5000e-02.
Epoch 8 / 10. Train loss: 5.0352, Validation loss: 3.9567
Adjusting learning rate of group 0 to 1.0000e-02.
Epoch 9 / 10. Train loss: 4.9881, Validation loss: 3.9715
Adjust

# Demo

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load(f"{config['model_dir']}/model.pt", map_location=device)
vocab = torch.load(f"{config['model_dir']}/vocab.pt")

In [6]:
model

CBOW(
  (embeddings): Embedding(4099, 50, max_norm=1)
  (linear): Linear(in_features=50, out_features=4099, bias=True)
)

In [7]:
embeddings = list(model.parameters())[0]
embeddings = embeddings.detach().numpy()
df_embeddings = pd.DataFrame(embeddings)
df_embeddings

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,-0.098839,-0.172087,-0.264099,-0.038619,0.075173,-0.070824,-0.135515,0.095532,0.163182,0.196603,...,-0.165546,-0.072509,0.110850,0.142184,-0.019874,-0.030324,0.014605,0.149147,0.103493,-0.067403
1,0.105999,-0.059489,-0.060163,0.189818,-0.167095,0.290664,0.074730,-0.030358,0.276500,0.011013,...,0.133246,0.245132,-0.011315,-0.029615,-0.196696,0.126419,0.213309,-0.205289,0.256524,-0.201837
2,-0.132155,0.101718,0.239807,-0.008616,0.018744,-0.070535,-0.162106,0.121187,-0.104910,0.181329,...,0.077887,0.291904,0.139919,0.102379,0.156431,-0.196776,0.038186,-0.008581,-0.226810,-0.040487
3,-0.017584,0.006306,0.256310,-0.015306,0.048258,-0.159094,-0.141126,0.052785,-0.102291,0.072849,...,-0.022162,-0.052363,0.202046,-0.042181,-0.144830,-0.088500,0.069987,-0.146259,-0.161737,0.110107
4,0.263584,0.218560,0.235492,-0.103390,0.110942,-0.257304,-0.181631,-0.072629,0.065196,0.113084,...,-0.089073,-0.050649,-0.012844,-0.126360,0.176091,0.119329,-0.052967,0.130097,0.108904,0.256307
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4094,-0.124015,-0.015919,0.259575,0.082721,0.046009,-0.221516,-0.058338,0.055235,-0.240489,0.101589,...,-0.137392,-0.047196,-0.153564,-0.079983,0.246937,0.149794,-0.048018,0.048593,0.076900,0.290880
4095,-0.120657,-0.004297,-0.006341,0.085011,-0.122029,0.077089,-0.074074,-0.066900,-0.017804,0.061503,...,-0.072031,-0.146721,0.034227,-0.002492,-0.120996,0.092408,-0.002568,0.002689,-0.017679,0.023213
4096,-0.180845,-0.218802,-0.167262,0.092539,0.053604,0.213206,-0.105693,0.140111,0.051984,0.227320,...,-0.078821,0.090479,0.181083,0.208519,-0.012982,-0.028508,-0.295906,0.008079,0.231796,-0.013389
4097,-0.220776,-0.076696,-0.250694,0.207355,-0.016029,0.036269,-0.099627,0.284319,-0.117931,0.067102,...,0.027536,-0.241473,0.003752,0.085492,0.354816,0.022406,-0.121308,0.020005,-0.071856,-0.121679


# Visualisation

In [8]:
tsne = TSNE(n_components=2)
df_embeddings_transformed = pd.DataFrame(tsne.fit_transform(df_embeddings))
df_embeddings_transformed.index = vocab.get_itos()
df_embeddings_transformed.rename(columns={0: 'x', 1: 'y'}, inplace=True)

In [20]:
df_embeddings_plot = df_embeddings_transformed.sample(50, replace=False)

fig = px.scatter(df_embeddings_plot, x='x', y='y', text=df_embeddings_plot.index)
fig.update_layout(
    height=800,
    title_text='Visualisation of words using t-SNE algorithm'
)
fig.update_traces(textposition='top center')
fig.show()

# Find similar words

In [15]:
def get_embedding(df_embeddings, vocab, word: str) -> str:
    return df_embeddings.loc[vocab.get_stoi()[word]]

In [16]:
def top_similar_words(df_embeddings: pd.DataFrame, word: str | list[float], vocab = None, n: int = 10) -> dict[str, float]:
    if isinstance(word, str):
        word_id = vocab[word]
        word_vec = df_embeddings.loc[word_id]
    else:
        word_id = -1
        word_vec = word
    distances = df_embeddings.apply(lambda embed: np.matmul(word_vec, embed), axis=1)
    if isinstance(word, str):
        distances = distances.drop(word_id)
    similar_distances = distances.sort_values(ascending=False).iloc[:n]
    return {vocab.get_itos()[id]: similar_distances[id] for id in similar_distances.index}

In [17]:
top_similar_words(df_embeddings, 'know', vocab, n=10)

{'wasn': 1.0163375,
 'doesn': 0.9734401,
 'think': 0.9618387,
 'got': 0.95153177,
 'want': 0.9223029,
 'feel': 0.9200848,
 'do': 0.8937379,
 'didn': 0.8842744,
 'don': 0.88322556,
 'really': 0.8814473}

# Vector equations

In [18]:
embed_1 = get_embedding(df_embeddings, vocab, 'man')
embed_2 = get_embedding(df_embeddings, vocab, 'woman')
embed_3 = get_embedding(df_embeddings, vocab, 'king')
embed_4 = embed_3 - embed_1 + embed_2

In [19]:
top_n = top_similar_words(df_embeddings, embed_4, vocab, n=10)
top_n

{'king': 1.0568279,
 'st': 0.95331234,
 'castle': 0.9513191,
 'bishop': 0.94327086,
 'aftermath': 0.9377484,
 'mary': 0.9267597,
 'image': 0.9262364,
 'reign': 0.909544,
 'church': 0.90753704,
 'regime': 0.9028273}