In [None]:
%%capture
!pip install torch torchtext torchdata

In [None]:
import os
import random
import re
from collections import Counter, OrderedDict
from dataclasses import dataclass
from time import monotonic
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from scipy.spatial.distance import cosine
from torch.utils.data import DataLoader
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import WikiText103
from tqdm import tqdm


In [None]:
def get_data():
    # gets the data
    train_iter = WikiText103(split='train')
    train_iter = to_map_style_dataset(train_iter)
    valid_iter = WikiText103(split='test')
    valid_iter = to_map_style_dataset(valid_iter)

    return train_iter, valid_iter

In [None]:
@dataclass
class Word2VecParams:

    # skipgram parameters
    MIN_FREQ = 50
    SKIPGRAM_N_WORDS = 8
    T = 85
    NEG_SAMPLES = 50
    NS_ARRAY_LEN = 5_000_000
    SPECIALS = "<unk>"
    TOKENIZER = 'basic_english'

    # network parameters
    BATCH_SIZE = 100
    EMBED_DIM = 300
    EMBED_MAX_NORM = None
    N_EPOCHS = 5
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    CRITERION = nn.BCEWithLogitsLoss()

In [None]:
class Vocab:
    def __init__(self, list, specials):
        self.stoi = {v[0]:(k, v[1]) for k, v in enumerate(list)}
        self.itos = {k:(v[0], v[1]) for k, v in enumerate(list)}
        self._specials = specials[0]
        self.total_tokens = np.nansum(
            [f for _, (_, f) in self.stoi.items()]
            , dtype=int)

    def __len__(self):
        return len(self.stoi) - 1

    def get_index(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.stoi:
                return self.stoi.get(word)[0]
            else:
                return self.stoi.get(self._specials)[0]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi:
                    res.append(self.stoi.get(w)[0])
                else:
                    res.append(self.stoi.get(self._specials)[0])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )


    def get_freq(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.stoi:
                return self.stoi.get(word)[1]
            else:
                return self.stoi.get(self._specials)[1]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi:
                    res.append(self.stoi.get(w)[1])
                else:
                    res.append(self.stoi.get(self._specials)[1])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )


    def lookup_token(self, token: Union[int, List]):
        if isinstance(token, (int, np.int64)):
            if token in self.itos:
                return self.itos.get(token)[0]
            else:
                raise ValueError(f"Token {token} not in vocabulary")
        elif isinstance(token, list):
            res = []
            for t in token:
                if t in self.itos:
                    res.append(self.itos.get(token)[0])
                else:
                    raise ValueError(f"Token {t} is not a valid index.")
            return res

In [None]:
def yield_tokens(iterator, tokenizer):
    r = re.compile('[a-z1-9]')
    for text in iterator:
        res = tokenizer(text)
        res = list(filter(r.match, res))
        yield res

def vocab(ordered_dict: Dict, min_freq: int = 1, specials: str = '<unk>'):
    tokens = []
    # Save room for special tokens
    for token, freq in ordered_dict.items():
        if freq >= min_freq:
            tokens.append((token, freq))

    specials = (specials, np.nan)
    tokens[0] = specials

    return Vocab(tokens, specials)

def pipeline(word, vocab, tokenizer):
    return vocab(tokenizer(word))

def build_vocab(
        iterator,
        tokenizer,
        params: Word2VecParams,
        max_tokens: Optional[int] = None,
    ):
    counter = Counter()
    for tokens in yield_tokens(iterator, tokenizer):
        counter.update(tokens)

    # First sort by descending frequency, then lexicographically
    sorted_by_freq_tuples = sorted(
        counter.items(), key=lambda x: (-x[1], x[0])
        )

    ordered_dict = OrderedDict(sorted_by_freq_tuples)

    word_vocab = vocab(
        ordered_dict, min_freq=params.MIN_FREQ, specials=params.SPECIALS
        )
    return word_vocab


In [None]:
class SkipGrams:
    def __init__(self, vocab: Vocab, params: Word2VecParams, tokenizer):
        self.vocab = vocab
        self.params = params
        self.t = self._t()
        self.tokenizer = tokenizer
        self.discard_probs = self._create_discard_dict()

    def _t(self):
        freq_list = []
        for _, (_, freq) in list(self.vocab.stoi.items())[1:]:
            freq_list.append(freq/self.vocab.total_tokens)
        return np.percentile(freq_list, self.params.T)


    def _create_discard_dict(self):
        discard_dict = {}
        for _, (word, freq) in self.vocab.stoi.items():
            dicard_prob = 1-np.sqrt(
                self.t / (freq/self.vocab.total_tokens + self.t))
            discard_dict[word] = dicard_prob
        return discard_dict


    def collate_skipgram(self, batch):
        batch_input, batch_output  = [], []
        for text in batch:
            text_tokens = self.vocab.get_index(self.tokenizer(text))

            if len(text_tokens) < self.params.SKIPGRAM_N_WORDS * 2 + 1:
                continue

            for idx in range(len(text_tokens) - self.params.SKIPGRAM_N_WORDS*2
                ):
                token_id_sequence = text_tokens[
                    idx : (idx + self.params.SKIPGRAM_N_WORDS * 2 + 1)
                    ]
                input_ = token_id_sequence.pop(self.params.SKIPGRAM_N_WORDS)
                outputs = token_id_sequence

                prb = random.random()
                del_pair = self.discard_probs.get(input_)
                if input_==0 or del_pair >= prb:
                    continue
                else:
                    for output in outputs:
                        prb = random.random()
                        del_pair = self.discard_probs.get(output)
                        if output==0 or del_pair >= prb:
                            continue
                        else:
                            batch_input.append(input_)
                            batch_output.append(output)

        batch_input = torch.tensor(batch_input, dtype=torch.long)
        batch_output = torch.tensor(batch_output, dtype=torch.long)

        return batch_input, batch_output

In [None]:
class NegativeSampler:
    def __init__(self, vocab: Vocab, ns_exponent: float, ns_array_len: int):
        self.vocab = vocab
        self.ns_exponent = ns_exponent
        self.ns_array_len = ns_array_len
        self.ns_array = self._create_negative_sampling()

    def __len__(self):
        return len(self.ns_array)

    def _create_negative_sampling(self):

        frequency_dict = {word:freq**(self.ns_exponent) \
                          for _,(word, freq) in
                          list(self.vocab.stoi.items())[1:]}
        frequency_dict_scaled = {
            word:
            max(1,int((freq/self.vocab.total_tokens)*self.ns_array_len))
            for word, freq in frequency_dict.items()
            }
        ns_array = []
        for word, freq in tqdm(frequency_dict_scaled.items()):
            ns_array = ns_array + [word]*freq
        return ns_array

    def sample(self,n_batches: int=1, n_samples: int=1):
        samples = []
        for _ in range(n_batches):
            samples.append(random.sample(self.ns_array, n_samples))
        samples = torch.as_tensor(np.array(samples))
        return samples


In [None]:
class Model(nn.Module):
    def __init__(self, vocab: Vocab, params: Word2VecParams):
        super().__init__()
        self.vocab = vocab
        self.t_embeddings = nn.Embedding(
            self.vocab.__len__()+1,
            params.EMBED_DIM,
            max_norm=params.EMBED_MAX_NORM
            )
        self.c_embeddings = nn.Embedding(
            self.vocab.__len__()+1,
            params.EMBED_DIM,
            max_norm=params.EMBED_MAX_NORM
            )

    def forward(self, inputs, context):
        # getting embeddings for target & reshaping
        target_embeddings = self.t_embeddings(inputs)
        n_examples = target_embeddings.shape[0]
        n_dimensions = target_embeddings.shape[1]
        target_embeddings = target_embeddings.view(n_examples, 1, n_dimensions)

        # get embeddings for context labels & reshaping
        # Allows us to do a bunch of matrix multiplications
        context_embeddings = self.c_embeddings(context)
        # * This transposes each batch
        context_embeddings = context_embeddings.permute(0,2,1)

        # * custom linear layer
        dots = target_embeddings.bmm(context_embeddings)
        dots = dots.view(dots.shape[0], dots.shape[2])
        return dots

    def normalize_embeddings(self):
        embeddings = list(self.t_embeddings.parameters())[0]
        embeddings = embeddings.cpu().detach().numpy()
        norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)
        norms = norms.reshape(norms.shape[0], 1)
        return embeddings / norms

    def get_similar_words(self, word, n):
        word_id = self.vocab.get_index(word)
        if word_id == 0:
            print("Out of vocabulary word")
            return

        embedding_norms = self.normalize_embeddings()
        word_vec = embedding_norms[word_id]
        word_vec = np.reshape(word_vec, (word_vec.shape[0], 1))
        dists = np.matmul(embedding_norms, word_vec).flatten()
        topN_ids = np.argsort(-dists)[1 : n + 1]

        topN_dict = {}
        for sim_word_id in topN_ids:
            sim_word = self.vocab.lookup_token(sim_word_id)
            topN_dict[sim_word] = dists[sim_word_id]
        return topN_dict

    def get_similarity(self, word1, word2):
        idx1 = self.vocab.get_index(word1)
        idx2 = self.vocab.get_index(word2)
        if idx1 == 0 or idx2 == 0:
            print("One or both words are out of vocabulary")
            return

        embedding_norms = self.normalize_embeddings()
        word1_vec, word2_vec = embedding_norms[idx1], embedding_norms[idx2]

        return cosine(word1_vec, word2_vec)

In [None]:
class Trainer:
    def __init__(self, model: Model, params: Word2VecParams, optimizer,
                vocab: Vocab, train_iter, valid_iter, skipgrams: SkipGrams):
        self.model = model
        self.optimizer = optimizer
        self.vocab = vocab
        self.train_iter = train_iter
        self.valid_iter = valid_iter
        self.skipgrams = skipgrams
        self.params = params

        self.epoch_train_mins = {}
        self.loss = {"train": [], "valid": []}

        # sending all to device
        self.model.to(self.params.DEVICE)
        self.params.CRITERION.to(self.params.DEVICE)

        self.negative_sampler = NegativeSampler(
            vocab=self.vocab, ns_exponent=.75,
            ns_array_len=self.params.NS_ARRAY_LEN
            )
        self.testwords = ['love', 'hurricane', 'military', 'army']


    def train(self):
        self.test_testwords()
        for epoch in range(self.params.N_EPOCHS):
            # Generate Dataloaders
            self.train_dataloader = DataLoader(
                self.train_iter,
                batch_size=self.params.BATCH_SIZE,
                shuffle=False,
                collate_fn=self.skipgrams.collate_skipgram
            )
            self.valid_dataloader = DataLoader(
                self.valid_iter,
                batch_size=self.params.BATCH_SIZE,
                shuffle=False,
                collate_fn=self.skipgrams.collate_skipgram
            )
            # training the model
            st_time = monotonic()
            self._train_epoch()
            self.epoch_train_mins[epoch] = round((monotonic()-st_time)/60, 1)

            # validating the model
            self._validate_epoch()
            print(f"""Epoch: {epoch+1}/{self.params.N_EPOCHS}\n""",
            f"""    Train Loss: {self.loss['train'][-1]:.2}\n""",
            f"""    Valid Loss: {self.loss['valid'][-1]:.2}\n""",
            f"""    Training Time (mins): {self.epoch_train_mins.get(epoch)}"""
            """\n"""
            )
            self.test_testwords()


    def _train_epoch(self):
        self.model.train()
        running_loss = []

        for i, batch_data in enumerate(self.train_dataloader, 1):
            if len(batch_data[0]) == 0:
                continue
            inputs = batch_data[0].to(self.params.DEVICE)
            pos_labels = batch_data[1].to(self.params.DEVICE)
            neg_labels = self.negative_sampler.sample(
                pos_labels.shape[0], self.params.NEG_SAMPLES
                )
            neg_labels = neg_labels.to(self.params.DEVICE)
            context = torch.cat(
                [pos_labels.view(pos_labels.shape[0], 1),
                neg_labels], dim=1
              )

            # building the targets tensor
            y_pos = torch.ones((pos_labels.shape[0], 1))
            y_neg = torch.zeros((neg_labels.shape[0], neg_labels.shape[1]))
            y = torch.cat([y_pos, y_neg], dim=1).to(self.params.DEVICE)

            self.optimizer.zero_grad()

            outputs = self.model(inputs, context)
            loss = self.params.CRITERION(outputs, y)
            loss.backward()
            self.optimizer.step()

            running_loss.append(loss.item())

        epoch_loss = np.mean(running_loss)

        self.loss['train'].append(epoch_loss)

    def _validate_epoch(self):
        self.model.eval()
        running_loss = []

        with torch.no_grad():
            for i, batch_data in enumerate(self.valid_dataloader, 1):
                if len(batch_data[0]) == 0:
                    continue
                inputs = batch_data[0].to(self.params.DEVICE)
                pos_labels = batch_data[1].to(self.params.DEVICE)
                neg_labels = self.negative_sampler.sample(
                    pos_labels.shape[0], self.params.NEG_SAMPLES
                    ).to(self.params.DEVICE)
                context = torch.cat(
                    [pos_labels.view(pos_labels.shape[0], 1),
                    neg_labels], dim=1
                  )


                # building the targets tensor
                y_pos = torch.ones((pos_labels.shape[0], 1))
                y_neg = torch.zeros((neg_labels.shape[0], neg_labels.shape[1]))
                y = torch.cat([y_pos, y_neg], dim=1).to(self.params.DEVICE)

                preds = self.model(inputs, context).to(self.params.DEVICE)
                loss = self.params.CRITERION(preds, y)

                running_loss.append(loss.item())

            epoch_loss = np.mean(running_loss)
            self.loss['valid'].append(epoch_loss)

    def test_testwords(self, n: int = 5):
        for word in self.testwords:
            print(word)
            nn_words = self.model.get_similar_words(word, n)
            for w, sim in nn_words.items():
                print(f"{w} ({sim:.3})", end=' ')
            print('\n')

In [None]:
params = Word2VecParams()
train_iter, valid_iter = get_data()
tokenizer = get_tokenizer(params.TOKENIZER)
vocab = build_vocab(train_iter, tokenizer, params)
skip_gram = SkipGrams(vocab=vocab, params=params, tokenizer=tokenizer)
model = Model(vocab=vocab, params=params).to(params.DEVICE)
optimizer = torch.optim.Adam(params = model.parameters())

In [None]:
trainer = Trainer(
        model=model,
        params=params,
        optimizer=optimizer,
        train_iter=train_iter,
        valid_iter=valid_iter,
        vocab=vocab,
        skipgrams=skip_gram
    )
trainer.train()

100%|██████████| 49331/49331 [01:14<00:00, 663.13it/s]


Epoch: 1/5
     Train Loss: 0.96
     Valid Loss: 0.18
     Training Time (mins): 43.0

love
music (0.44) her (0.44) has (0.43) s (0.424) wrote (0.424) 

hurricane
landfall (0.287) changing (0.28) rapidly (0.279) steadily (0.276) winds (0.275) 

military
by (0.457) for (0.45) although (0.449) was (0.449) any (0.444) 

army
british (0.641) were (0.605) only (0.605) during (0.604) in (0.604) 

Epoch: 2/5
     Train Loss: 0.14
     Valid Loss: 0.11
     Training Time (mins): 42.4

love
me (0.851) my (0.849) said (0.845) good (0.838) saying (0.832) 

hurricane
storm (0.676) tropical (0.64) landfall (0.629) winds (0.625) rapidly (0.589) 

military
war (0.852) army (0.839) support (0.829) forces (0.82) supported (0.819) 

army
forces (0.896) troops (0.891) war (0.854) captured (0.839) military (0.839) 

Epoch: 3/5
     Train Loss: 0.1
     Valid Loss: 0.096
     Training Time (mins): 42.5

love
girl (0.825) woman (0.787) me (0.782) herself (0.781) song (0.773) 

hurricane
storm (0.787) tropi