In [1]:
import spacy
import tqdm
import torch
import numpy as np
from torch import nn
from collections import Counter
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
spacy_eng = spacy.load('en')
def tokenizer_eng(text):
    return [t.text for t in spacy_eng.tokenizer(text)]

In [3]:
text = 'what is your name?'
tokenizer_eng(text)

['what', 'is', 'your', 'name', '?']

In [4]:
class GetDataset(Dataset):
    def __init__(self, file_name, window_size=2):
        self.window_size = window_size
        raw_text = open(file_name, 'r', encoding='utf-8').read().lower()
        tokenized_words = tokenizer_eng(raw_text)
        self.data_pairs = [
                      (
                          [tokenized_words[i-(j+1)] for j in range(window_size)] + [tokenized_words[i+(j+1)] for j in range(window_size)], tokenized_words[i]
                      ) for i in range(window_size, len(tokenized_words)-window_size)
                    ]
        self.vocab = Counter(tokenized_words)
        self.vocab_size = len(self.vocab)
        self.stoi = {item[0]: idx for idx, item in enumerate(self.vocab.most_common())}
        self.itos = list(self.stoi.keys())

    def __len__(self):
        return len(self.data_pairs)
    
    def __getitem__(self, index):
        context = torch.LongTensor([self.stoi[w] for w in self.data_pairs[index][0]])
        target = torch.tensor(self.stoi[self.data_pairs[index][1]])
        return context, target

In [5]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.fc = nn.Sequential(
                        nn.Linear(embedding_size, 128),
                        nn.ReLU(),
                        nn.Linear(128, 256),
                        nn.ReLU(),
                        nn.Linear(256, 512),
                        nn.ReLU(),
                        nn.Linear(512, vocab_size)
                    )        

    def forward(self, x):
        x = self.embedding(x).sum(1) # (b, t, d) -> (b, d)
        return self.fc(x)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
window_size = 2
n_epochs = 20
embedding_size = 1028
batch_size = 64
lr = 3e-4
print(device)

cuda


In [7]:
data = GetDataset('input.txt', window_size)
loader = DataLoader(data, batch_size=batch_size, num_workers=2, pin_memory=True, shuffle=True)
x, y = next(iter(loader))
print(len(data), data.vocab_size, x.shape, y.shape)

287756 12340 torch.Size([64, 4]) torch.Size([64])


In [8]:
net = CBOW(data.vocab_size, embedding_size).to(device)
inp = torch.LongTensor([[0, 1, 2, 3]]).to(device)
out = net(inp)
print(out.shape)
del inp, out
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

torch.Size([1, 12340])


In [9]:
def loop(net, loader, epoch):
    net.train()
    losses = []
    pbar = tqdm.tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        preds = net(x)
        loss = loss_fn(preds, y)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_description(f'epoch={epoch}, avg_loss={np.mean(losses):.4f}')

In [10]:
for epoch in range(n_epochs):
    loop(net, loader, epoch)

epoch=0, avg_loss=5.3158: 100%|██████████| 4497/4497 [00:42<00:00, 105.01it/s]
epoch=1, avg_loss=4.7350: 100%|██████████| 4497/4497 [00:42<00:00, 105.30it/s]
epoch=2, avg_loss=4.3994: 100%|██████████| 4497/4497 [00:42<00:00, 105.25it/s]
epoch=3, avg_loss=4.1293: 100%|██████████| 4497/4497 [00:42<00:00, 105.00it/s]
epoch=4, avg_loss=3.8954: 100%|██████████| 4497/4497 [00:42<00:00, 105.05it/s]
epoch=5, avg_loss=3.6806: 100%|██████████| 4497/4497 [00:42<00:00, 105.23it/s]
epoch=6, avg_loss=3.4775: 100%|██████████| 4497/4497 [00:42<00:00, 105.40it/s]
epoch=7, avg_loss=3.2823: 100%|██████████| 4497/4497 [00:42<00:00, 105.28it/s]
epoch=8, avg_loss=3.0992: 100%|██████████| 4497/4497 [00:42<00:00, 105.22it/s]
epoch=9, avg_loss=2.9342: 100%|██████████| 4497/4497 [00:42<00:00, 105.33it/s]
epoch=10, avg_loss=2.7922: 100%|██████████| 4497/4497 [00:42<00:00, 105.22it/s]
epoch=11, avg_loss=2.6720: 100%|██████████| 4497/4497 [00:42<00:00, 105.07it/s]
epoch=12, avg_loss=2.5684: 100%|██████████| 4497/4

In [23]:
def get_similar_words(word, data):
    if word not in data.stoi:
        raise Exception('word not found in the vocab!!')
    word_idx = torch.LongTensor([data.stoi[word]]).to(device)
    word_embedding = net.embedding(word_idx)
    similar_words = []
    
    for curr_word in data.vocab:
        if curr_word == word: 
            continue
        curr_idx = torch.LongTensor([data.stoi[curr_word]]).to(device)
        curr_embedding = net.embedding(curr_idx)
        cosine_sim = F.cosine_similarity(word_embedding, curr_embedding)
        similar_words.append([curr_word, cosine_sim.item()])

    return sorted(similar_words, key=lambda x: x[1], reverse=True)[:10]

In [24]:
get_similar_words('what', data)

[['excepting', 0.12064912170171738],
 ['bon', 0.11391300708055496],
 ['foremost', 0.1138283982872963],
 ['volumnia', 0.11017448455095291],
 ['devouring', 0.10938507318496704],
 ["view'd", 0.1080944836139679],
 ['descends', 0.10349922627210617],
 ['hasty', 0.1014949157834053],
 ['gory', 0.1014704555273056],
 ['unseemly', 0.10107247531414032]]

In [25]:
get_similar_words('dog', data)

[['troth', 0.11910170316696167],
 ['kindest', 0.11287535727024078],
 ['prisoner', 0.11184903979301453],
 ['capers', 0.10974440723657608],
 ['den', 0.10471750050783157],
 ['characters', 0.10162997245788574],
 ['smell', 0.10123172402381897],
 ['answering', 0.10066694766283035],
 ['allegiance', 0.1002424955368042],
 ['uncrown', 0.09983837604522705]]

In [26]:
get_similar_words('asdas', data)

Exception: ignored