In [1]:
import itertools
from collections import OrderedDict 
import re
import nltk
from nltk.corpus import brown, gutenberg
from nltk.probability import FreqDist
from nltk.corpus import stopwords

## Corups

In [2]:
gutenberg.fileids()[3]

'bible-kjv.txt'

In [4]:
samples  =gutenberg.sents(gutenberg.fileids()[3])
pattern = re.compile("[A-Za-z]+")
stop_w =  set(stopwords.words('english'))
corpus = []
for sent in samples:
    sent = [w.lower() for w in sent]
    sent = [w for w in sent if w not in stop_w]
    sent = [w.replace('\n', ' ') for w in sent]
    sent = [w for w in sent if pattern.fullmatch(w)]
    if len(sent) > 5:
        corpus.append(sent)

In [5]:
fre_dist = FreqDist()
for sent in corpus:
    fre_dist.update(sent)
fre_dist = {k : v for k, v in fre_dist.items() if v > 5}

In [6]:
vocab_size = len(fre_dist)
idx_to_word = {idx: word for idx,  word in enumerate(fre_dist.keys())}
word_to_idx = {word: idx for idx, word in idx_to_word.items()}


In [7]:
corpus_indexed = [[word_to_idx[word] for word in sent if word in word_to_idx]for sent in corpus]
corpus_indexed = [sent for sent in corpus_indexed if len(sent) > 5]
# fre_dist_indexed = {word_to_idx[w]: f for w, f in fre_dist.items()}

## skipgram with softmax
$$
P(context|center;\theta) = P(w_{O,1}, w_{O,2},...,w_{O,C}|w_I) = \prod^C_{c=1 }\cfrac{\exp(h^\top \text{v}'_{w_{O,c}})}{\sum_{w_i \in V} \exp(h^\top \text{v}'_{w_i})}
$$

In [8]:
import torch
import numpy as np
import torch.functional as F
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm

In [9]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, corpus, window_size=2, sentence_length_threshold=5):
        self.window_size = window_size
        self.sentence_length_threshold = sentence_length_threshold
        self.pairs = self.__generate_pairs(corpus, window_size)
        
    def __generate_pairs(self, corpus, windows_size):
        pairs = []
        for sentence in corpus:
            if len(sentence) < self.sentence_length_threshold:
                continue
            for center_word_pos in range(len(sentence)):
                for shift in range(-windows_size, windows_size + 1):
                    context_word_pos = center_word_pos + shift
                    
                    if (0 <= context_word_pos < len(sentence)) and context_word_pos != center_word_pos:
                        pairs.append((sentence[center_word_pos], sentence[context_word_pos]))
        return pairs

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, index):
        return np.array([self.pairs[index][0]]), np.array([self.pairs[index][1]])

In [10]:
class SkipgramSoftmax(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.syn0 = nn.Embedding(vocab_size, embedding_dim)  # |V| x |K|
        self.syn1 = nn.Linear(embedding_dim, vocab_size)  # |K| x |V|

    def forward(self, center, context):
        # center: [b_size, 1]
        # context: [b_size, 1]
        embds = self.syn0(center.view(-1))
        out = self.syn1(embds)
        log_probs = F.log_softmax(out, dim=1)
        loss = F.nll_loss(log_probs, context.view(-1), reduction='mean')
        return loss

In [11]:
EMBEDDING_DIM = 100
model = SkipgramSoftmax(vocab_size, EMBEDDING_DIM)
optimizer = optim.Adam(model.parameters(), lr=0.001,  weight_decay=1e-6)


In [12]:
dataset = Dataset(corpus_indexed)
data_loader = DataLoader(dataset, batch_size=500, num_workers=0)

In [13]:
log_interval = 100
for epoch_i in range(5):
    total_loss = 0
    model.train()
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (center, context) in enumerate(tk0):
       
        loss = model(center, context)

        model.zero_grad()
        loss.backward()
        
        optimizer.step()
        total_loss += loss.item()
        if(i + 1) % log_interval == 0:
            tk0.set_postfix(loss = total_loss/log_interval)
            total_loss = 0

100%|██████████| 2418/2418 [03:08<00:00, 12.81it/s, loss=7.11]
100%|██████████| 2418/2418 [02:49<00:00, 14.25it/s, loss=6.86]
100%|██████████| 2418/2418 [02:51<00:00, 14.14it/s, loss=6.75]
100%|██████████| 2418/2418 [02:41<00:00, 14.93it/s, loss=6.68]
100%|██████████| 2418/2418 [03:21<00:00, 12.00it/s, loss=6.63]
100%|██████████| 2418/2418 [02:55<00:00, 13.78it/s, loss=6.58]
100%|██████████| 2418/2418 [03:03<00:00, 13.17it/s, loss=6.55]
100%|██████████| 2418/2418 [02:50<00:00, 14.21it/s, loss=6.52]
100%|██████████| 2418/2418 [02:36<00:00, 15.45it/s, loss=6.49]
100%|██████████| 2418/2418 [02:30<00:00, 16.11it/s, loss=6.47]


### fetch word embedding

In [15]:
syn0 = model.syn0.weight.data
syn1 = model.syn1.weight.data


w2v_embedding = (syn0 + syn1) / 2
w2v_embedding = w2v_embedding.numpy()
l2norm = np.linalg.norm(w2v_embedding, 2, axis=1, keepdims=True)
w2v_embedding = w2v_embedding / l2norm


# Evaluation

In [18]:
class CosineSimilarity:
    def __init__(self, word_embedding, idx_to_word_dict, word_to_idx_dict):
        self.word_embedding = word_embedding # normed already
        self.idx_to_word_dict = idx_to_word_dict
        self.word_to_idx_dict = word_to_idx_dict
        
    def get_synonym(self, word, topK=10):
        idx = self.word_to_idx_dict[word]
        embed = self.word_embedding[idx]
        
        cos_similairty = w2v_embedding @ embed
        
        topK_index = np.argsort(-cos_similairty)[:topK]
        pairs = []
        for i in topK_index:
            w = self.idx_to_word_dict[i]
#             pairs[w] = cos_similairty[i]
            pairs.append((w, cos_similairty[i]))
        return pairs
        
    
    

In [19]:
cosineSim = CosineSimilarity(w2v_embedding, idx_to_word, word_to_idx)
cosineSim.get_synonym('christ')

[('christ', 0.9999999),
 ('gospel', 0.6221024),
 ('jesus', 0.60206234),
 ('resurrection', 0.5797884),
 ('sufferings', 0.5712591),
 ('faith', 0.5694477),
 ('appearing', 0.5664267),
 ('apostle', 0.5652396),
 ('justified', 0.5442045),
 ('timothy', 0.53993124)]

In [20]:
cosineSim.get_synonym('jesus')

[('jesus', 1.0),
 ('christ', 0.60206234),
 ('answering', 0.5809975),
 ('crucified', 0.5506017),
 ('nazareth', 0.54799545),
 ('peter', 0.54017067),
 ('apostle', 0.5311042),
 ('disciples', 0.5209755),
 ('baptized', 0.511466),
 ('pilate', 0.502664)]

#### debug

In [60]:
model = SkipgramSoftmax(vocab_size, EMBEDDING_DIM)

In [21]:
dl = iter(data_loader)
center, context = next(dl)

In [22]:
center.shape, context.shape

(torch.Size([500, 1]), torch.Size([500, 1]))

In [25]:
center[10], context[10]

(tensor([3]), tensor([2]))