# Word2Vec (Skipgram )

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from string import punctuation
import time

In [2]:
import nltk
from nltk.corpus import brown
from nltk.corpus import stopwords
from collections import Counter
import matplotlib
nltk.download('stopwords')
nltk.download('brown')

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\swara\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\swara\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!


True

## 1. Load data

In [3]:
corpus = brown.sents()

In [4]:
stop_words = set(stopwords.words('english'))
corpus = [[word for word in sent if word.lower() not in stop_words] for sent in corpus]

# Remove punctuation from corpus
corpus = [[word for word in sent if word not in punctuation] for sent in corpus]

# Remove empty sentences
corpus = [sent for sent in corpus if len(sent) > 0]

# Remove sentences with less than 5 words
corpus = [sent for sent in corpus if len(sent) >= 5]

# Remove sentences with more than 20 words
corpus = [sent for sent in corpus if len(sent) <= 20]

# Remove rare words
word_freq = Counter([word for sent in corpus for word in sent])
corpus = [[word for word in sent if word_freq[word] > 5] for sent in corpus]

In [5]:
#2. numeralization
#find unique words
flatten = lambda l: [item for sublist in l for item in sublist]
#assign unique integer
vocabs = list(set(flatten(corpus))) #all the words we have in the system - <UNK>

In [6]:
#create handy mapping between integer and word
word2index = {v:idx for idx, v in enumerate(vocabs)}
word2index['dog']

3158

In [7]:
vocabs.append('<UNK>')
word2index['<UNK>'] = len(word2index)

In [8]:
index2word = {v:k for k, v in word2index.items()}
index2word[len(index2word) - 1]

'<UNK>'

## 2. Prepare train data

In [9]:
#create pairs of center word, and outside word

def random_batch(batch_size, corpus, window_size=2):

    skipgrams = []

    #loop each corpus
    for doc in corpus:
        #look from the 2nd word until second last word
        for i in range(window_size, len(doc) - window_size):
            #center word
            center = word2index[doc[i]]
            #outside words = rest of the words
            outside_start =  i - window_size
            outside_end =  i + window_size + 1

            for j in range(outside_start, outside_end):
                if i != j:  # Skip the center word
                    outside = word2index[doc[j]]
                    skipgrams.append([center, outside])
                
    random_index = np.random.choice(range(len(skipgrams)), batch_size, replace=False)
    
    inputs, labels = [], []
    for index in random_index:
        inputs.append([skipgrams[index][0]])
        labels.append([skipgrams[index][1]])
        
    return np.array(inputs), np.array(labels)
            
x, y = random_batch(2, corpus)

In [10]:
x.shape  #batch_size, 1

(2, 1)

In [11]:
x

array([[2911],
       [8851]])

In [12]:
y.shape  #batch_size 1

(2, 1)

## 3. Model

$$J(\theta) = -\frac{1}{T}\sum_{t=1}^{T}\sum_{\substack{-m \leq j \leq m \\ j \neq 0}}\log P(w_{t+j} | w_t; \theta)$$

where $P(w_{t+j} | w_t; \theta) = $

$$P(o|c)=\frac{\exp(\mathbf{u_o^{\top}v_c})}{\sum_{w=1}^V\exp(\mathbf{u_w^{\top}v_c})}$$

where $o$ is the outside words and $c$ is the center word

In [13]:
len(vocabs)

10583

In [14]:
embedding = nn.Embedding(len(vocabs), 2)

In [15]:
x_tensor = torch.LongTensor(x)
embedding(x_tensor).shape  #(batch_size, 1, emb_size)

torch.Size([2, 1, 2])

$$P(o|c)=\frac{\exp(\mathbf{u_o^{\top}v_c})}{\sum_{w=1}^V\exp(\mathbf{u_w^{\top}v_c})}$$

In [16]:
class Skipgram(nn.Module):
    
    def __init__(self, voc_size, emb_size):
        super(Skipgram, self).__init__()
        self.embedding_center  = nn.Embedding(voc_size, emb_size)
        self.embedding_outside = nn.Embedding(voc_size, emb_size)
    
    def forward(self, center, outside, all_vocabs):
        center_embedding     = self.embedding_center(center)  #(batch_size, 1, emb_size)
        outside_embedding    = self.embedding_center(outside) #(batch_size, 1, emb_size)
        all_vocabs_embedding = self.embedding_center(all_vocabs) #(batch_size, voc_size, emb_size)
        
        top_term = torch.exp(outside_embedding.bmm(center_embedding.transpose(1, 2)).squeeze(2))
        #batch_size, 1, emb_size) @ (batch_size, emb_size, 1) = (batch_size, 1, 1) = (batch_size, 1) 

        lower_term = all_vocabs_embedding.bmm(center_embedding.transpose(1, 2)).squeeze(2)
        #batch_size, voc_size, emb_size) @ (batch_size, emb_size, 1) = (batch_size, voc_size, 1) = (batch_size, voc_size) 
        
        lower_term_sum = torch.sum(torch.exp(lower_term), 1)  #(batch_size, 1)
        
        loss = -torch.mean(torch.log(top_term / lower_term_sum))  #scalar
        
        return loss
        

In [17]:
#prepare all vocabs

batch_size = 2
voc_size   = len(vocabs)

def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index["<UNK>"], seq))
    return torch.LongTensor(idxs)

all_vocabs = prepare_sequence(list(vocabs), word2index).expand(batch_size, voc_size)
all_vocabs

tensor([[    0,     1,     2,  ..., 10580, 10581, 10582],
        [    0,     1,     2,  ..., 10580, 10581, 10582]])

In [18]:
model = Skipgram(voc_size, 2)
model

Skipgram(
  (embedding_center): Embedding(10583, 2)
  (embedding_outside): Embedding(10583, 2)
)

In [19]:
input_tensor = torch.LongTensor(x)
label_tensor = torch.LongTensor(y)

In [20]:
loss = model(input_tensor, label_tensor, all_vocabs)

In [21]:
loss

tensor(10.0398, grad_fn=<NegBackward0>)

## 4. Training

In [22]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [23]:
batch_size = 2
emb_size   = 2
model      = Skipgram(voc_size, emb_size)
optimizer  = optim.Adam(model.parameters(), lr=0.001)

In [24]:
start = time.time()
num_epochs = 1000
window_size = 5

for epoch in range(num_epochs):
    
    #get batch
    input_batch, label_batch = random_batch(batch_size, corpus, window_size)
    input_tensor = torch.LongTensor(input_batch)
    label_tensor = torch.LongTensor(label_batch)
    
    #predict
    loss = model(input_tensor, label_tensor, all_vocabs)
    
    #backprogate
    optimizer.zero_grad()
    loss.backward()
    
    #update alpha
    optimizer.step()

    epoch_mins, epoch_secs = epoch_time(start, time.time())
    
    #print the loss
    if (epoch + 1) % 100 == 0:
        print(f"Epoch: {epoch + 1} | Loss: {loss:.6f} | Time: {epoch_mins}m {epoch_secs}s")

Epoch: 100 | Loss: 9.686029 | Time: 1m 23s
Epoch: 200 | Loss: 9.613126 | Time: 2m 42s
Epoch: 300 | Loss: 10.529035 | Time: 4m 1s
Epoch: 400 | Loss: 9.488087 | Time: 5m 19s
Epoch: 500 | Loss: 8.416470 | Time: 6m 40s
Epoch: 600 | Loss: 9.561131 | Time: 8m 0s
Epoch: 700 | Loss: 9.829943 | Time: 9m 20s
Epoch: 800 | Loss: 10.165442 | Time: 10m 38s
Epoch: 900 | Loss: 8.747721 | Time: 11m 54s
Epoch: 1000 | Loss: 8.801099 | Time: 13m 11s


## 5. Plot the embeddings

Is fruit really near to fish?
Is fruit really far from cat?

In [25]:
vocabs

['evil',
 'accelerometer',
 '200',
 'fish',
 'mimesis',
 'detectives',
 'paste',
 'intend',
 'odors',
 'draw',
 'Warsaw',
 'bears',
 'thief',
 'develop',
 'butter',
 'swell',
 'ads',
 'differently',
 'requests',
 'H',
 'Methodist',
 'architectural',
 'smell',
 'Sponsor',
 'Cologne',
 '2:37',
 'chapter',
 'utterly',
 'break',
 'i.e.',
 'Riverside',
 'towels',
 'tires',
 'Bright',
 'conclusive',
 'experts',
 'blanket',
 'Shelley',
 'illustration',
 'radio',
 'waste',
 'Sea',
 'Civil',
 'verse',
 'boss',
 'inspection',
 'non-violent',
 'addition',
 'healthy',
 'homely',
 'proclaimed',
 'unimportant',
 'Homeric',
 'candle',
 'Blue',
 'indebted',
 'Usually',
 'ashamed',
 'decks',
 'bought',
 'committees',
 'fresh',
 "they've",
 'dust',
 'tightly',
 'anxious',
 'troubled',
 'battle',
 'prone',
 'availability',
 'dedicated',
 'Highlands',
 'E.',
 'friendship',
 'grinning',
 'Life',
 'drops',
 'appearances',
 'happened',
 'sensational',
 'alliance',
 'sample',
 'commercials',
 'erected',
 'wea

In [26]:
fish = torch.LongTensor([word2index['dog']])
fish

tensor([3158])

In [27]:
fish_embed_c = model.embedding_center(fish)
fish_embed_o = model.embedding_outside(fish)
fish_embed   = (fish_embed_c + fish_embed_o) / 2
fish_embed

tensor([[-0.4071,  0.0078]], grad_fn=<DivBackward0>)

In [28]:
fish_embed_o

tensor([[-0.5963, -0.0380]], grad_fn=<EmbeddingBackward0>)

In [29]:
def get_embed(word):
    try:
        index = word2index[word]
    except:
        index = word2index['<UNK>']
        
    word = torch.LongTensor([word2index[word]])
    
    embed_c = model.embedding_center(word)
    embed_o = model.embedding_outside(word)
    embed   = (embed_c + embed_o) / 2
    
    return embed[0][0].item(), embed[0][1].item()

In [30]:
get_embed('animal')

(0.5301620364189148, -1.1365153789520264)

In [31]:
get_embed('cat')

(0.2453116923570633, 0.5773712992668152)

In [32]:
get_embed('dog')

(-0.40714454650878906, 0.00776178203523159)

In [33]:
get_embed('fish')

(-0.7001842260360718, 1.021337866783142)

## 6. Cosine similarity

In [34]:
fish = get_embed('fish')
fish

(-0.7001842260360718, 1.021337866783142)

In [35]:
fruit = get_embed('fruit')
fruit

(0.021597012877464294, 0.7026581168174744)

In [36]:
unk = get_embed('<UNK>')
unk

(0.7222745418548584, 0.12681323289871216)

In [37]:
np.array(fish) @ np.array(unk)

-0.3762060843055579

In [38]:
#more formally is to divide by its norm
def cosine_similarity(A, B):
    dot_product = np.dot(A, B)
    norm_a = np.linalg.norm(A)
    norm_b = np.linalg.norm(B)
    similarity = dot_product / (norm_a * norm_b)
    return similarity

print(cosine_similarity(np.array(fish), np.array(unk)))
print(cosine_similarity(np.array(fish), np.array(fruit)))

-0.4142900904790245
0.807029211856514


In [39]:
# Create a pickle of the model
import pickle

with open('skipgram.pkl', 'wb') as f:
    pickle.dump(model, f)