In [None]:
import math
from collections import defaultdict
from argparse import Namespace
from pathlib import Path

import matplotlib.pylab as plt
from IPython.display import display

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

from utils import seed_everything, read_cs224_sentences, pca

seed_everything()

%load_ext autoreload
%autoreload 2

In [None]:
config = Namespace(
    data_path=Path("../code/utils/datasets/stanfordSentimentTreebank/"),
    num_context=3,
    batch_size=64,
    embedding_dim=10,
    num_epochs=10,
    lr=1e-2,
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
)

In [None]:
class NaiveDataset(Dataset):
    def __init__(self, path, unk_token="UNK", num_context=2, flatten=False):
        self._unk_token = unk_token
        self._num_context = num_context
        self.df, _, self._token_dict, self._token_freq = read_cs224_sentences(path, unk_token, nrows=5000)
        sentences = self.df.sentence.apply(
            lambda s: [unk_token] * num_context + s + [unk_token] * num_context
        ).to_list()
        context = np.concatenate(
            [
                np.lib.stride_tricks.sliding_window_view(self.sentence_to_index(s), num_context * 2 + 1)
                for s in sentences
            ]
        )
        center_word = context[:, num_context]
        context = np.concatenate([context[:, :num_context], context[:, num_context + 1 :]], axis=1)
        if flatten:
            center_word = np.repeat(center_word, context.shape[1])
            context = context.flatten()

        self.data = list(zip(center_word, context))

    def sentence_to_index(self, sentence):
        return [self.word2idx.get(w, self.word2idx[self._unk_token]) for w in sentence]

    @property
    def idx2word(self):
        return list(self._token_dict.keys())

    @property
    def word2idx(self):
        return self._token_dict

    @property
    def word_freq(self):
        return self._token_freq

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class SkipGramMode(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self._embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self._linear = torch.nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        x = self._embedding(x)
        x = self._linear(x)
        return x

    @property
    def embeddings(self):
        return self._embedding.weight.detach().cpu().numpy()


def train(train_loader, model, criterion, optimizer):
    model.train()
    train_loss = 0
    for inputs, targets in train_loader:
        output = model(inputs)
        loss = criterion(output, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    return model, train_loss / len(train_loader.dataset)


def valid(valid_loader, model, criterion):
    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for inputs, targets in valid_loader:
            output = model(inputs)
            loss = criterion(output, targets)
            valid_loss += loss.item()
    return valid_loss / len(valid_loader.dataset)

In [None]:
tmp_dataset = NaiveDataset(path=config.data_path / "datasetSentences.txt", num_context=3)
tmp_loader = DataLoader(tmp_dataset, batch_size=4, shuffle=True)

print(f"The total number of unique tokens is {len(tmp_dataset.word_freq)}")
print(f"The total number of dataset is {len(tmp_dataset)}")

display(tmp_dataset.df.head())
print(f"The top 10 most common words are:")
display(tmp_dataset.word_freq.most_common(10))

In [None]:
# Call the data generator to get one batch and its targets
inputs, targets = next(iter(tmp_loader))

print(f"The inputs shape is {inputs.shape}")
print(f"The targets shape is {targets.shape}")

print(f"The center and context words are:")
for c, o in zip(inputs, targets):
    print(tmp_dataset.idx2word[c], [tmp_dataset.idx2word[w] for w in o])

In [None]:
dataset = NaiveDataset(path=config.data_path / "datasetSentences.txt", num_context=config.num_context, flatten=True)
data_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
inputs, targets = next(iter(data_loader))

model = SkipGramMode(vocab_size=len(dataset.word2idx), embedding_dim=config.embedding_dim)
output = model(inputs)
print(f"The inputs shape is {inputs.shape}")
print(f"The targets shape is {targets.shape}")
print(f"The output shape is {output.shape}")

In [None]:
num_digits = int(math.log10(config.num_epochs)) + 1

optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
criterion = torch.nn.CrossEntropyLoss()

moving_loss = None

losses = defaultdict(list)
for epoch in range(config.num_epochs + 1):
    model, loss = train(train_loader=data_loader, model=model, criterion=criterion, optimizer=optimizer)
    if moving_loss is None:
        moving_loss = loss
    else:
        moving_loss = 0.95 * moving_loss + 0.05 * loss
    losses["loss"].append(loss)
    losses["moving_loss"].append(moving_loss)

    print(f"Epoch {epoch+1:0{num_digits}d}/{config.num_epochs} Loss: {moving_loss:.4f}")

In [None]:
_, axs = plt.subplots(nrows=len(losses.keys()), figsize=(10, 10), sharex=True)

for ax, (k, v) in zip(axs, losses.items()):
    ax.plot(v)
    ax.set_title(k)
    ax.grid()
plt.tight_layout()
plt.show()

In [None]:
words = [
    "great",
    "cool",
    "brilliant",
    "wonderful",
    "well",
    "amazing",
    "worth",
    "sweet",
    "enjoyable",
    "boring",
    "bad",
    "dumb",
    "annoying",
    "female",
    "male",
    "queen",
    "king",
    "man",
    "woman",
    "rain",
    "snow",
    "hail",
    "coffee",
    "tea",
]

words = [w for w in words if w in dataset.word2idx]

word_indices = [dataset.word2idx[w] for w in words]
result = pca(model.embeddings[word_indices], 2)
result = result / np.linalg.norm(result, axis=1, keepdims=True)

plt.figure(figsize=(10, 8))
plt.scatter(result[:, 0], result[:, 1])
for i, word in enumerate(words):
    plt.annotate(word, xy=(result[i, 0], result[i, 1]))
plt.show()