In [None]:
# project with a mock vocabulary to show how word2vec learns.

# mock text data using king queen prince princess, and france capital paris city. We will train a word2vec model to learn the relationships between these words.
import torch
import torch.nn as nn
import torch.optim as optim
import re
from collections import Counter
# Sample text data
text = """
king queen king prince princess king queen prince princess
king prince princess king queen prince princess
france capital paris france capital paris france capital paris
paris capital france paris capital france paris capital france
"""


In [None]:
words = text.split()
counter = Counter(words)
vocab = sorted(counter.keys())
word_to_index = {word: i for i, word in enumerate(vocab)}
index_to_word = {i: word for i, word in enumerate(vocab)}
embedding_dim = 2


In [None]:
# training data preparation, creating pairs
training_data = []
window_size = 2
for i in range(len(words)):
    target_word = words[i]
    context_indices = list(range(max(0, i - window_size), i)) + list(range(i + 1, min(len(words), i + window_size + 1)))
    for j in context_indices:
        context_word = words[j]
        training_data.append((word_to_index[target_word], word_to_index[context_word]))
print(f"Num training pairs: {len(training_data)}")
print(f"Sample training pairs: {training_data[:5]}")

In [None]:
# model is composed of Wv and Wu, which are the input and output word embeddings. 
# The forward pass computes the dot product between the input word embedding and the output word embeddings of the context words, and applies a sigmoid function to get probabilities. 
# The loss is computed using negative sampling, where we sample negative examples from the vocabulary and compute the loss for both positive and negative examples.

Wv = torch.randn(len(vocab), embedding_dim, requires_grad=True)
Wu = torch.randn(len(vocab), embedding_dim, requires_grad=True)

epochs = 100
learning_rate = 0.5
batch_size = 64
num_negatives = 5

def draw_embeddings(Wv, index_to_word):
    import matplotlib.pyplot as plt
    Wv_np = Wv.detach().numpy()
    plt.figure(figsize=(8, 8))
    for i in range(len(index_to_word)):
        plt.scatter(Wv_np[i, 0], Wv_np[i, 1])
        plt.text(Wv_np[i, 0] + 0.01, Wv_np[i, 1] + 0.01, index_to_word[i], fontsize=9)
    plt.title("Word Embeddings")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.grid()
    plt.show()

embeddings = []

for epoch in range(epochs):
    idx = torch.randint(0, len(training_data), (batch_size,))

    # collect v_c and u_c for the batch
    v_c = Wv[[training_data[i][0] for i in idx]]
    u_c = Wu[[training_data[i][1] for i in idx]]

    # positive examples
    pos_scores = torch.sum(v_c * u_c, dim=1)
    pos_loss = -torch.log(torch.sigmoid(pos_scores)).mean()

    # negative sampling
    neg_indices = torch.randint(0, len(vocab), (batch_size, num_negatives))
    neg_u_c = Wu[neg_indices]
    neg_scores = torch.bmm(neg_u_c, v_c.unsqueeze(2)).squeeze()
    neg_loss = -torch.log(torch.sigmoid(-neg_scores)).mean()

    loss = pos_loss + neg_loss
    loss.backward()
    with torch.no_grad():
        Wv -= learning_rate * Wv.grad
        Wu -= learning_rate * Wu.grad
        Wv.grad.zero_()
        Wu.grad.zero_()

    if (epoch + 1) % 2 == 0:
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")

        embeddings.append(Wv.detach().clone())

# draw_embeddings(Wv, index_to_word)


In [None]:
draw_embeddings(embeddings[-1], index_to_word)

In [None]:
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation

fig, ax = plt.subplots(figsize=(8, 8))

# Fix axis limits ONCE so they donâ€™t rescale
all_emb = torch.stack(embeddings)
xmin, xmax = all_emb[:,:,0].min(), all_emb[:,:,0].max()
ymin, ymax = all_emb[:,:,1].min(), all_emb[:,:,1].max()

def update(frame):
    ax.clear()
    Wv_np = embeddings[frame].numpy()

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    ax.scatter(Wv_np[:, 0], Wv_np[:, 1])

    for i in range(len(index_to_word)):
        ax.text(Wv_np[i, 0], Wv_np[i, 1], index_to_word[i], fontsize=9)

    ax.set_title(f"Epoch {frame}")

ani = animation.FuncAnimation(
    fig, update,
    frames=len(embeddings),
    interval=500,
    blit=False
)

HTML(ani.to_jshtml())