In [243]:
# Import statements
# Build Dataset
# Negative Sampling
# Define Model
# Predictions & vector plotting

In [244]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import defaultdict
from tqdm.notebook import trange, tqdm

from preprocess_dataset import StanfordSentiment

# Dataset

In [245]:
class CustomDataset(Dataset):
    def __init__(self, path, window_size = 2, transform = None, target_transform = None):

        # Get the raw dataset
        dataset = StanfordSentiment(path)
        self.corpus = dataset.get_dataset()

        self.dataset = []
        self.window_size = window_size
        self.word2index = {}
        self.vocab = []
        self.word_freq = defaultdict(int)
        self.construct_dataset()
        
    def construct_dataset(self):
        for sentence in self.corpus:
            for i, word in enumerate(sentence):
                self.word_freq[word] += 1
                if self.word2index.get(word) is None:
                    self.word2index[word] = len(self.word2index)
                
                target_word_combinations = []
                for k in range(i - self.window_size, i + self.window_size + 1):
                    if k == i or k < 0 or k >= len(sentence):
                        continue
                    target_word_combinations.append([word] + [sentence[k]])
                self.dataset.extend(target_word_combinations)
                     
        self.index2word = {i: w for w, i in self.word2index.items()}
        self.vocab_size = len(self.index2word)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return (torch.LongTensor([self.word2index[self.dataset[index][0]]]), torch.LongTensor([self.word2index[self.dataset[index][1]]]))

In [246]:
train_dataset = CustomDataset(path = './test.txt', window_size = 2)
train_data_loader = DataLoader(train_dataset, batch_size = 1, shuffle = False)

# Negative Sampling

In [247]:
negative_sample_table_size = 1000000
negative_samples_count = 1
negative_sample_table = [0] * negative_sample_table_size

def construct_negative_sample_table(dataset):
    word_freq = [dataset.word_freq[dataset.index2word[i]] for i in range(dataset.vocab_size)]
    word_freq = np.array(word_freq) ** 0.75
    word_freq = word_freq / np.sum(word_freq)
    word_freq = np.cumsum(word_freq) * negative_sample_table_size

    j = 0
    for i in range(0, negative_sample_table_size):
        while i > word_freq[j]:
            j += 1
        negative_sample_table[i] = j

def get_sample_word():
     return negative_sample_table[random.randint(0, negative_sample_table_size - 1)]

def get_negative_samples(target_word_index, context_word_index):
    negative_samples = torch.LongTensor([])
    
    while negative_samples.shape[0] < negative_samples_count:
        sample_word_index = get_sample_word()
        while sample_word_index == target_word_index or sample_word_index == context_word_index:
            sample_word_index = get_sample_word()
        negative_samples = torch.cat((negative_samples, torch.LongTensor([sample_word_index])), dim = 0)
    return negative_samples

def generate_batch_negative_samples(target_word_indices, context_word_indices):
    batches = target_word_indices.shape[0]

    negative_samples = torch.LongTensor([])
    for i in range(batches):
        negative_sample = get_negative_samples(target_word_indices[i].data, context_word_indices[i].data)
        negative_samples = torch.cat((negative_samples, negative_sample.unsqueeze(0)), dim = 0)

    return negative_samples

In [248]:
construct_negative_sample_table(train_dataset)

# Model

In [249]:
class Model(nn.Module):
    def __init__(self, embedding_size, vocab_size):
        super().__init__()
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size

        self.hidden_layer = nn.Embedding(vocab_size, embedding_size)
        self.output_layer = nn.Embedding(vocab_size, embedding_size)
        self.logSigmoid = nn.LogSigmoid()

        self.hidden_layer.weight.data.uniform_(-1.0, 1.0)
        self.output_layer.weight.data.uniform_(-1.0, 1.0)
        
    def forward(self, target_word_indices, context_word_indices, negative_word_indices):
        target_word_embeddings = self.hidden_layer(target_word_indices) # B x 1 x D
        context_word_embeddings = self.output_layer(context_word_indices) # B x 1 x D
        negative_word_embeddings = -self.output_layer(negative_word_indices) # B x K x D
        

        context_word_scores = (context_word_embeddings.bmm(target_word_embeddings.transpose(1, 2))).squeeze(2).squeeze(1) # 1D tensor with B elements
        negative_word_scores = torch.sum((negative_word_embeddings.bmm(target_word_embeddings.transpose(1, 2))).squeeze(2), axis = 0) # 1D tensor with B elements
        
        pos_loss = self.logSigmoid(context_word_scores)
        neg_loss = self.logSigmoid(negative_word_scores)
        loss = pos_loss + neg_loss
        return -torch.mean(loss)

    def get_word_vector(self, index):
        return self.hidden_layer(index)

In [250]:
model = Model(5, train_dataset.vocab_size)

# Training 

In [251]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [252]:
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [253]:
def train(model, iterator, optimizer, device):

    epoch_loss = 0

    for batch in tqdm(iterator, desc="Training", leave=False):

        target_word_indices = batch[0].to(device)
        context_word_indices = batch[1].to(device)

        negative_word_indices = generate_batch_negative_samples(target_word_indices, context_word_indices)
        

        model.zero_grad()

        loss = model(target_word_indices, context_word_indices, negative_word_indices)
        loss.backward()

        optimizer.step()

        loss = loss.item()

        epoch_loss += loss

    return epoch_loss

In [254]:
EPOCHS = 10

for epoch in trange(EPOCHS):
    epoch_loss = train(model, train_data_loader, optimizer, device)
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {epoch_loss:.3f}')

  0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 01
	Train Loss: 70.807


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 02
	Train Loss: 69.407


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 03
	Train Loss: 65.221


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 04
	Train Loss: 71.089


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 05
	Train Loss: 67.138


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 06
	Train Loss: 69.314


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 07
	Train Loss: 74.438


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 08
	Train Loss: 72.796


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 09
	Train Loss: 69.338


Training:   0%|          | 0/48 [00:00<?, ?it/s]

Epoch: 10
	Train Loss: 72.022


In [255]:
# returns top-10 similar words
def word_similarity(target_word, vocab):
    target_embedding = model.get_word_vector(torch.LongTensor([train_dataset.word2index[target_word]]))

    similarities = []
    for context_word in vocab:
        if context_word == target_word: 
            continue
        
        context_embedding = model.get_word_vector(torch.LongTensor([train_dataset.word2index[context_word]]))
        cosine_sim = F.cosine_similarity(target_embedding, context_embedding).data.tolist()[0]
        similarities.append([context_word, cosine_sim])
    return sorted(similarities, key=lambda x: x[1], reverse=True)

In [256]:
word_similarity('drink', train_dataset.word_freq.keys())

[['car', 0.8025755882263184],
 ['apple', 0.6895413994789124],
 ['bacon', 0.6523947715759277],
 ['water', 0.6153475046157837],
 ['cherry', 0.59963059425354],
 ['berlin', 0.5709933042526245],
 ['mercedes', 0.5525842308998108],
 ['ford', 0.35669708251953125],
 ['usa', 0.3406466543674469],
 ['germany', 0.33181434869766235],
 ['mango', 0.2919435501098633],
 ['fruit', 0.19135074317455292],
 ['milk', 0.0682157501578331],
 ['boston', 0.026741784065961838],
 ['sugar', -0.12252910435199738],
 ['eat', -0.25700169801712036],
 ['cola', -0.3765031695365906],
 ['juice', -0.3999926447868347],
 ['cold', -0.6409810185432434]]