### Skip-Gram Word2Vec (with Negative Sampling)

The skip-gram word2vec algorithm is a simple self-supervised model for learning dense word embedding vectors from a corpus of text. It is trained on the task of predicting the `probability distribution` over `context words` given a `center word`, for all possible center words from the vocabulary. The parameters of this model consist of two separate $|V| \times D$ embedding matrices ($|V|$ is the vocab size and $D$ is the embedding dimensions): matrix $W$ whose rows are the embeddings of outside words and matrix $C$ whose rows are the embeddings of center words. (We could instead have one single embedding for both outside and context words, however this approach of having separate embeddings is more convenient). Then the model for computing the probability distribution is simply defined as follows:

$P(w|c) = \frac{exp(\vec{w} \cdot \vec{c})}{\sum_{w' \in V} exp(\vec{w}' \cdot \vec{c})}$

where $\vec{w}$ and $\vec{c}$ are the embedding vectors of the context and center words resepectively. Note that this is just a softmax over all the dot products of every possible context word $w$ given a particluar center word $c$.

Now due to large volcabulary size $|V|$, computing this probability distribution is very inefficient (because of the sum in the denominator). Instead of computing a probability distribution over all possible context words, we instead simplify our task into a `binary classification` problem. Given a pair of context word $w$ and center word $c$, our simplified task is to train a logistic regression classifier to predict whether $w$ actually occurs in the context of $c$ or not (we use label $1$ for True and $0$ for False). We define this simple binary classification problem as follows:

$P(y=1|w_{pos},c) = \sigma (\vec{w}_{pos} \cdot \vec{c})$

$P(y=0|w_{neg},c) = 1-\sigma (\vec{w}_{neg} \cdot \vec{c}) = \sigma (-\vec{w}_{neg} \cdot \vec{c})$

where $\sigma()$ is the sigmoid function, $w_{pos}$ denotes a true context word and $w_{neg}$ denotes a `noise` word which is not a true context word. For training this classifier, we will use $k$ times as many noise words than context words (this reflects the fact that each center word will have far fewer words from the vocab that appear in it's context than words that don't). During training, we will simply slide a context window of half-size $L$ over the training corpus, so at each position this gives us $2L$ different positive pairs $\{(w^{i}_{pos},c) | i =1,2..,2L\}$. For each of these positive pairs, we generate k negative samples by sampling from the unigram probability distribution over the vocabulary (making sure that none of these noise words match the positive word). Then we compute the negative log-likelihood loss for the each positive pair along with the k negative pairs:

$L = -\log(P(y=1|w_{pos},c)) - \sum_{j=1}^k \log(P(y=0|w^{j}_{neg},c)) = -\log(\sigma (\vec{w}_{pos} \cdot \vec{c})) - \sum_{j=1}^k \log(\sigma (-\vec{w}^{j}_{neg} \cdot \vec{c}))$

We can think of each window position providing us with a batch of $L$ positive instances and $kL$ negative instances. Then we can minimize this loss via gradient descent. The gradients with are:

$\frac{\partial L}{\partial \vec c} = (\sigma (\vec{w}_{pos} \cdot \vec{c})-1) \vec{w}_{pos} + \sum_{j=1}^k \sigma (\vec{w}^{j}_{neg} \cdot \vec{c}) \vec{w}^{j}_{neg}$

$\frac{\partial L}{\partial \vec{w}_{pos}} = (\sigma (\vec{w}_{pos} \cdot \vec{c})-1) \vec{c}$

$\frac{\partial L}{\partial \vec{w}^{j}_{neg}} = \sigma (\vec{w}^{j}_{neg} \cdot \vec{c}) \vec{c}$


Note: When sampling the noise words, instead of using the unigram probability distribution $P(w)$ over the vocabulary words, it's better to use a weighted version of this distirbution $P_{\alpha}(w)$:

$P(w) = \frac{count(w)}{ \sum_{w'\in V} count(w')}$

which is the unigram distribution and the weighted unigram distribution is defined as:

$P_{\alpha}(w) = \frac{count(w)^{\alpha}}{ \sum_{w'\in V} count(w')^{\alpha}}$

where $\alpha$ is an exponent between $0$ and $1$. This kind of weighting helps to slightly increase the probabilities of the rarer words and slightly suppressess the probability of the most common words. Empirically $\alpha = 0.75$ tends to work well. 



We will implement and train a skipgram word2vec model using the Stanford Treebank dataset.

In [2]:
from collections import defaultdict
import numpy as np

In [2]:
def check_punc(w):
    return any(c.isalpha() for c in w)

# remove punctuations from list of words and apply lowercase folding 
def preprocess(s):
    words = s.lower().strip().split()[1:]
    words = [w for w in words if check_punc(w)]
    return words

# load dataset
word_count = 0
unigram_count = defaultdict(int)
wierd_words = []
with open('datasetSentences.txt','r') as file:
    lines = file.readlines()
    # preprocessing
    sentences = []
    for line in lines[1:]:
        words = preprocess(line)
        s = []
        for word in words:
            if "\/" in word:
                ws = word.replace("\/", " ").split()
                for w in ws:
                    s.append(w)
                    unigram_count[w] += 1
                    word_count += 1
            else:
                s.append(word)    
                unigram_count[word] += 1
                word_count += 1
        sentences.append(s)        
                 

#### The skipgram paper uses a subsampling strategy to get rid of the most frequent words, like stop words. Each word $w_i$ in the training corpus is discarded with a probability given by the following:

$P(w_i) = 1 - \sqrt{\frac{T}{count(w_i)}}$ 

where $T$ is a threshold value which is a small fraction of the corpus total token count (~$10^{-5} \times N$) and $count(w_i)$ is the frequency of that word in the corpus. For more frequent words, the square root term is very close to zero and so the word will get discarder with high probability.

We also keep multiple copies of the same sentence to reduce the chances of entirely losing important words due to the subsampling.

In [3]:
def subsample_prob(word, t=1e-4):
    p = max(0, 1 - np.sqrt(t*word_count/unigram_count[word]))
    return p

num_copies = 10
discard_probs = {w:subsample_prob(w) for w in unigram_count.keys()}
sentences_subsampled = [[word for word in s if np.random.random() >= discard_probs[word]] for s in sentences*num_copies]

# remove zero length subsampled sentences
sentences_subsampled = [s for s in sentences_subsampled if len(s) > 0]

In [4]:
# compare before and after subsampling
for i in range(7):
    print("Before subsampling: ", sentences[i])
    print("After subsampling: ", sentences_subsampled[i])

Before subsampling:  ['the', 'rock', 'is', 'destined', 'to', 'be', 'the', '21st', 'century', "'s", 'new', 'conan', 'and', 'that', 'he', "'s", 'going', 'to', 'make', 'a', 'splash', 'even', 'greater', 'than', 'arnold', 'schwarzenegger', 'jean-claud', 'van', 'damme', 'or', 'steven', 'segal']
After subsampling:  ['rock', 'destined', '21st', 'century', 'conan', 'going', 'splash', 'greater', 'than', 'arnold', 'schwarzenegger', 'jean-claud', 'van', 'damme', 'steven', 'segal']
Before subsampling:  ['the', 'gorgeously', 'elaborate', 'continuation', 'of', 'the', 'lord', 'of', 'the', 'rings', 'trilogy', 'is', 'so', 'huge', 'that', 'a', 'column', 'of', 'words', 'can', 'not', 'adequately', 'describe', 'co-writer', 'director', 'peter', 'jackson', "'s", 'expanded', 'vision', 'of', 'j.r.r.', 'tolkien', "'s", 'middle-earth']
After subsampling:  ['gorgeously', 'elaborate', 'continuation', 'lord', 'rings', 'trilogy', 'huge', 'column', 'words', 'adequately', 'describe', 'co-writer', 'peter', 'jackson', 'e

Note that most of the stop words are gone after subsampling.

In [8]:
# create vocabulary
vocab = sorted(list(set([word for sentence in sentences_subsampled for word in sentence])))
word2idx = {w:i for i,w in enumerate(vocab)}
vocab_size = len(vocab)

print(f"Vocab size: {len(vocab)}")
print(f"Num sentences: {len(sentences_subsampled)}")
print(f"Total number of tokens: {sum(len(s) for s in sentences_subsampled)}")

# tokenize the sentences
sentences_tokenized = []
for s in sentences_subsampled:
    sentences_tokenized.append([word2idx[w] for w in s])

Vocab size: 19332
Num sentences: 117918
Total number of tokens: 956449


In [9]:
# unigram weighted probability distribution
alpha = 0.75
P_alpha = np.zeros(shape=(vocab_size))
for i,w in enumerate(vocab):
    P_alpha[i] = unigram_count[w]**alpha
P_alpha = P_alpha / P_alpha.sum()    
unigram_idx = np.arange(0,vocab_size)

In [197]:
# hyperparameters of word2vec model
D = 32 # embedding dim
L = 8  # context window half-size

#### Instead of sliding context window over every position from start to end of corpus, we will instead randomly select a batch of context windows on every epoch. We will also add some randomness to the context window size, by sampling a random size between [1,L]. This ensures that we get smaller context windows more often than longer windows which is helpful because context words that are closer should be related more strongly on the center word than context words that are farther away. Closer context words should therefore be sampled more often.  

In [204]:
def get_random_context(L):
    # first randomly select a sentence
    sent_idx = np.random.randint(0,len(sentences_tokenized)-1)
    sent = sentences_tokenized[sent_idx]
    # pick random context window half-length between [1..L]
    R = np.random.randint(1, L)
    # pick a random center word from the sentence
    if len(sent) > R:
        c_idx = np.random.randint(0,len(sent)-1)
    else:
        c_idx = 0

    center_word = sent[c_idx]
    context_words = sent[max(0,c_idx-R):c_idx] + sent[c_idx+1:c_idx+1+R]

    if len(context_words) == 0:
        return get_random_context(L)
    else:
        return center_word, context_words    


def get_negative_samples(wpos_idx, k=10):
    nsamples = 0
    wnegs = []
    # generate negative samples
    while nsamples < k:
        wneg_idx = np.random.choice(unigram_idx, size=1, p=P_alpha)[0]
        # make sure noise words don't match the positive word
        if wneg_idx != wpos_idx:
            wnegs.append(wneg_idx)
            nsamples += 1
    return wnegs    


def sigmoid(x):
    return 1/(1+np.exp(-x))

# skip-gram with negative samplice loss for a single (w,c) pair
def compute_loss_and_grads(wpos_idx, c_idx, W, C):
    # get negative samples
    wnegs_idx = get_negative_samples(wpos_idx)
    # get embedding vectors
    V, D = W.shape
    c = C[c_idx]   # shape: (D,)
    w_pos = W[wpos_idx]  # shape: (D,)
    w_negs = W[wnegs_idx]  # shape: (k,D)
    
    s_wpos_dot_c = sigmoid(np.dot(w_pos,c))  # shape: (1,)
    s_wneg_dot_c = sigmoid(np.dot(w_negs,c)).reshape(w_negs.shape[0],1)  # shape: (k,1)

    # compute loss
    loss =  -np.log(s_wpos_dot_c) - np.log(1-s_wneg_dot_c).sum()

    # compute gradients
    grad_c = (s_wpos_dot_c-1) * w_pos +  (s_wneg_dot_c * w_negs).sum(axis=0)  # shape: (D,)
    grad_wpos = (s_wpos_dot_c-1) * c  # shape: (D,)
    grad_wnegs = s_wneg_dot_c * c  # shape: (k,D)
    
    return loss, grad_c, grad_wpos, grad_wnegs, wnegs_idx


# compute total loss and accumulated gradients for a single context window
def skipgram(center_word_idx, context_words_idx, W, C):
    grad_C = np.zeros_like(C) 
    grad_W = np.zeros_like(W) 
    total_loss = 0.0

    # compute loss and accumulate gradients for each positive context word and negative samples
    for wpos_idx in context_words_idx:
        loss, grad_c, grad_wpos, grad_wnegs, wnegs_idx = compute_loss_and_grads(wpos_idx, center_word_idx, W, C) 
        total_loss += loss
        grad_C[center_word_idx] += grad_c
        # negative samples could be repeated, so we need to be more careful about adding all the contirbutions
        # can't use += operator which will only add repeated contirbutions once
        np.add.at(grad_W, wnegs_idx, grad_wnegs)
        grad_W[wpos_idx] += grad_wpos

    return total_loss, grad_W, grad_C


# perform gradient descent update of parameters over a mini batch
def train_step(W, C, L, batch_size, alpha):
    grad_C = np.zeros_like(C) 
    grad_W = np.zeros_like(W)
    total_loss = 0 
    for _ in range(batch_size):
        # get a random context window
        center_word_idx, context_words_idx = get_random_context(L)
        # compute loss and gradients for this window 
        loss, grad_W_window, grad_C_window = skipgram(center_word_idx, context_words_idx, W, C)
        # accumulate loss and grads
        total_loss += loss
        grad_W += grad_W_window
        grad_C += grad_C_window

    # average over mini-batch
    total_loss /= batch_size
    grad_W /= batch_size    
    grad_C /= batch_size    

    # perform sgd update of parameters
    W -= alpha * grad_W
    C -= alpha * grad_C

    return W, C, total_loss

# training loop
def train(W, C, L, num_epochs=10, batch_size=32, alpha=0.01, print_every=100):
    for epoch in range(num_epochs):
        W, C, loss = train_step(W, C, L, batch_size, alpha)
        if epoch%print_every==0:
            print(f"Epoch #{epoch}, Train Loss: {loss}")

    return W, C    

In [205]:
# parameters: embedding matrices
#W = 0.001 * np.random.randn(vocab_size, D)
#C = 0.001 * np.random.randn(vocab_size, D)

W, C, = train(W, C, L, num_epochs=5000, batch_size=100, alpha=0.1, print_every=50)

Epoch #0, Train Loss: 26.200831558512604
Epoch #50, Train Loss: 23.382870880154247
Epoch #100, Train Loss: 24.619770497523135
Epoch #150, Train Loss: 24.43834585654396
Epoch #200, Train Loss: 26.31312494344878
Epoch #250, Train Loss: 24.96004441332466
Epoch #300, Train Loss: 26.903127029699125
Epoch #350, Train Loss: 25.32494968546413
Epoch #400, Train Loss: 26.5285282611221
Epoch #450, Train Loss: 25.17586418564848
Epoch #500, Train Loss: 25.996913461858494
Epoch #550, Train Loss: 27.228483777467112
Epoch #600, Train Loss: 25.20039220782222
Epoch #650, Train Loss: 22.765463235425944
Epoch #700, Train Loss: 23.820647346142145
Epoch #750, Train Loss: 26.113035287980914
Epoch #800, Train Loss: 27.041191851181903
Epoch #850, Train Loss: 23.575318658064297
Epoch #900, Train Loss: 25.529429717116
Epoch #950, Train Loss: 24.147396035317566
Epoch #1000, Train Loss: 26.09767534426002
Epoch #1050, Train Loss: 26.18397234267203
Epoch #1100, Train Loss: 23.992082590350456
Epoch #1150, Train Loss:

In [207]:
W, C, = train(W, C, L, num_epochs=20000, batch_size=100, alpha=0.5, print_every=50)

Epoch #0, Train Loss: 20.65272676847883
Epoch #50, Train Loss: 22.239343190398717
Epoch #100, Train Loss: 20.96164221369737
Epoch #150, Train Loss: 21.99620678015732
Epoch #200, Train Loss: 17.174604699019362
Epoch #250, Train Loss: 17.275600106957167
Epoch #300, Train Loss: 19.987688334956484
Epoch #350, Train Loss: 19.86592998149786
Epoch #400, Train Loss: 21.268760764781064
Epoch #450, Train Loss: 19.423829854752015
Epoch #500, Train Loss: 18.66069026663032
Epoch #550, Train Loss: 20.522580413296136
Epoch #600, Train Loss: 16.930430685389723
Epoch #650, Train Loss: 18.07658792431563
Epoch #700, Train Loss: 19.948015291515734
Epoch #750, Train Loss: 18.506683751477393
Epoch #800, Train Loss: 17.747519006052435
Epoch #850, Train Loss: 19.271912766334747
Epoch #900, Train Loss: 18.515824848053803
Epoch #950, Train Loss: 18.96908352420442
Epoch #1000, Train Loss: 17.384877709700653
Epoch #1050, Train Loss: 18.376070687138114
Epoch #1100, Train Loss: 16.632214649504657
Epoch #1150, Train

In [208]:
np.savez('word2vec_params.npz', arr1=W, arr2=C)

In [None]:
#arrays = np.load('word2vec_params.npz')
#W, C = arrays['arr1'], arrays['arr2'] 

In [244]:
# total embedding
E = np.array(C)

# normalize lengths of vectors

E = E / np.linalg.norm(E, axis=1, keepdims=True)

In [245]:
def find_most_similar(word, n=20):
    print(f"Finding most similar words for idx: {word2idx[word]}")
    # get the embedding of this word (sum of context and center embeddings)
    w_emb = E[word2idx[word]]
    # compute dot product with all other words
    similarity_scores = np.dot(E,w_emb)
    # find the indices sorted from largest to smallest score
    idx = np.argsort(similarity_scores)[::-1]
    print(f"best idx: {idx[:n]}")

    # get the n highest scoring words
    best = []
    for i in range(n):
        best.append(vocab[idx[i]])
    return best, similarity_scores[idx[:n]]    

In [248]:
w = "awesome"
similar_words, scores = find_most_similar(w)
print(f"Words most similar to '{w}': {similar_words}")
print(f"scores: {scores}")

Finding most similar words for idx: 1243
best idx: [ 1243 18186 12156  2939 11258  3169  3565  6976 10361  9626  2414 15552
  8749 16437 18801 12605  8677  6551 13061 18090]
Words most similar to 'awesome': ['awesome', 'unsentimental', 'passionate', 'clashing', 'neo-nazism', 'color', 'controlled', 'generous', 'masterful', 'lean', 'campanella', 'sluggish', 'instincts', 'sturdy', 'well-edited', 'plotted', 'innocence', 'forcefully', 'profile', 'uniquely']
scores: [1.         0.99989264 0.99987931 0.99986776 0.99986762 0.99986141
 0.99985859 0.99985747 0.99985041 0.99984668 0.99984502 0.99983811
 0.99983667 0.99983664 0.99983146 0.99983003 0.99982877 0.99982836
 0.99982625 0.99982193]


#### PyTorchified implementation of skipgram word2vec

In [124]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
from nltk.corpus import brown
from nltk.corpus import stopwords

In [130]:
# create pytorch dataset for neatly packaging the Stanford Treebank Sentiment sentences
class SSTDataset(Dataset):
    def __init__(self, window_size=6, num_negatives=20, subsample=True, subsample_t=1e-4):
        self.L = window_size   # context window size
        self.k = num_negatives # number of negative samples per positive
        self.subsample_t = subsample_t

        # get sentences from file
        sentences, self.unigram_count, self.word_count = self.get_sentences()
        # subsample the sentences
        if subsample:
            sentences_subsampled = self.subsample_sentences(sentences)
        else:
            sentences_subsampled = sentences
        # create vocabulary
        self.vocab = sorted(list(set([word for sentence in sentences_subsampled for word in sentence])))
        self.word2idx = {w:i for i,w in enumerate(self.vocab)}
        print(f"Vocab size: {len(self.vocab)}, word_count: {self.word_count}")
        print(f"Num sentences: {len(sentences_subsampled)}")
        print(f"Total number of tokens after subsampling: {sum(len(s) for s in sentences_subsampled)}")
        # tokenize the sentences
        self.sentences_tokenized = []
        for s in sentences_subsampled:
            self.sentences_tokenized.append([self.word2idx[w] for w in s])
        # generate all positive pairs: (w_pos, c)
        self.pos_pairs = self.generate_positive_pairs()
        # compute weighted unigram probability distribution
        self.P_alpha = self.compute_weighted_unigram_dist()

    def __len__(self):
        return len(self.pos_pairs)

    def __getitem__(self, idx):
        # get positive pair
        pos_pair = self.pos_pairs[idx]
        # generate negative words
        wnegs = self.get_negative_samples(pos_pair[0])
        return {"center": pos_pair[1], "positive": pos_pair[0], "negatives": wnegs}
    
    def on_epoch_end(self):
        # generate new pos pairs
        self.pos_pairs = self.generate_positive_pairs()
        random.shuffle(self.pos_pairs)

    def get_negative_samples(self, wpos_idx):
        # generate negative samples (generate extra's to account for removals of matches with wpos_idx)
        wnegs = torch.multinomial(self.P_alpha, self.k+50)
        # remove matches with the positive word
        wnegs = wnegs[wnegs != wpos_idx]    
        
        # if not enough negative samples, generate more
        while  len(wnegs) < self.k:
            # generate more extras
            negs_extra = torch.multinomial(self.P_alpha, self.k+50)
            # remove matches with positive word
            negs_extra = negs_extra[negs_extra != wpos_idx] 
            wnegs = torch.cat([wnegs,negs_extra])
    
        return wnegs[:self.k].tolist()    

    def compute_weighted_unigram_dist(self, alpha=0.75):
        P_alpha = torch.zeros(size=(len(self.vocab),))
        for i,w in enumerate(self.vocab):
            P_alpha[i] = self.unigram_count[w]**alpha
        P_alpha = P_alpha / P_alpha.sum()    
        return P_alpha

    def generate_positive_pairs(self):
        pos_pairs = []
        for s in self.sentences_tokenized:
            for i, w in enumerate(s):
                # randomly pick a context window size between 1..L
                R = torch.randint(0,self.L+1, size=(1,)).item()
                c = w # center_word 
                context_words = s[max(0,i-R):i] + s[i+1:i+1+R]
                for w_pos in context_words:
                    pos_pairs.append((w_pos,c)) 
        return pos_pairs
    
    def check_punc(self, w):
        return any(c.isalpha() for c in w)

    # remove punctuations from list of words and apply lowercase folding 
    def preprocess(self, s):
        words = s.lower().strip().split()[1:]
        words = [w for w in words if self.check_punc(w)]
        return words

    def get_sentences(self):
        # load sentences from file
        word_count = 0
        unigram_count = defaultdict(int)
        with open('datasetSentences.txt','r') as file:
            lines = file.readlines()
            # preprocessing
            sentences = []
            for line in lines[1:]:
                words = self.preprocess(line)
                s = []
                for word in words:
                    if "\/" in word:
                        ws = word.replace("\/", " ").split()
                        for w in ws:
                            s.append(w)
                            unigram_count[w] += 1
                            word_count += 1
                    else:
                        s.append(word)    
                        unigram_count[word] += 1
                        word_count += 1
                sentences.append(s)   

        return sentences, unigram_count, word_count

    def subsample_sentences(self, sentences, num_copies=10):
        num_copies = 10
        discard_probs = {w:self.subsample_prob(w) for w in self.unigram_count.keys()}
        sentences_subsampled = [[word for word in s if np.random.random() >= discard_probs[word]] for s in sentences*num_copies]
        # remove zero length subsampled sentences
        sentences_subsampled = [s for s in sentences_subsampled if len(s) > 0]
        return sentences_subsampled

    def subsample_prob(self, word):
        p = max(0, 1 - np.sqrt(self.subsample_t*self.word_count/self.unigram_count[word]))
        return p



# create pytorch dataset for neatly packaging the Brown Corpus
class BrownCorpus(Dataset):
    def __init__(self, window_size=6, num_negatives=20, remove_stopwords=False, subsample=False, subsample_t=1e-4):
        self.L = window_size   # context window size
        self.k = num_negatives # number of negative samples per positive
        self.remove_stopwords = remove_stopwords
        self.subsample_t = subsample_t
        if remove_stopwords:
            self.stop_words = set(stopwords.words('english'))

        # get text
        text, self.unigram_count, self.word_count = self.get_text()

        # subsample the text
        if subsample:
            text_subsampled = self.subsample_text(text)
        else:
            text_subsampled = text
        # create vocabulary
        self.vocab = sorted(list(set(text_subsampled)))
        self.word2idx = {w:i for i,w in enumerate(self.vocab)}
        print(f"Vocab size: {len(self.vocab)}, word_count: {self.word_count}")
        print(f"Total number of tokens after subsampling: {len(text_subsampled)}")
        # tokenize the text
        self.text_tokenized = [self.word2idx[w] for w in text_subsampled]
        # generate all positive pairs: (w_pos, c)
        self.pos_pairs = self.generate_positive_pairs()
        # compute weighted unigram probability distribution
        self.P_alpha = self.compute_weighted_unigram_dist()

    def __len__(self):
        return len(self.pos_pairs)

    def __getitem__(self, idx):
        # get positive pair
        pos_pair = self.pos_pairs[idx]
        # generate negative words
        wnegs = self.get_negative_samples(pos_pair[0])
        return {"center": pos_pair[1], "positive": pos_pair[0], "negatives": wnegs}
    
    def on_epoch_end(self):
        # generate new pos pairs
        self.pos_pairs = self.generate_positive_pairs()
        random.shuffle(self.pos_pairs)

    def get_negative_samples(self, wpos_idx):
        # generate negative samples (generate extra's to account for removals of matches with wpos_idx)
        wnegs = torch.multinomial(self.P_alpha, self.k+50)
        # remove matches with the positive word
        wnegs = wnegs[wnegs != wpos_idx]    
        
        # if not enough negative samples, generate more
        while  len(wnegs) < self.k:
            # generate more extras
            negs_extra = torch.multinomial(self.P_alpha, self.k+50)
            # remove matches with positive word
            negs_extra = negs_extra[negs_extra != wpos_idx] 
            wnegs = torch.cat([wnegs,negs_extra])
    
        return wnegs[:self.k].tolist()    

    def compute_weighted_unigram_dist(self, alpha=0.75):
        P_alpha = torch.zeros(size=(len(self.vocab),))
        for i,w in enumerate(self.vocab):
            P_alpha[i] = self.unigram_count[w]**alpha
        P_alpha = P_alpha / P_alpha.sum()    
        return P_alpha

    def generate_positive_pairs(self):
        pos_pairs = []
        for i, w in enumerate(self.text_tokenized):
            # randomly pick a context window size between 1..L
            R = torch.randint(0,self.L+1, size=(1,)).item()
            c = w # center_word 
            context_words = self.text_tokenized[max(0,i-R):i] + self.text_tokenized[i+1:i+1+R]
            for w_pos in context_words:
                pos_pairs.append((w_pos,c)) 
        return pos_pairs
    
    def check_punc(self, w):
        return any(c.isalpha() for c in w)

    # remove punctuations from list of words and apply lowercase folding 
    def preprocess(self, s):
        words = [w.lower() for w in s if self.check_punc(w)]
        if self.remove_stopwords:
            words = [w for w in words if not w in self.stop_words]
        return words

    def get_text(self):
        # preprocessing
        text = self.preprocess(brown.words())
        unigram_count = defaultdict(int)
        for word in text:
            unigram_count[word] += 1
        word_count = len(text)

        return text, unigram_count, word_count

    def subsample_text(self, text, num_copies=10):
        num_copies = 10
        discard_probs = {w:self.subsample_prob(w) for w in self.unigram_count.keys()}
        text_subsampled = [word for word in text*num_copies if np.random.random() >= discard_probs[word]]
        return text_subsampled

    def subsample_prob(self, word):
        p = max(0, 1 - np.sqrt(self.subsample_t*self.word_count/self.unigram_count[word]))
        return p


In [142]:
"""
brown_data = BrownCorpus(remove_stopwords=True)
d = brown_data[55]
c = d["center"]
wpos = d["positive"]
wnegs = d["negatives"]
print(f"center: {brown_data.vocab[c]}")
print(f"center: {brown_data.vocab[wpos]}")
print(f"negatives: {[brown_data.vocab[w] for w in wnegs]}")
"""

center: election
center: friday
negatives: ['administered', 'leningrad-kirov', 'first', 'af', 'historical', 'attention', 'water', 'preparations', 'scandal', 'traffic', 'remember', 'mattered', 'undermine', 'supposed', 'faithful', 'age-old', 'constrictors', 'extra', 'method', 'starting']


In [143]:
# create a dataloader for our custom dataset
def collate_fn(batch):
    collated_batch = {}
    for key in batch[0]:
        collated_batch[key] = torch.stack([torch.tensor(item[key]) for item in batch]).unsqueeze(-1)
    return collated_batch

In [144]:
# define a custom pytorch model for skipgram word2vec
class Word2Vec(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dims):
        super().__init__()
        self.emb_C = torch.nn.Embedding(vocab_size, embedding_dims)        
        self.emb_W = torch.nn.Embedding(vocab_size, embedding_dims)        
        # intialize small random weights
        c = 0.01 / embedding_dims
        self.emb_C.weight.data.uniform_(-c, c)
        self.emb_W.weight.data.uniform_(-c, c)

    # forward pass
    def forward(self, center, positive, negatives):
        # get the embedding vectors
        center_embeds = self.emb_C(center.squeeze()) # shape: (B,D)
        pos_embeds = self.emb_W(positive.squeeze()) # shape: (B,D)
        negs_embeds = self.emb_W(negatives.squeeze()) # shape: (B,k,D)
        B, k, D = negs_embeds.shape
        # reshape the tensors so that we can perform batch matrix multiply
        center_embeds = center_embeds.view(B,D,1)
        pos_embeds = pos_embeds.view(B,1,D)
        # compute logits
        pos_logits = torch.bmm(pos_embeds, center_embeds).squeeze()  # shape: (B,)
        negs_logits = torch.bmm(negs_embeds, center_embeds).squeeze() # shape: (B,k)
        # set up labels
        pos_labels = torch.ones_like(pos_logits)
        neg_labels = torch.zeros_like(negs_logits)
        # compute binary cross entropy loss
        pos_loss = F.binary_cross_entropy_with_logits(pos_logits, pos_labels, reduction='none') # shape: (B,)
        negs_loss = F.binary_cross_entropy_with_logits(negs_logits, neg_labels, reduction='none') # shape: (B,k)
        loss = torch.mean(pos_loss + negs_loss.sum(dim=1))
        
        return loss
    
# training loop
def train(model, optimizer, train_data, train_dataloader, device="cpu", num_epochs=10):
    avg_loss = 0
    for epoch in range(num_epochs):
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            center, positive, negatives = batch["center"], batch["positive"], batch["negatives"]
            # move batch to device
            center, positive, negatives = center.to(device), positive.to(device), negatives.to(device)
            # reset gradients
            optimizer.zero_grad()
            # forward pass
            loss = model(center, positive, negatives)
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()

            pbar.set_description(f"Epoch {epoch + 1}, Batch Loss: {loss.item():.3f}, Moving Average Loss: {avg_loss:.3f}")  

        train_data.on_epoch_end()


In [145]:
#train_data = SSTDataset(subsample=False)
train_data = BrownCorpus(remove_stopwords=True)

Vocab size: 47892, word_count: 530090
Total number of tokens after subsampling: 530090


In [116]:
D = 64
B = 256
vocab_size = len(train_data.vocab)
learning_rate = 1e-2
DEVICE = "cuda"

model = Word2Vec(vocab_size=vocab_size, embedding_dims=D).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)


In [117]:
train_dataloader = DataLoader(train_data, batch_size=B, shuffle=False, collate_fn=collate_fn, pin_memory=True, num_workers=4)

In [50]:
for g in optimizer.param_groups:
    g['lr'] = 0.01


In [118]:
train(model, optimizer, train_data, train_dataloader, device=DEVICE, num_epochs=20)

Epoch 1, Batch Loss: 4.938, Moving Average Loss: 5.114: 100%|██████████| 3995/3995 [00:35<00:00, 112.30it/s]
Epoch 2, Batch Loss: 3.901, Moving Average Loss: 3.828: 100%|██████████| 3990/3990 [00:44<00:00, 90.52it/s]
Epoch 3, Batch Loss: 3.746, Moving Average Loss: 3.729: 100%|██████████| 4010/4010 [00:41<00:00, 95.88it/s] 
Epoch 4, Batch Loss: 3.471, Moving Average Loss: 3.628: 100%|██████████| 4008/4008 [00:43<00:00, 92.35it/s] 
Epoch 5, Batch Loss: 3.470, Moving Average Loss: 3.627: 100%|██████████| 3999/3999 [00:39<00:00, 101.30it/s]
Epoch 6, Batch Loss: 3.893, Moving Average Loss: 3.624: 100%|██████████| 4001/4001 [00:42<00:00, 94.68it/s] 
Epoch 7, Batch Loss: 3.748, Moving Average Loss: 3.609: 100%|██████████| 4003/4003 [00:38<00:00, 103.13it/s]
Epoch 8, Batch Loss: 3.409, Moving Average Loss: 3.550: 100%|██████████| 4005/4005 [00:44<00:00, 89.67it/s] 
Epoch 9, Batch Loss: 3.583, Moving Average Loss: 3.573: 100%|██████████| 4005/4005 [00:38<00:00, 103.17it/s]
Epoch 10, Batch Loss

In [119]:
# get copies of embedding matrices
C = model.emb_C.weight.clone().detach()
W = model.emb_W.weight.clone().detach()

# average the context and center embeddings
#E = C
#E = W
E = 0.5 * (C + W)
norms = torch.norm(E, dim=1, keepdim=True)
E = E / norms

def find_most_similar(word, n=10):
    print(f"Finding most similar words for idx: {train_data.word2idx[word]}")
    # get the embedding of this word (sum of context and center embeddings)
    w_emb = E[train_data.word2idx[word]].view(-1,1)
    # compute dot product similarity with all other words
    scores = torch.mm(E,w_emb).view(-1)
    scores[norms.view(-1)==0] = 0
    # find the indices sorted from largest to smallest score
    _, idx = torch.sort(scores, descending=True) 

    # get the n highest scoring words
    best = []
    for i in range(n):
        best.append(train_data.vocab[idx[i]])
    return best, scores[idx[:n]].tolist()    

In [121]:
word = "scream"
best_words, scores = find_most_similar(word)
print(f"Target word: {word}")
print("Most similar words: ", best_words)
print("Similarity scores: ", scores)

Finding most similar words for idx: 14700
Target word: scream
Most similar words:  ['scream', 'budget', 'improvise', 'geriatric', 'charmer', 'two-actor', 'consciousness', 'signpost', 'groggy', 'seeping']
Similarity scores:  [1.0000001192092896, 0.5037898421287537, 0.47533127665519714, 0.4743828773498535, 0.4240070879459381, 0.368988960981369, 0.3607769310474396, 0.36068689823150635, 0.35166338086128235, 0.34774455428123474]
