# Import

In [2]:
import pandas as pd
import numpy as np
import plotly.express as px
from sklearn.manifold import TSNE
import torch

In [3]:
from train import train
from word2vec.utils import load_yaml

# Train model

In [4]:
config = load_yaml('config.yaml')
config

{'model': 'CBOW',
 'model_dir': 'artifacts',
 'data_dir': 'data',
 'epochs': 10,
 'batch_size': 64,
 'learning_rate': 0.05,
 'checkpoint_frequency': 1,
 'steps_train': 60,
 'steps_val': 60,
 'n_words': 4,
 'min_word_frequency': 50,
 'max_sequence_length': 256,
 'embed_size': 50,
 'embed_max_norm': 1}

In [5]:
train(config)

Vocabulary size: 4099
Adjusting learning rate of group 0 to 5.0000e-02.
Epoch 1 / 10. Train loss: 5.8727, Validation loss: 3.8772
Adjusting learning rate of group 0 to 4.5000e-02.
Epoch 2 / 10. Train loss: 5.4055, Validation loss: 3.8699
Adjusting learning rate of group 0 to 4.0000e-02.
Epoch 3 / 10. Train loss: 5.2783, Validation loss: 3.9240
Adjusting learning rate of group 0 to 3.5000e-02.
Epoch 4 / 10. Train loss: 5.2126, Validation loss: 3.9753
Adjusting learning rate of group 0 to 3.0000e-02.
Epoch 5 / 10. Train loss: 5.1678, Validation loss: 3.9251
Adjusting learning rate of group 0 to 2.5000e-02.
Epoch 6 / 10. Train loss: 5.1175, Validation loss: 3.9390
Adjusting learning rate of group 0 to 2.0000e-02.
Epoch 7 / 10. Train loss: 5.0875, Validation loss: 3.9312
Adjusting learning rate of group 0 to 1.5000e-02.
Epoch 8 / 10. Train loss: 5.0409, Validation loss: 3.9930
Adjusting learning rate of group 0 to 1.0000e-02.
Epoch 9 / 10. Train loss: 5.0002, Validation loss: 3.9989
Adjust

# Demo

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load(f"{config['model_dir']}/model.pt", map_location=device)
vocab = torch.load(f"{config['model_dir']}/vocab.pt")

In [7]:
model

CBOW(
  (embeddings): Embedding(4099, 50, max_norm=1)
  (linear): Linear(in_features=50, out_features=4099, bias=True)
)

In [8]:
embeddings = list(model.parameters())[0]
embeddings = embeddings.detach().numpy()
df_embeddings = pd.DataFrame(embeddings)
df_embeddings

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,0.000579,0.074418,0.147624,0.065660,-0.049540,0.116364,0.211759,0.015510,-0.070899,-0.069466,...,-0.047297,-0.154251,0.113365,0.097166,-0.288124,-0.102232,-0.097474,-0.182347,0.180586,0.070746
1,0.096340,0.090729,-0.053377,0.066608,-0.283814,0.113487,-0.086393,-0.285818,-0.141846,0.046376,...,-0.017125,-0.126061,-0.224858,0.149425,-0.049315,-0.046874,0.082443,0.065943,0.006308,-0.185886
2,-0.040920,-0.020283,0.006201,0.101501,0.333140,-0.056435,0.103615,0.190687,-0.068326,0.069308,...,-0.079025,-0.256366,-0.206165,0.086232,-0.120812,-0.144534,0.088946,-0.221591,-0.317499,0.092255
3,0.226830,0.074295,0.016519,0.150539,0.274640,0.034020,0.047728,0.269956,0.050331,0.225864,...,0.014992,-0.111618,0.114063,0.033514,0.077608,-0.121704,0.134801,-0.212910,-0.283753,-0.024760
4,0.075703,0.035119,0.113515,0.152397,0.249103,0.073610,0.157959,-0.030533,0.126269,0.227408,...,-0.148045,-0.005974,0.213000,-0.045665,0.126412,-0.049099,-0.146071,-0.032726,0.223849,0.142328
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4094,0.086245,0.136938,0.135125,0.104863,0.244871,-0.085348,0.285394,0.035199,-0.049950,0.122726,...,-0.202779,0.005533,0.100520,-0.022237,-0.004941,-0.064221,-0.281689,0.139486,-0.010553,0.074286
4095,0.019155,0.088238,0.074970,-0.082796,-0.156666,0.126047,0.133371,0.026874,-0.056753,-0.118710,...,-0.019874,-0.083749,0.040024,0.102385,-0.085964,-0.050792,-0.101137,-0.266856,0.389514,-0.032107
4096,0.122592,0.145686,0.046430,-0.154911,-0.161012,-0.002371,0.069378,0.133784,-0.091619,-0.122243,...,0.045587,0.049309,-0.182077,0.093749,-0.047278,-0.012451,0.108881,-0.042897,-0.094294,0.237073
4097,-0.001457,0.027580,0.111501,0.030600,-0.051120,0.036001,0.132384,0.058927,-0.143698,-0.018530,...,-0.066456,-0.158292,0.116807,0.083249,-0.136541,-0.074800,-0.113990,-0.218624,0.247066,0.109642


# Visualisation

In [9]:
tsne = TSNE(n_components=2)
df_embeddings_transformed = pd.DataFrame(tsne.fit_transform(df_embeddings))
df_embeddings_transformed.index = vocab.get_itos()
df_embeddings_transformed.rename(columns={0: 'x', 1: 'y'}, inplace=True)

In [22]:
df_embeddings_plot = df_embeddings_transformed.sample(50, replace=False)

fig = px.scatter(df_embeddings_plot, x='x', y='y', text=df_embeddings_plot.index)
fig.update_layout(
    height=800,
    title_text='Visualisation of words using t-SNE algorithm'
)
fig.update_traces(textposition='top center')
fig.show()

# Find similar words

In [17]:
def get_embedding(df_embeddings, vocab, word: str) -> str:
    return df_embeddings.loc[vocab.get_stoi()[word]]

In [18]:
def top_similar_words(df_embeddings: pd.DataFrame, word: str | list[float], vocab = None, n: int = 10) -> dict[str, float]:
    if isinstance(word, str):
        word_id = vocab[word]
        word_vec = df_embeddings.loc[word_id]
    else:
        word_id = -1
        word_vec = word
    distances = df_embeddings.apply(lambda embed: np.matmul(word_vec, embed), axis=1)
    if isinstance(word, str):
        distances = distances.drop(word_id)
    similar_distances = distances.sort_values(ascending=False).iloc[:n]
    return {vocab.get_itos()[id]: similar_distances[id] for id in similar_distances.index}

In [19]:
top_similar_words(df_embeddings, 'know', vocab, n=10)

{'really': 0.99585295,
 'want': 0.9918584,
 'didn': 0.9839044,
 'think': 0.9781511,
 'doesn': 0.96299285,
 'believe': 0.96082723,
 'don': 0.9576585,
 'wasn': 0.9570808,
 'you': 0.9407921,
 'mean': 0.9243557}

# Vector equations

In [20]:
embed_1 = get_embedding(df_embeddings, vocab, 'man')
embed_2 = get_embedding(df_embeddings, vocab, 'woman')
embed_3 = get_embedding(df_embeddings, vocab, 'king')
embed_4 = embed_3 - embed_1 + embed_2

In [21]:
top_n = top_similar_words(df_embeddings, embed_4, vocab, n=10)
top_n

{'opinion': 1.1277608,
 'king': 1.1260895,
 'atmosphere': 1.1182954,
 'popularity': 1.1152412,
 'mariana': 1.1134031,
 'trujillo': 1.1068636,
 'philosophy': 1.0924467,
 'description': 1.090578,
 'circumstances': 1.0867015,
 'nation': 1.0761068}