In [None]:
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import DATASETS
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from tqdm import tqdm
import pickle
import random
import numpy as np
from collections import Counter, defaultdict
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from gensim.test.utils import datapath, get_tmpfile
from gensim.models import KeyedVectors
import gensim.downloader
from torch import FloatTensor as FT

# Get the interactive Tools for Matplotlib
%matplotlib notebook
%matplotlib inline

plt.style.use('ggplot')

### Instructions
For this part, fill in the required code and make the notebook work. This wll be very similar to the Skip-Gram model, but a little more difficult. Look for the """ FILL IN """ string to guide you.

In [None]:
# Where do I want to run my job. You can do "cuda" on linux machines
DEVICE = "mps" if torch.backends.mps.is_available() else  "cpu"
# DEVICE = "cuda" if torch.cuda.is_available() else  "cpu"

# The batch size in Adam or SGD
BATCH_SIZE = 512

# Number of epochs
NUM_EPOCHS = 10

# Predict from 2 words the inner word for CBOW
# I.e. I'll have a window like ["a", "b", "c"] of continuous text (each is a word)
# We'll predict each of wc = ["a", "c"] from "b" = wc for Skip-Gram
# For CBOW, we'll use ["a", "c"] to predict "b" = wo
WINDOW = 1

# Negative samples.
K = 4

The text8 Wikipedia corpus. 100M characters.

In [None]:
# Put the data in your Google Drive
# You can get the data from the HW page
from google.colab import drive
drive.mount('/content/drive')

!du -h text8

f = open('/content/drive/MyDrive/text8', 'r')
text = f.read()
# One big string of size 100M
print(len(text))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
du: cannot access 'text8': No such file or directory
100000000


In [None]:
punc = '!"#$%&()*+,-./:;<=>?@[\\]^_\'{|}~\t\n'

# Can do regular expressions here too
for c in punc:
    if c in text:
        text.replace(c, ' ')

In [None]:
# A very crude tokenizer you get for free: lower case and also split on spaces
TOKENIZER = get_tokenizer("basic_english")

In [None]:
words = TOKENIZER(text)
f = Counter(words)

In [None]:
len(words)

17005207

In [None]:
# Do a very crude filter on the text which removes all very popular words
text = [word for word in words if f[word] > 5]

In [None]:
text[0:5]

['anarchism', 'originated', 'as', 'a', 'term']

In [None]:
VOCAB = build_vocab_from_iterator([text])

In [None]:
# word -> int hash map
stoi = VOCAB.get_stoi()
# int -> word hash map
itos = VOCAB.get_itos()

In [None]:
stoi['as']

11

In [None]:
# Total number of words
len(stoi)

63641

In [None]:
f = Counter(text)
# This is the probability that we pick a word in the corpus
z = {word: f[word] / len(text) for word in f}

In [None]:
threshold = 1e-5
# Probability that word is kept while subsampling
# This is explained here and sightly differet from the paper: http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/
p_keep = {word: (np.sqrt(z[word] / 0.001) + 1)*(0.0001 / z[word]) for word in f}

In [None]:
# This is in the integer space
train_dataset = [word for word in text if random.random() < p_keep[word]]

# Rebuild the vocabulary
VOCAB = build_vocab_from_iterator([train_dataset])

In [None]:
len(train_dataset)

7847985

In [None]:
# word -> int mapping
stoi = VOCAB.get_stoi()
# int -> word mapping
itos = VOCAB.get_itos()

In [None]:
# The vocabulary size after we do all the filters
len(VOCAB)

63641

In [None]:
# The probability we draw something for negative sampling
f = Counter(train_dataset)
p = torch.zeros(len(VOCAB))

# Downsample frequent words and upsample less frequent
s = sum([np.power(freq, 0.75) for word, freq in f.items()])

for word in f:
    p[stoi[word]] = np.power(f[word], 0.75) / s

In [None]:
# Map everything to integers
train_dataset = [stoi[word] for word in text]

In [None]:
# This just gets the (wc, wo) pairs that are positive - they are seen together!
def get_tokenized_dataset(dataset, verbose=False):
    x_list = []

    for i, token in enumerate(dataset):
        m = 1

        # Get the left and right tokens
        start = i - WINDOW
        left_tokens = [dataset[index] for index in range(start, i) if index >= 0]

        end = i + WINDOW
        right_tokens = [dataset[index] for index in range(i + 1, end + 1) if index < len(dataset)]

        # Check these are the same length, and if so use them to add a row of data. This should be a list like
        # [a, c, b] where b is the center word
        if len(left_tokens) == len(right_tokens):
            w_context = left_tokens + right_tokens

            wc = token

            x_list.extend(
                [w_context + [wc]]
            )

    return x_list

In [None]:
train_x_list = get_tokenized_dataset(train_dataset, verbose=False)

In [None]:
pickle.dump(train_x_list, open('train_x_list.pkl', 'wb'))

In [None]:
train_x_list = pickle.load(open('train_x_list.pkl', 'rb'))

In [None]:
# These are (wc, wo) pairs. All are y = +1 by design
train_x_list[:10]

[[5233, 11, 3083],
 [3083, 5, 11],
 [11, 218, 5],
 [5, 1, 218],
 [218, 3133, 1],
 [1, 45, 3133],
 [3133, 60, 45],
 [45, 174, 60],
 [60, 132, 174],
 [174, 740, 132]]

In [None]:
# The number of things of BATCH_SIZE = 512
assert(len(train_x_list) // BATCH_SIZE == 32579)

### Set up the dataloader.

In [None]:
train_dl = DataLoader(
    TensorDataset(
        torch.tensor(train_x_list).to(DEVICE),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [None]:
for xb in train_dl:
    assert(xb[0].shape == (BATCH_SIZE, 3))
    break

### Words we'll use to asses the quality of the model ...

In [None]:
valid_ids = torch.tensor([
    stoi['money'],
    stoi['lion'],
    stoi['africa'],
    stoi['musician'],
    stoi['dance'],
])

### Get the model.

In [None]:
class CBOWNegativeSampling(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(CBOWNegativeSampling, self).__init__()
        self.A = nn.Embedding(vocab_size, embed_dim) # Context vectors - center word
        self.B = nn.Embedding(vocab_size, embed_dim) # Output vectors - words around the center word
        self.init_weights()

    def init_weights(self):
        # Is this the best way? Not sure
        initrange = 0.5
        self.A.weight.data.uniform_(-initrange, initrange)
        self.B.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        # N is the batch size
        # x is (N, 3)

        # Context words are 2m things, m = 1 so w_context is (N, 2) while wc is N
        w_context, wc = x[:, : -1], x[:, -1]

        # Each of these is (N, 2, D) since each context has 2 word
        # We want this to be (N, D) and this is what we get

        # (N, 2, D)
        a = self.A(w_context)

        # (N, D)
        a_avg = a.mean(axis=1)

        # Each of these is (N, D) since each target has 1 word
        b = self.B(wc)

        # The product between each context and target vector. Look at the Skip-Gram code.
        # The logits is now (N, 1) since we sum across the final dimension.
        logits = (a_avg * b).sum(axis=-1, keepdim=True)

        return logits

In [None]:
@torch.no_grad()
def validate_embeddings(
    model,
    valid_ids,
    itos
):
    """ Validation logic """

    # We will use context embeddings to get the most similar words
    # Other strategies include: using target embeddings, mean embeddings after avaraging context/target
    embedding_weights = model.A.weight

    normalized_embeddings = embedding_weights.cpu() / np.sqrt(
        np.sum(embedding_weights.cpu().numpy()**2, axis=1, keepdims=True)
    )

    # Get the embeddings corresponding to valid_term_ids
    valid_embeddings = normalized_embeddings[valid_ids, :]

    # Compute the similarity between valid_term_ids (S) and all the embeddings (V)
    # We do S x d (d x V) => S x D and sort by negative similarity
    top_k = 10 # Top k items will be displayed
    similarity = np.dot(valid_embeddings.cpu().numpy(), normalized_embeddings.cpu().numpy().T)

    # Invert similarity matrix to negative
    # Ignore the first one because that would be the same word as the probe word
    similarity_top_k = np.argsort(-similarity, axis=1)[:, 1: top_k+1]

    # Print the output.
    for i, word_id in enumerate(valid_ids):
        # j >= 1 here since we don't want to include the word itself.
        similar_word_str = ', '.join([itos[j] for j in similarity_top_k[i, :] if j >= 1])
        print(f"{itos[word_id]}: {similar_word_str}")

    print('\n')

### Set up the model

In [None]:
LR = 10.0
NUM_EPOCHS = 10
EMBED_DIM = 300

In [None]:
model = CBOWNegativeSampling(len(VOCAB), EMBED_DIM).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# The learning rate is lowered every epoch by 1/10.
# Is this a good idea?
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)

In [None]:
model

CBOWNegativeSampling(
  (A): Embedding(63641, 300)
  (B): Embedding(63641, 300)
)

In [None]:
validate_embeddings(model, valid_ids, itos)

money: tfl, riding, samara, aphra, borrow, orford, predictive, theodosian, designing, proposing
lion: att, stirrup, outdated, schoolmaster, beckham, musicianship, cultivators, suffixes, thread, alter
africa: foreseeable, wali, haec, scheele, client, gamers, surfer, bhopal, detractors, alu
musician: forebears, vx, cartoons, ideological, housed, could, hydrolysis, bestowing, terrorized, mlas
dance: manmade, mulligan, oasis, orgy, cognitive, string, provability, eapc, kitab, radiocommunications




### Train the model

In [None]:
ratios = []

def train(dataloader, model, optimizer, epoch):
    model.train()
    total_acc, total_count, total_loss, total_batches = 0, 0, 0.0, 0.0
    log_interval = 500

    for idx, x_batch in tqdm(enumerate(dataloader)):

        x_batch = x_batch[0]

        batch_size = x_batch.shape[0]

        # Zero the gradient so they don't accumulate
        optimizer.zero_grad()

        logits = model(x_batch)

        # Get the positive samples loss. Notice we use weights here
        positive_loss = torch.nn.BCEWithLogitsLoss()(input=logits, target=torch.ones(batch_size).reshape(-1, 1).to(DEVICE).float())

        # For each batch, get some negative samples
        # We need a total of len(y_batch) * K samples across a batch
        # We then reshape this batch
        # These are effectively the output words
        negative_samples = torch.multinomial(p, batch_size * K, replacement=True)

        # Context words are 2m things, m = 1 so w_context is (N, 2) while wc is (N, )
        w_context, wc = x_batch[:, :-1].to("cpu").repeat(K, 1), negative_samples.reshape(-1, 1)

        """
        if w_context looks like below (batch_size = 3)
        [
        (a, b),
        (c, d),
        (e, f)
        ] and K = 2 we'd like to get:

        [
        (a, b),
        (a, b),
        (c, d),
        (c, d),
        (e, f),
        (e, f)
        ]

        This will be batch_size * K rows.
        """


        x_batch_negative = torch.cat([w_context, wc], dim=1).to(DEVICE)

        """
        Note the way we formulated the targets: they are all 0 since these are negative samples.
        We do the BCEWithLogitsLoss by hand basically here.
        Notice we sum across the negative samples, per positive word.

        This is literally the equation in the lecture notes.
        """

        # (N, K, D) -> (N, D) -> (N)
        # Look at the Skip-Gram notebook
        negative_loss = model(x_batch_negative).neg().sigmoid().log().reshape(batch_size, K).sum(1).mean().neg().to(DEVICE)

        loss = (positive_loss + negative_loss).mean()

        # Get the gradients via back propagation
        loss.backward()

        # Clip the gradients? Generally a good idea
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

        # Used for optimization. This should be roughly 0.001, on average
        # You can use this to see if your learning rate is right - you can also plot the loss performance
        with torch.no_grad():
            r = [
                (LR * p.grad.std() / p.data.std()).log10().item() for _, p in model.named_parameters()
            ]
            ratios.append(r)

        # Do an optimization step. Update the parameters A and B
        optimizer.step()
        # Get the new loss.
        total_loss += loss.item()
        # Update the batch count
        total_batches += 1

        if idx % log_interval == 0:
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| loss {:8.3f} ".format(
                    epoch,
                    idx,
                    len(dataloader),
                    total_loss / total_batches
                )
            )
            validate_embeddings(model, valid_ids, itos)
            total_loss, total_batches = 0.0, 0.0

### Some results from the run look like below:

Somewhere inside of 2 iterations you should get sensible associattions.
Paste here a screenshot of the closest vectors.

In [None]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()

    train(train_dl, model, optimizer, epoch)
    # We have a learning rate scheduler here
    # Basically, given the state of the optimizer, this lowers the learning rate in a smart way
    scheduler.step()

1it [00:02,  2.25s/it]

| epoch   1 |     0/32580 batches | loss    4.058 
money: tfl, riding, samara, aphra, borrow, orford, predictive, theodosian, designing, proposing
lion: att, stirrup, outdated, schoolmaster, beckham, musicianship, cultivators, suffixes, thread, alter
africa: foreseeable, wali, haec, scheele, client, gamers, surfer, bhopal, detractors, alu
musician: forebears, vx, cartoons, ideological, housed, could, hydrolysis, bestowing, terrorized, mlas
dance: manmade, mulligan, oasis, orgy, cognitive, string, provability, eapc, kitab, radiocommunications




501it [04:16,  2.07it/s]

| epoch   1 |   500/32580 batches | loss    3.713 
money: tfl, samara, riding, orford, borrow, aphra, theodosian, predictive, designing, finality
lion: att, stirrup, outdated, schoolmaster, musicianship, beckham, cultivators, thread, suffixes, alter
africa: foreseeable, wali, haec, client, surfer, scheele, bhopal, detractors, gamers, eta
musician: forebears, vx, cartoons, ideological, housed, hydrolysis, bestowing, could, overlooks, mlas
dance: manmade, mulligan, oasis, orgy, cognitive, string, kitab, eapc, provability, radiocommunications




1000it [08:35,  1.50it/s]

| epoch   1 |  1000/32580 batches | loss    3.356 
money: tfl, samara, riding, orford, borrow, theodosian, aphra, predictive, designing, finality
lion: att, stirrup, outdated, schoolmaster, musicianship, beckham, cultivators, thread, suffixes, alter

1001it [08:36,  1.37it/s]


africa: foreseeable, surfer, haec, wali, client, bhopal, scheele, detractors, gamers, alu
musician: vx, ideological, forebears, cartoons, could, housed, hydrolysis, bestowing, mlas, overlooks
dance: manmade, mulligan, oasis, string, orgy, cognitive, eapc, kitab, provability, stringent




1501it [12:56,  2.03it/s]

| epoch   1 |  1500/32580 batches | loss    3.029 
money: tfl, samara, riding, orford, borrow, theodosian, aphra, predictive, finality, terms
lion: att, stirrup, outdated, schoolmaster, beckham, musicianship, cultivators, thread, generic, suffixes
africa: foreseeable, surfer, haec, client, wali, scheele, bhopal, detractors, gamers, alu
musician: ideological, forebears, vx, could, cartoons, housed, hydrolysis, mlas, tsui, gujarat
dance: manmade, oasis, mulligan, orgy, stringent, string, cognitive, alleviating, desecration, eapc




2001it [17:17,  1.78it/s]

| epoch   1 |  2000/32580 batches | loss    2.788 
money: tfl, samara, orford, borrow, riding, theodosian, terms, finality, designing, predictive
lion: att, stirrup, outdated, schoolmaster, musicianship, cultivators, beckham, thread, generic, mercian
africa: foreseeable, surfer, haec, client, wali, bhopal, scheele, detractors, gamers, glial
musician: could, ideological, forebears, cartoons, vx, housed, hydrolysis, overlooks, mlas, tsui
dance: manmade, mulligan, oasis, orgy, stringent, string, alleviating, radiocommunications, cognitive, untrained




2501it [21:45,  2.01it/s]

| epoch   1 |  2500/32580 batches | loss    2.616 
money: tfl, orford, samara, riding, borrow, terms, theodosian, finality, designing, deallocation
lion: att, stirrup, outdated, service, schoolmaster, beckham, cultivators, musicianship, mercian, thread
africa: foreseeable, seven, some, only, surfer, many, four, five, war, second
musician: could, ideological, forebears, cartoons, vx, housed, hydrolysis, duchess, consultancies, overlooks
dance: manmade, mulligan, oasis, stringent, string, alleviating, untrained, orgy, radiocommunications, desecration




3000it [26:08,  1.56it/s]

| epoch   1 |  3000/32580 batches | loss    2.485 
money: tfl, terms, orford, samara, borrow, riding, theodosian, finality, predictive, designing
lion: service, att, stirrup, outdated, beckham, schoolmaster, cultivators, musicianship, mercian, thread
africa: seven, foreseeable, war, some, four, five, many, three, two, united
musician: could, ideological, cartoons, forebears, number, fact, vx, housed, hydrolysis, duchess
dance: manmade, oasis, stringent, string, mulligan, untrained, alleviating, eapc, radiocommunications, panegyric



3001it [26:09,  1.40it/s]




3501it [30:34,  2.09it/s]

| epoch   1 |  3500/32580 batches | loss    2.374 
money: tfl, orford, terms, samara, riding, borrow, finality, theodosian, predictive, designing
lion: service, att, stirrup, outdated, beckham, schoolmaster, cultivators, musicianship, mercian, thread
africa: some, war, seven, five, three, many, united, only, six, which
musician: could, number, ideological, fact, well, cartoons, forebears, x, housed, duchess
dance: manmade, stringent, string, oasis, eapc, alleviating, untrained, radiocommunications, mulligan, panegyric




4000it [34:50,  1.80it/s]

| epoch   1 |  4000/32580 batches | loss    2.290 
money: tfl, terms, orford, samara, riding, borrow, finality, theodosian, designing, predictive


4001it [34:51,  1.51it/s]

lion: service, att, stirrup, outdated, beckham, schoolmaster, cultivators, thread, musicianship, mercian
africa: some, war, seven, many, five, two, three, united, only, six
musician: could, number, well, fact, x, ideological, including, form, five, all
dance: manmade, string, stringent, untrained, oasis, radiocommunications, alleviating, eapc, panegyric, weapon




4501it [39:11,  1.97it/s]

| epoch   1 |  4500/32580 batches | loss    2.217 
money: tfl, terms, one, example, orford, order, his, government, samara, as
lion: service, stirrup, att, outdated, beckham, economic, schoolmaster, cultivators, thread, musicianship
africa: many, united, some, seven, war, five, six, three, only, eight
musician: could, number, well, fact, x, including, form, under, u, common
dance: string, manmade, stringent, untrained, alleviating, oasis, panegyric, eapc, weapon, radiocommunications




5000it [43:33,  1.77it/s]

| epoch   1 |  5000/32580 batches | loss    2.156 
money: tfl, terms, example, order, s, one, his, no, both, government


5001it [43:34,  1.50it/s]

lion: service, stirrup, att, outdated, economic, beckham, schoolmaster, thread, cultivators, musicianship
africa: five, united, many, seven, some, three, six, which, war, four
musician: could, number, x, well, fact, including, form, u, three, common
dance: string, stringent, manmade, alleviating, untrained, weapon, panegyric, eapc, oasis, somerset




5501it [47:58,  2.03it/s]

| epoch   1 |  5500/32580 batches | loss    2.102 
money: terms, tfl, example, order, one, government, later, time, no, his
lion: service, stirrup, att, economic, outdated, thread, beckham, cultivators, schoolmaster, classical
africa: five, many, seven, united, some, six, three, four, war, two
musician: could, number, x, well, fact, form, u, including, three, common
dance: string, stringent, manmade, alleviating, weapon, untrained, panegyric, eapc, somerset, infusion




5973it [52:13,  1.00s/it]