In [1]:
import itertools
from collections import OrderedDict 
import re
import nltk
from nltk.corpus import brown, gutenberg
from nltk.probability import FreqDist
from nltk.corpus import stopwords

## Corpus

In [2]:
samples  =gutenberg.sents(gutenberg.fileids()[0])
pattern = re.compile("[A-Za-z]+")
stop_w =  set(stopwords.words('english'))
corpus = []
for sent in samples:
    sent = [w.lower() for w in sent]
    sent = [w for w in sent if w not in stop_w]
    sent = [w.replace('\n', ' ') for w in sent]
    sent = [w for w in sent if pattern.fullmatch(w)]
    if len(sent) > 5:
        corpus.append(sent)

In [3]:
fre_dist = FreqDist()
for sent in corpus:
    fre_dist.update(sent)
fre_dist = {k : v for k, v in fre_dist.items() if v > 5}

In [4]:
vocab_size = len(fre_dist)
idx_to_word = {idx: word for idx,  word in enumerate(fre_dist.keys())}
word_to_idx = {word: idx for idx, word in idx_to_word.items()}

In [5]:
corpus_indexed = [[word_to_idx[word] for word in sent if word in word_to_idx]for sent in corpus]
corpus_indexed = [sent for sent in corpus_indexed if len(sent) > 5]
fre_dist_indexed = {word_to_idx[w]: f for w, f in fre_dist.items()}

## CBOW with softmax

In [6]:
import torch
import numpy as np
import torch.functional as F
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm

In [7]:
class CBOWDataset(torch.utils.data.Dataset):
    def __init__(self, corpus, windows_size=2, sentence_length_threshold=5):
        self.windows_size = windows_size
        self.sentence_length_threshold = sentence_length_threshold
        self.contexts, self.centers = self._generate_pairs(corpus, windows_size)
        
    def _generate_pairs(self, corpus, windows_size):
        contexts = []
        centers = []
        
        for sent in corpus:
            if len(sent) < self.sentence_length_threshold:
                continue
            
            for center_word_pos in range(len(sent)):
                context = []
                for w in range(-windows_size, windows_size + 1):
                    context_word_pos = center_word_pos + w
                    if(0 <= context_word_pos < len(sent) and context_word_pos != center_word_pos):
                        context.append(sent[context_word_pos])
                if(len(context) == 2 * self.windows_size):
                    contexts.append(context)
                    centers.append(sent[center_word_pos])
        return contexts, centers
    
    def __len__(self):
        return len(self.centers)
    
    def __getitem__(self, index):
        return np.array(self.contexts[index]), np.array([self.centers[index]])

In [8]:
class CBOWSoftmax(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.syn0 = nn.Embedding(vocab_size, embedding_dim)
        self.syn1 = nn.Linear(embedding_dim, vocab_size)
    
    def forward(self, context, center):
        #  context: [b_size, windows_size]
        #  center: [b_size, 1]
        embds = self.syn0(context).mean(dim=1) # [b_size, embedding_dim]
        out = self.syn1(embds)
        
        log_probs = F.log_softmax(out, dim=1)
        loss = F.nll_loss(log_probs, center.view(-1), reduction='mean')
        return loss
        

In [12]:
EMBEDDING_DIM = 50
model = CBOWSoftmax(vocab_size, EMBEDDING_DIM)
optimizer = optim.Adam(model.parameters(), lr=0.001,  weight_decay=1e-6)

In [13]:
dataset = CBOWDataset(corpus_indexed)
data_loader = DataLoader(dataset, batch_size=100, num_workers=0)

In [16]:
log_interval = 100
for epoch_i in range(10):
    total_loss = 0
    model.train()
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (context, center) in enumerate(tk0):
       
        loss = model(context, center)

        model.zero_grad()
        loss.backward()
        
        optimizer.step()
        total_loss += loss.item()
        if(i + 1) % log_interval == 0:
            tk0.set_postfix(loss = total_loss/log_interval)
            total_loss = 0

100%|██████████| 361/361 [00:03<00:00, 112.92it/s, loss=7.3]
100%|██████████| 361/361 [00:03<00:00, 116.16it/s, loss=6.87]
100%|██████████| 361/361 [00:03<00:00, 93.38it/s, loss=6.54]
100%|██████████| 361/361 [00:06<00:00, 57.69it/s, loss=6.37]
100%|██████████| 361/361 [00:04<00:00, 83.32it/s, loss=6.24]
100%|██████████| 361/361 [00:04<00:00, 87.73it/s, loss=6.13]
100%|██████████| 361/361 [00:03<00:00, 99.07it/s, loss=6.03]
100%|██████████| 361/361 [00:03<00:00, 90.57it/s, loss=5.93]
100%|██████████| 361/361 [00:03<00:00, 94.97it/s, loss=5.85]
100%|██████████| 361/361 [00:03<00:00, 96.49it/s, loss=5.76]


In [17]:
emebdding = model.syn0.weight.data
emebdding.shape

torch.Size([1749, 50])

#### debug

In [20]:
emebdding[word_to_idx['woodhouse']]

tensor([-3.2198,  1.6870, -1.1216,  0.2148,  0.5650, -1.5975,  2.0637, -0.0704,
        -0.5054,  0.1672, -0.6360, -1.1686,  0.2833, -0.5476,  1.6006, -1.8558,
         2.0661,  1.4087, -0.9098, -0.1002,  0.1536,  0.6631,  0.4492,  1.7913,
        -2.0220,  1.4267,  2.7691, -2.6137,  0.7928, -1.2498,  0.2200,  0.5134,
         0.1859,  0.9686,  0.5371,  1.5631, -0.0983, -0.8209, -0.1594, -0.2577,
        -1.6348, -0.7915,  0.8425, -3.2552, -0.3174,  0.9666,  0.9690,  0.2145,
        -1.5016, -1.5281])

In [15]:
context.shape, center.shape

(torch.Size([100, 4]), torch.Size([100, 1]))