In [1]:
import itertools
from collections import OrderedDict 
import re
import nltk
from nltk.corpus import brown, gutenberg
from nltk.probability import FreqDist
from nltk.corpus import stopwords

## Corups

In [2]:
samples  =gutenberg.sents(gutenberg.fileids()[0])
pattern = re.compile("[A-Za-z]+")
stop_w =  set(stopwords.words('english'))
corpus = []
for sent in samples:
    sent = [w.lower() for w in sent]
    sent = [w for w in sent if w not in stop_w]
    sent = [w.replace('\n', ' ') for w in sent]
    sent = [w for w in sent if pattern.fullmatch(w)]
    if len(sent) > 5:
        corpus.append(sent)

In [3]:
fre_dist = FreqDist()
for sent in corpus:
    fre_dist.update(sent)
fre_dist = {k : v for k, v in fre_dist.items() if v > 5}

In [4]:
vocab_size = len(fre_dist)
idx_to_word = {idx: word for idx,  word in enumerate(fre_dist.keys())}
word_to_idx = {word: idx for idx, word in idx_to_word.items()}


In [5]:
corpus_indexed = [[word_to_idx[word] for word in sent if word in word_to_idx]for sent in corpus]
corpus_indexed = [sent for sent in corpus_indexed if len(sent) > 5]
fre_dist_indexed = {word_to_idx[w]: f for w, f in fre_dist.items()}

## skipgram with softmax

In [6]:
import torch
import numpy as np
import torch.functional as F
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm

In [7]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, corpus, window_size=2, sentence_length_threshold=5):
        self.window_size = window_size
        self.sentence_length_threshold = sentence_length_threshold
        self.pairs = self.__generate_pairs(corpus, window_size)
        
    def __generate_pairs(self, corpus, windows_size):
        pairs = []
        for sentence in corpus:
            if len(sentence) < self.sentence_length_threshold:
                continue
            for center_word_pos in range(len(sentence)):
                for shift in range(-windows_size, windows_size + 1):
                    context_word_pos = center_word_pos + shift
                    
                    if (0 <= context_word_pos < len(sentence)) and context_word_pos != center_word_pos:
                        pairs.append((sentence[center_word_pos], sentence[context_word_pos]))
        return pairs

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, index):
        return np.array([self.pairs[index][0]]), np.array([self.pairs[index][1]])

In [8]:
class SkipgramSoftmax(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.syn0 = nn.Embedding(vocab_size, embedding_dim)  # |V| x |K|
        self.syn1 = nn.Linear(embedding_dim, vocab_size)  # |K| x |V|

    def forward(self, center, context):
        # center: [b_size, 1]
        # context: [b_size, 1]
        embds = self.syn0(center.view(-1))
        out = self.syn1(embds)
        log_probs = F.log_softmax(out, dim=1)
        loss = F.nll_loss(log_probs, context.view(-1), reduction='mean')
        return loss

In [9]:
EMBEDDING_DIM = 50
model = SkipgramSoftmax(vocab_size, EMBEDDING_DIM)
optimizer = optim.Adam(model.parameters(), lr=0.001,  weight_decay=1e-6)


In [10]:
dataset = Dataset(corpus_indexed)
data_loader = DataLoader(dataset, batch_size=100, num_workers=0)

In [11]:
log_interval = 100
for epoch_i in range(10):
    total_loss = 0
    model.train()
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (center, context) in enumerate(tk0):
       
        loss = model(center, context)

        model.zero_grad()
        loss.backward()
        
        optimizer.step()
        total_loss += loss.item()
        if(i + 1) % log_interval == 0:
            tk0.set_postfix(loss = total_loss/log_interval)
            total_loss = 0

100%|██████████| 1845/1845 [00:13<00:00, 141.67it/s, loss=6.9]
100%|██████████| 1845/1845 [00:12<00:00, 146.03it/s, loss=6.61]
100%|██████████| 1845/1845 [00:12<00:00, 143.85it/s, loss=6.5]
100%|██████████| 1845/1845 [00:12<00:00, 145.09it/s, loss=6.44]
100%|██████████| 1845/1845 [00:12<00:00, 144.16it/s, loss=6.39]
100%|██████████| 1845/1845 [00:12<00:00, 145.40it/s, loss=6.36]
100%|██████████| 1845/1845 [00:12<00:00, 142.82it/s, loss=6.33]
100%|██████████| 1845/1845 [00:12<00:00, 145.18it/s, loss=6.31]
100%|██████████| 1845/1845 [00:12<00:00, 146.23it/s, loss=6.28]
100%|██████████| 1845/1845 [00:12<00:00, 145.52it/s, loss=6.26]
100%|██████████| 1845/1845 [00:13<00:00, 141.20it/s, loss=6.24]
100%|██████████| 1845/1845 [00:15<00:00, 121.88it/s, loss=6.22]
100%|██████████| 1845/1845 [00:15<00:00, 115.82it/s, loss=6.2]
100%|██████████| 1845/1845 [00:15<00:00, 115.33it/s, loss=6.19]
100%|██████████| 1845/1845 [00:14<00:00, 130.63it/s, loss=6.17]
100%|██████████| 1845/1845 [00:13<00:00, 13

### syn0 and syn1 emebdding

In [18]:
emebdding1 = model.syn0.weight.data
emebdding1.shape

torch.Size([1749, 50])

In [19]:
emebdding2 = model.syn1.weight.data
emebdding2.shape

torch.Size([1749, 50])

#### debug

In [60]:
model = SkipgramSoftmax(vocab_size, EMBEDDING_DIM)

In [61]:
dl = iter(data_loader)
center, context = next(dl)

In [63]:
center = center.view(-1)
context = context.view(-1)
center.shape, context.shape

(torch.Size([100]), torch.Size([100]))