In [1]:
import os
import requests
import zipfile

import random
import math
import itertools

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import lightning as L
import torchmetrics as tm

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

L.seed_everything(42)

Global seed set to 42


42

## Downloading WikiText

In [2]:
DATA_URL = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"
DATA_DIR = ".data"
SPLITS = ["train", "valid", "test"]

In [3]:
# Download the zip file from DATA_URL and extract it
def download_and_extract(url, data_dir, force=False):
    # Create the data directory if it does not exist
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    # Get filename from URL
    filename = url.split("/")[-1]
    filepath = os.path.join(data_dir, filename)

    # Download the zip file and save it to disk
    if not os.path.exists(filepath) or force:
        print(f"Downloading {url} to {filepath}")
        r = requests.get(url, stream=True)
        with open(filepath, "wb") as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:
                    f.write(chunk)
    else:
        print(f"File {filepath} already downloaded")

    # Extract the zip file
    if not os.path.exists(os.path.join(data_dir, "wikitext-2")) or force:
        print(f"Extracting {filepath} to {data_dir}")
        with zipfile.ZipFile(filepath, "r") as f:
            f.extractall(data_dir)
    else:
        print(f"File {filepath} already extracted")

    return os.path.join(data_dir)

In [4]:
data_dir = download_and_extract(DATA_URL, DATA_DIR)

File .data/wikitext-2-v1.zip already downloaded
File .data/wikitext-2-v1.zip already extracted


In [5]:
def read_data(data_dir, split):
    assert split in SPLITS, f"split must be one of {SPLITS}"
    filepath = os.path.join(data_dir, f"wikitext-2/wiki.{split}.tokens")
    with open(filepath, "r") as f:
        return f.read()

In [6]:
text = read_data(data_dir, "train")

In [7]:
text[:50]

' \n = Valkyria Chronicles III = \n \n Senjō no Valkyr'

## Preprocessing text into list of words

In [8]:
def split_words(text):
    return text.split()

In [9]:
text = split_words(text)
text = [token for token in text if token != "<unk>"]

In [10]:
text[:10]

['=',
 'Valkyria',
 'Chronicles',
 'III',
 '=',
 'Senjō',
 'no',
 'Valkyria',
 '3',
 ':']

In [11]:
text = text[:len(text) // 4]

## Creating vocabulary

In [12]:
class Vocab:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def build(self, text):
        for word in text:
            if word not in self.word2idx:
                self.word2idx[word] = self.idx
                self.idx2word[self.idx] = word
                self.idx += 1

    def encode(self, text):
        return [self.word2idx[word] for word in text if word in self.word2idx]

    def decode(self, tokens):
        return [self.idx2word[idx] for idx in tokens if idx in self.idx2word]

In [13]:
vocab = Vocab()
vocab.build(text)

In [14]:
assert text[:10] == vocab.decode(vocab.encode(text[:10]))

## Sub-Sampling

In [16]:
def get_word_counts(text):
    counts = {}
    for word in text:
        if word not in counts:
            counts[word] = 0
        counts[word] += 1
    return counts

In [17]:
word_counts = get_word_counts(text)

In [18]:
def keep_word(count, total, threshold=1e-3):
        prob = 1 - math.sqrt(threshold * total / count)
        return random.random() > prob

In [19]:
len(text), sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:10]

(499321,
 [('the', 28049),
  (',', 25449),
  ('.', 18228),
  ('of', 14406),
  ('and', 12597),
  ('in', 9946),
  ('to', 9676),
  ('a', 8406),
  ('=', 7358),
  ('"', 7111)])

In [20]:
text = [token for token in text if keep_word(word_counts[token], len(text))]

In [21]:
len(text), sorted(get_word_counts(text).items(), key=lambda x: x[1], reverse=True)[:10]

(341238,
 [('the', 3862),
  (',', 3602),
  ('.', 3003),
  ('of', 2621),
  ('and', 2568),
  ('in', 2167),
  ('to', 2152),
  ('a', 2070),
  ('=', 1901),
  ('"', 1824)])

## Sequence Iterator

In [22]:
def get_centre_contexts(text, window_size=5):
    centre_contexts = []
    for i in range(window_size, len(text) - window_size):
        centre = text[i]
        context = text[i-window_size:i] + text[i+1:i+window_size+1]
        centre_contexts.append((centre, context))
    return centre_contexts

In [23]:
get_centre_contexts(vocab.encode(text[30:50]), 2)[:5]

[(37, [34, 36, 38, 39]),
 (38, [36, 37, 39, 40]),
 (39, [37, 38, 40, 42]),
 (40, [38, 39, 42, 43]),
 (42, [39, 40, 43, 23])]

## Negative Sampling

In [24]:
class Sampler:
    def __init__(self, corpus, cache_size=100000):
        self.word_counts = get_word_counts(corpus)
        self.words = list(self.word_counts.keys())
        self.sampling_weights = [self.word_counts[word] ** 0.75 for word in self.words]
        self.cumulative_weights = list(itertools.accumulate(self.sampling_weights))
        self.cache_size = cache_size
        self.cache = random.choices(self.words, cum_weights=self.cumulative_weights, k=self.cache_size)
        self.cache_idx = 0

    def sample_negative(self, contexts, k=5):
        negatives = []
        for _ in range(k):
            while True:
                word = self.cache[self.cache_idx]
                self.cache_idx = self.cache_idx + 1

                if self.cache_idx >= self.cache_size:
                    self.cache = random.choices(self.words, cum_weights=self.cumulative_weights, k=self.cache_size)
                    self.cache_idx = 0

                if word not in contexts:
                    negatives.append(word)
                    break
        return negatives

In [25]:
sampler = Sampler(vocab.encode(text), cache_size=10000)
sampler.sample_negative(["the", "a"])

[10661, 3589, 13138, 6689, 11654]

In [26]:
def get_centre_contexts_negatives_pairs(centre_contexts, sampler, negative_per_context=5):
    pairs = []
    for centre, contexts in tqdm(centre_contexts):
        for context in contexts:
            pairs.append((centre, context, 1))
            negatives = sampler.sample_negative(contexts, k=negative_per_context)
            for negative in negatives:
                pairs.append((centre, negative, 0))
    return pairs

In [27]:
def get_centre_contexts_negatives_pairs(centre_contexts, sampler, negative_per_context=5):
    centres = []
    contexts_and_negatives = []
    labels = []
    for centre, contexts in tqdm(centre_contexts):
        for context in contexts:
            centres.append(centre)
            contexts_and_negatives.append(context)
            labels.append(1)
            negatives = sampler.sample_negative(contexts, k=negative_per_context)
            for negative in range(negative_per_context):
                centres.append(centre)
                contexts_and_negatives.append(negatives[negative])
                labels.append(0)
    return centres, contexts_and_negatives, labels

In [28]:
WINDOW_SIZE = 10
NEGATIVE_PER_CONTEXT = 5

# centre_contexts = get_centre_contexts(vocab.encode(text), window_size=WINDOW_SIZE)
# pairs = get_centre_contexts_negatives_pairs(
#     centre_contexts, sampler, negative_per_context=NEGATIVE_PER_CONTEXT
# )

In [29]:
def get_centre_contexts_negatives_pairs_lazy(centre_contexts, sampler, negative_per_context=5):
    for centre, contexts in centre_contexts:
        for context in contexts:
            yield centre, context, 1
            negatives = sampler.sample_negative(contexts, k=negative_per_context)
            for negative in range(negative_per_context):
                yield centre, negatives[negative], 0

In [30]:
WINDOW_SIZE = 10
NEGATIVE_PER_CONTEXT = 5
centre_contexts = get_centre_contexts(vocab.encode(text), window_size=WINDOW_SIZE)
pairs_len = len(centre_contexts) * WINDOW_SIZE * 2 * (1 + NEGATIVE_PER_CONTEXT)

## Batch Iterator

In [31]:
class Word2VecDataset(Dataset):
    def __init__(
        self,
        centre_contexts,
        sampler,
        negative_per_context,
        pairs_len,
        vocab,
    ):
        self.centre_contexts = centre_contexts
        self.sampler = sampler
        self.negative_per_context = negative_per_context
        self.pair_iter = get_centre_contexts_negatives_pairs_lazy(
            centre_contexts, sampler, negative_per_context=negative_per_context
        )
        self.pairs_len = pairs_len
        self.vocab = vocab

    def __len__(self):
        return self.pairs_len

    def __getitem__(self, idx):
        try:
            return next(self.pair_iter)
        except StopIteration:
            self.pair_iter = get_centre_contexts_negatives_pairs_lazy(
                self.centre_contexts, self.sampler, negative_per_context=self.negative_per_context
            )
            return next(self.pair_iter)

In [35]:
train_ds = Word2VecDataset(
    centre_contexts,
    sampler,
    negative_per_context=NEGATIVE_PER_CONTEXT,
    pairs_len=pairs_len,
    vocab=vocab,
)
train_dl = DataLoader(train_ds, batch_size=1024)

## Embedding Model

In [36]:
class EmbeddingModel(L.LightningModule):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.word_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.context_embedding = nn.Embedding(vocab_size, embedding_dim)

        def init_weights(module):
            if isinstance(module, nn.Embedding):
                nn.init.xavier_uniform_(module.weight)

        self.apply(init_weights)

        self.accuracy = tm.Accuracy(task="binary")

    def forward(self, word, context):
        word = self.word_embedding(word)
        context = self.context_embedding(context)
        return (word * context).sum(dim=1)

    def training_step(self, batch, batch_idx):
        centre, context, label = batch
        y_hat = self(centre, context)
        loss = nn.functional.binary_cross_entropy_with_logits(y_hat, label.float())
        self.log("train_loss", loss)
        pred = torch.round(torch.sigmoid(y_hat))
        self.accuracy(pred, label)
        self.log("train_acc", self.accuracy, prog_bar=False, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [37]:
model = EmbeddingModel(len(vocab.word2idx), 128)
trainer = L.Trainer(max_epochs=5)
trainer.fit(model, train_dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type           | Params
-----------------------------------------------------
0 | word_embedding    | Embedding      | 3.0 M 
1 | context_embedding | Embedding      | 3.0 M 
2 | accuracy          | BinaryAccuracy | 0     
-----------------------------------------------------
5.9 M     Trainable params
0         Non-trainable params
5.9 M     Total params
23.642    Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Using embeddings

In [38]:
model.cpu()

EmbeddingModel(
  (word_embedding): Embedding(23088, 128)
  (context_embedding): Embedding(23088, 128)
  (accuracy): BinaryAccuracy()
)

In [39]:
word_embedding = model.word_embedding

In [40]:
def get_embedding(word):
    idx = vocab.word2idx[word]
    return word_embedding(torch.tensor(idx)).detach().numpy()

### Similarity

In [41]:
def cosine_similarity(embed_a, embed_b):
    return np.dot(embed_a, embed_b) / (
        np.linalg.norm(embed_a) * np.linalg.norm(embed_b)
    )

In [73]:
words = ("TV", "comedy")
embeddings = [get_embedding(word) for word in words]
cosine_similarity(*embeddings)

0.70177114

In [72]:
# Get cosine similarity between all words and given word
word = "star"
embeddings = get_embedding(word)
cosine_similarities = [
    cosine_similarity(embeddings, get_embedding(word)) for word in vocab.word2idx
]

# Sort and get top 10 most similar words and their cosine similarity
sorted_cosine_similarities = sorted(
    zip(cosine_similarities, vocab.word2idx), reverse=True
)
sorted_cosine_similarities[:10]

[(0.9999999, 'star'),
 (0.7052185, 'playing'),
 (0.6851464, 'starred'),
 (0.65827155, 'cast'),
 (0.65029514, 'comedy'),
 (0.6498021, 'TV'),
 (0.64620864, 'appeared'),
 (0.6402657, 'Sonia'),
 (0.6315743, 'starring'),
 (0.6226327, 'television')]