In [None]:
import pickle
import numpy as np
import pandas as pd
import torch
import datasets
from nltk.util import ngrams
import os
from tqdm import tqdm

In [2]:
with open("lemmatized_tweets.pkl","rb") as file_handle:
    data = pickle.load(file_handle)

In [3]:
data = data.apply(lambda x: " ".join(x)+"\n")

In [4]:
converted_raw_text = list(data)

In [5]:
truncated_raw_text = list(filter(lambda x: len(x) > 1,converted_raw_text))

In [6]:
with open("truncated-cleaned-tweets.txt","wt") as file_handle:

    file_handle.writelines(truncated_raw_text[:100000])

In [None]:
dset = datasets.load_dataset("text",data_files={"train":"truncated-cleaned-tweets.txt"})

In [8]:
def tokenize_tweets(single_row):

    single_row["tokenized-tweets"] = single_row["text"].split()
    return single_row

In [None]:
dset["train"] = dset["train"].map(tokenize_tweets)

In [10]:
def convert_to_trigrams(single_row):

    single_row["tri-grams"] = list(ngrams(single_row["tokenized-tweets"],n=3))
    return single_row

In [None]:
dset["train"] = dset["train"].map(convert_to_trigrams)

In [12]:
vocabulary = set()

for raw_text in truncated_raw_text[:100000]:
    vocabulary.update(raw_text.split())

In [13]:
vocab2idx = dict(zip(vocabulary,range(len(vocabulary))))

In [14]:
def convert_to_bigrams(single_row):

    center_token_target_token_pairs = list()

    for single_trigram in single_row["tri-grams"]:

        bigrams = list()
        
        bigrams.append([vocab2idx[single_trigram[1]],
                                                vocab2idx[single_trigram[0]]])
        bigrams.append([vocab2idx[single_trigram[1]],
                                                vocab2idx[single_trigram[2]]])
        center_token_target_token_pairs.append(bigrams)

    single_row["tri-grams"] = center_token_target_token_pairs

    return single_row

In [None]:
dset["train"] = dset["train"].map(convert_to_bigrams)

In [16]:
input_token_target_token_pairs = list()

for single_tweet_bigrams in dset["train"]["tri-grams"]:
    for bigrams_list in single_tweet_bigrams:

        input_token_target_token_pairs.append(bigrams_list[0])
        input_token_target_token_pairs.append(bigrams_list[1])

In [17]:
class SkipGramDataset(torch.utils.data.Dataset):

    def __init__(self,input_target_pairs):
        self.data = input_target_pairs

    def __getitem__(self,index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [18]:
training_data_obj = SkipGramDataset(input_token_target_token_pairs)

In [19]:
training_data_generator = torch.utils.data.DataLoader(training_data_obj,batch_size=32,
                                                     num_workers=os.cpu_count())

In [20]:
class Word2VecSkipGramNeuralNetwork(torch.nn.Module):

    def __init__(self,vocabulary_size,topic_vector_dim):
        super().__init__()

        self.hidden_layer = torch.nn.Embedding(num_embeddings=vocabulary_size,
                                               embedding_dim=topic_vector_dim)
        self.output_layer = torch.nn.Linear(in_features=topic_vector_dim,
                                            out_features=vocabulary_size)
        self.output_layer_activation = torch.nn.Softmax()

    
    def forward(self,center_token):

        embedding_layer_out = self.hidden_layer(center_token)
        linear_layer_out = self.output_layer(embedding_layer_out)
        nn_out = self.output_layer_activation(linear_layer_out)

        return nn_out

In [21]:
our_word2vec_skip_gram_nw = Word2VecSkipGramNeuralNetwork(len(vocab2idx),64)

In [None]:
our_word2vec_skip_gram_nw.to("cpu")

In [None]:
for mini_batch_idx, mini_batch in enumerate(training_data_generator):

    print("Index of Mini Batch is {}".format(mini_batch_idx))
    print("Center Token Mini Batch is {}".format(mini_batch[0]))
    print("Surrounding Token Mini Batch is {}".format(mini_batch[1]))
    break

In [None]:
epochs = 5
optimizer = torch.optim.Adam(params=our_word2vec_skip_gram_nw.parameters(),
                             lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
progress_bar = tqdm(range(epochs * len(training_data_generator)))

for epoch in range(epochs):

    for mini_batch_idx, mini_batch in enumerate(training_data_generator):

        center_token_mini_batch = mini_batch[0]
        surrounding_token_mini_batch = mini_batch[1]

        center_token_mini_batch.to("cpu")
        surrounding_token_mini_batch.to("cpu")

        optimizer.zero_grad()

        if (mini_batch_idx+1) % 1000 == 0:
            print("Epoch # {}, Time Step # {}, Loss = {}".format(epoch,(mini_batch_idx+1),
                                                             loss_fn_value))

        y_pred = our_word2vec_skip_gram_nw(center_token_mini_batch)

        loss_fn_value = loss_fn(y_pred,surrounding_token_mini_batch)
        loss_fn_value.backward()

        optimizer.step()
        progress_bar.update(1)