<a href="https://colab.research.google.com/github/sibat119/papers-review-code-impl/blob/main/word2vec/word2vec_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [60]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.functional as F
import torch.nn.functional as F
import os
import json
import re
from collections import Counter
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

## **Data Preprocessing**

In [61]:
from torch.utils.data import Dataset
class Word2vecDataset(Dataset):
    def __init__(self, datapath, window_size):
        self.window_size = window_size
        self.sentences_count = 0
        self.input_file = open(datapath, encoding="utf8")
        self.corpus = self.get_corpus(datapath)
        self.words = self.get_words(self.corpus)
        self.vocab_to_int, self.int_to_vocab = self.create_lookup_tables(self.words)


    def __len__(self):
        return self.sentences_count

    def __getitem__(self, idx):
        return self.sentences_count

    @staticmethod
    def collate(batches):
        all_u = [u for batch in batches for u, _, _ in batch if len(batch) > 0]
        all_v = [v for batch in batches for _, v, _ in batch if len(batch) > 0]
        all_neg_v = [neg_v for batch in batches for _, _, neg_v in batch if len(batch) > 0]

        return torch.LongTensor(all_u), torch.LongTensor(all_v), torch.LongTensor(all_neg_v)
    
    def get_corpus(self, datapath):
        corpus = ''
        with open(datapath) as input_file:
            data_file = json.load(input_file)
            for x in data_file:
                self.sentences_count += 1
                corpus += (x["text"].lower().strip())
        print(len(corpus.strip()))
        return corpus
  
    def get_words(self, text):
        # Replace punctuation with tokens so we can use them in our model
        text = text.lower()
        text = text.replace('.', ' <PERIOD> ')
        text = text.replace(',', ' <COMMA> ')
        text = text.replace('"', ' <QUOTATION_MARK> ')
        text = text.replace(';', ' <SEMICOLON> ')
        text = text.replace('!', ' <EXCLAMATION_MARK> ')
        text = text.replace('?', ' <QUESTION_MARK> ')
        text = text.replace('(', ' <LEFT_PAREN> ')
        text = text.replace(')', ' <RIGHT_PAREN> ')
        text = text.replace('--', ' <HYPHENS> ')
        text = text.replace('?', ' <QUESTION_MARK> ')
        # text = text.replace('\n', ' <NEW_LINE> ')
        text = text.replace(':', ' <COLON> ')
        words = text.split()
        
        # Remove all words with  5 or fewer occurences
        word_counts = Counter(words)
        trimmed_words = [word for word in words if word_counts[word] > 2]

        return trimmed_words

    def create_lookup_tables(self, words):
        """
        Create lookup tables for vocabulary
        :param words: Input list of words
        :return: Two dictionaries, vocab_to_int, int_to_vocab
        """
        word_counts = Counter(words)
        # sorting the words from most to least frequent in text occurrence
        sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
        # create int_to_vocab dictionaries
        int_to_vocab = {ii: word for ii, word in enumerate(sorted_vocab)}
        vocab_to_int = {word: ii for ii, word in int_to_vocab.items()}

        return vocab_to_int, int_to_vocab

In [62]:

threshold = 1e-5
word_counts = Counter(int_words)
#print(list(word_counts.items())[0])  # dictionary of int_words, how many times they appear

total_count = len(int_words)
freqs = {word: count/total_count for word, count in word_counts.items()}
p_drop = {word: 1 - np.sqrt(threshold/freqs[word]) for word in word_counts}
# discard some frequent words, according to the subsampling equation
# create a new list of words for training
train_words = [word for word in int_words if random.random() < (1 - p_drop[word])]

print(train_words[:30])
print(Counter(train_words))

[8, 5, 1, 13]
Counter({8: 1, 5: 1, 1: 1, 13: 1})


In [63]:
def get_context_words(words, idx, window_size=5):
    ''' Get a list of words in a window around an index. '''
    
    # R = np.random.randint(1, window_size+1)
    start = idx - window_size if (idx - window_size) > 0 else 0
    stop = idx + window_size
    target_words = words[start:idx] + words[idx+1:stop+1]
    
    return list(target_words)

In [64]:
def neg_sampling(self, vocab):
  NEG_SIZE = 1e6
  neg_word_list = []
  sorted_vocab = []
  freq_sum = np.sum(vocab[word]['word_freq']**0.75 for word in vocab)
  for word in vocab:
      sorted_vocab.append((word, vocab[word]['word_freq']))
  sorted_vocab.sort(key=lambda tup: tup[1], reverse=True)
  for word in sorted_vocab:
      neg_word_list.extend([word[0]] * int((word[1]**0.75 / freq_sum) * NEG_SIZE))
  return neg_word_list

## **Define Model**

In [65]:
class SkipGramModel(nn.Module):

    def __init__(self, emb_size, emb_dimension):
        super(SkipGramModel, self).__init__()
        self.emb_size = emb_size
        self.emb_dimension = emb_dimension
        self.word_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)
        self.context_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)

        initrange = 1.0 / self.emb_dimension
        init.uniform_(self.word_embeddings.weight.data, -initrange, initrange)
        init.constant_(self.context_embeddings.weight.data, 0)

    def forward(self, pos_u, pos_v, neg_v):
        emb_u = self.word_embeddings(pos_u)
        emb_v = self.context_embeddings(pos_v)
        emb_neg_v = self.context_embeddings(neg_v)

        score = torch.sum(torch.mul(emb_u, emb_v), dim=1)
        score = torch.clamp(score, max=10, min=-10)
        score = -F.logsigmoid(score)

        neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()
        neg_score = torch.clamp(neg_score, max=10, min=-10)
        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)

        return torch.mean(score + neg_score)

    def save_embedding(self, id2word, file_name):
        embedding = self.word_embeddings.weight.cpu().data.numpy()
        with open(file_name, 'w') as f:
            f.write('%d %d\n' % (len(id2word), self.emb_dimension))
            for wid, w in id2word.items():
                e = ' '.join(map(lambda x: str(x), embedding[wid]))
                f.write('%s %s\n' % (w, e))

## **Training Loop**

In [66]:
def train(skip_gram_model, dataloader, initial_lr=1e-5, iterations=3, device='cpu'):
  for iteration in range(iterations):

      print("\n\n\nIteration: " + str(iteration + 1))
      optimizer = optim.SparseAdam(skip_gram_model.parameters(), lr=initial_lr)
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(dataloader))

      running_loss = 0.0
      for i, sample_batched in enumerate(tqdm(dataloader)):

          if len(sample_batched[0]) > 1:
              pos_u = sample_batched[0].to(device)
              pos_v = sample_batched[1].to(device)
              neg_v = sample_batched[2].to(device)

              scheduler.step()
              optimizer.zero_grad()
              loss = skip_gram_model.forward(pos_u, pos_v, neg_v)
              loss.backward()
              optimizer.step()

              running_loss = running_loss * 0.9 + loss.item() * 0.1
              if i > 0 and i % 500 == 0:
                  print(" Loss: " + str(running_loss))
      # skip_gram_model.save_embedding(id2word, output_file_name)

## **Train Model**

In [67]:
skip_gram_model = SkipGramModel(len(vocab_to_int), 50)

In [68]:
dataset = Word2vecDataset('/content/wiki_clean.json', 5)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=dataset.collate)

33536661


In [70]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train(skip_gram_model=skip_gram_model, dataloader=dataloader, device=device)




Iteration: 1


  0%|          | 0/312 [00:00<?, ?it/s]


TypeError: ignored

## **Evaluation**