In [51]:
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import WikiText2
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import LambdaLR
import numpy as np

MIN_FREQUENCY_TOKEN = 50
BATCH_SIZE = 16
N_SKIPGRAM = 2
MAX_CONTEXT_WINDOW = 256
DEVICE = 'mps' if torch.has_mps else 'cpu'
train_iter = to_map_style_dataset(WikiText2(split='train'))
val_iter = to_map_style_dataset(WikiText2(split='valid'))
tokenizer = get_tokenizer('basic_english', language='en')

def build_vocab(data_iter: Dataset):
    vocab = build_vocab_from_iterator(map(tokenizer, data_iter), specials=['<unk>'], min_freq=MIN_FREQUENCY_TOKEN)
    vocab.set_default_index(vocab['<unk>'])
    return vocab

vocab = build_vocab(train_iter)

def collate_skipgram(paragraphs: List[str]):
    xb, yb = [], []
    for paragraph in paragraphs:
        text_indices = vocab(tokenizer(paragraph))
        window = N_SKIPGRAM * 2 + 1
        if len(text_indices) < window:
            continue
        text_indices = text_indices[:MAX_CONTEXT_WINDOW]
        for i in range(len(text_indices) - 2 * N_SKIPGRAM):
            for j in range(1, N_SKIPGRAM+1):
                yb.append(text_indices[i+N_SKIPGRAM+j])
                yb.append(text_indices[i+N_SKIPGRAM-j])
                xb.append(text_indices[i+N_SKIPGRAM])
                xb.append(text_indices[i+N_SKIPGRAM])
    
    return torch.tensor(xb, dtype=torch.long).to(DEVICE), torch.tensor(yb, dtype=torch.long).to(DEVICE)

train_data = DataLoader(
    train_iter,
    BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_skipgram
)
val_data = DataLoader(
    val_iter,
    BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_skipgram
)





In [23]:
a, b = collate_skipgram(["this movie is bad by a lot"])
for x, y in zip(a, b):
    x_word = vocab.lookup_token(x)
    y_word = vocab.lookup_token(y)
    print(f"{x_word} --> {y_word}")

is --> bad
is --> movie
is --> by
is --> this
bad --> by
bad --> is
bad --> a
bad --> movie
by --> a
by --> bad
by --> lot
by --> is


In [52]:
max(a.shape[0] for a, _ in train_data)

7640

In [41]:
VOCAB_SIZE = len(vocab)
N_EMBED = 300
class Skipgram(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        self.embeddings = nn.Embedding(VOCAB_SIZE, N_EMBED, max_norm=1.0)
        self.ln = nn.Linear(N_EMBED, VOCAB_SIZE)

    def forward(self, x: torch.Tensor, target: torch.Tensor = None):
        # x = (B)
        x = self.embeddings(x) # (B, N_EMBED)
        logits = self.ln(x) # (B, VOCAB_SIZE)
        loss = None
        if target is not None:
            loss = F.cross_entropy(logits, target)
        
        return logits, loss

m = Skipgram().to(DEVICE)
for a, b in train_data:
    m(a)
    break

In [54]:
def train(model: nn.Module, optimizer: torch.optim.Optimizer, eval_iter: int, dataloader: DataLoader):
    losses = []
    for i, (xb, yb) in enumerate(dataloader, 1):
        # forward
        _, loss = model(xb, yb)
        losses.append(loss.item())
        if i % eval_iter == 0:
            loss_mean = np.mean(losses)
            print(f"iteration: {i}, training_loss: {loss_mean:.4f}")
            losses = []

        # backward
        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        # update
        optimizer.step()
    
    return model

@torch.no_grad()
def validate(model: nn.Module, dataloader: DataLoader):
    losses = []
    model.eval()
    for x, y in dataloader:
        _, loss = model(x, y)
        losses.append(loss.item())
    loss_mean = np.mean(losses)
    print(f"validation_loss: {loss_mean}")
    model.train()

def training(model: nn.Module, eval_iter: int, epochs: int, train_data: DataLoader, val_data: DataLoader):
    optimzr = torch.optim.Adam(model.parameters(), lr=0.025)
    lr_scheduler = LambdaLR(optimzr, lambda i: (epochs - i) / epochs, verbose=True)
    validate(model, val_data)
    for epoch in range(epochs):
        print(f"================ EPOCHS {epoch} ================")
        train(model, optimzr, eval_iter, train_data)
        validate(model, val_data)
        lr_scheduler.step()
    
    return model

In [55]:
m = Skipgram().to(DEVICE)
m = training(m, 100, 5, train_data, val_data)
torch.save(m, 'params/skipgram.pt')

Adjusting learning rate of group 0 to 2.5000e-02.
validation_loss: 8.306036502756971
iteration: 100, training_loss: 6.0311
iteration: 200, training_loss: 5.6738
iteration: 300, training_loss: 5.5879
iteration: 400, training_loss: 5.5181
iteration: 500, training_loss: 5.4631
iteration: 600, training_loss: 5.4610
iteration: 700, training_loss: 5.4076
iteration: 800, training_loss: 5.3964
iteration: 900, training_loss: 5.4159
iteration: 1000, training_loss: 5.4401
iteration: 1100, training_loss: 5.4409
iteration: 1200, training_loss: 5.4002
iteration: 1300, training_loss: 5.4052
iteration: 1400, training_loss: 5.4110
iteration: 1500, training_loss: 5.4051
iteration: 1600, training_loss: 5.4050
iteration: 1700, training_loss: 5.4272
iteration: 1800, training_loss: 5.3803
iteration: 1900, training_loss: 5.4039
iteration: 2000, training_loss: 5.3833
iteration: 2100, training_loss: 5.4366
iteration: 2200, training_loss: 5.4295
validation_loss: 5.373748115783042
Adjusting learning rate of grou

In [56]:
import pandas as pd
from sklearn.manifold import TSNE
import plotly.graph_objects as go

model: torch.Tensor = torch.load('params/skipgram.pt')
embeddings = list(model.parameters())[0]
# get embeddings
embeddings_df = pd.DataFrame(embeddings.cpu().detach().numpy())

# t-SNE transform
tsne = TSNE(n_components=2)
embeddings_df_trans = tsne.fit_transform(embeddings_df)
embeddings_df_trans = pd.DataFrame(embeddings_df_trans)

# get token order
embeddings_df_trans.index = vocab.get_itos()

# if token is a number
is_numeric = embeddings_df_trans.index.str.isnumeric()
color = np.where(is_numeric, "green", "black")
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=embeddings_df_trans[0],
        y=embeddings_df_trans[1],
        mode="text",
        text=embeddings_df_trans.index,
        textposition="middle center",
        textfont=dict(color=color),
    )
)
fig.write_html("./word2vec_visualization_skipgram.html")
fig.show()