In [22]:
from datasets import load_dataset
import re
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import mlflow
from collections import defaultdict

from IPython.display import clear_output
from tqdm import tqdm
import plotly.express as px

In [2]:
dataset = load_dataset('Fraser/news-category-dataset')

Using custom data configuration default
Reusing dataset news_category (/home/tsimur/.cache/huggingface/datasets/Fraser___news_category/default/0.0.0/737b7b6dff469cbba49a6202c9e94f9d39da1fed94e13170cf7ac4b61a75fb9c)
100%|██████████| 3/3 [00:00<00:00, 478.40it/s]


## Preprocessing method

In [3]:
def sample_preprocess(sample):
    # To lower.  
    sample = sample.lower()

    # Replacing words contructions.  
    contradiction_dict = {"ain't": "are not", "'s": " is", "aren't": "are not", 
                        "i'm": "i am", "'re": " are", "'ve": " have"}
    for word, replaceVal in contradiction_dict.items():
        sample = sample.replace(word, replaceVal)

    # Remove hashtags, @users, links and digits.  
    reg_str = "(@[A-Za-z0-9]+)|([^0-9A-Za-z \t])|(\w+:\/\/\S+)"
    sample = re.sub(reg_str, " ", sample)
    
    # Replace numbers with NUM.  
    sample = re.sub('[0-9]+', '<NUM>', sample)

    # Remove multiple spaces.  
    sample = re.sub(' +', ' ', sample)

    sample = sample.strip()
    
    return sample.split()

## Form vocab of texts and vocabulary

In [4]:
texts = [sample['headline'] for sample in dataset['train']]
texts.extend([sample['short_description'] for sample in dataset['train']])

In [5]:
vocab = set()
for sample in dataset['train']:
    vocab.update(set(sample_preprocess(sample['headline'])))
    vocab.update(set(sample_preprocess(sample['short_description'])))

In [6]:
vocab = list(vocab)
vocab.extend(['<UNK>', '<PAD>'])

## Dataset with preprocessing data

In [7]:
class TextDataset(Dataset):
    def __init__(self, texts, vocab=None):
        """
        Dataset class. Preprocessing data according to vocabulary.  
        :param texts: corpus of texts, 
        :param vocab: vocabulary.  
        """
        
        super().__init__()
        self.texts = texts

        if vocab:
            self.vocab = vocab
        else:
            self.vocab = set()
            for sample in self.texts:
                self.vocab.update(set(self.sample_preprocess(sample['headline'])))
                self.vocab.update(set(self.sample_preprocess(sample['short_description'])))

            self.vocab = list(self.vocab)
            self.vocab.extend(["<UNK>", "<PAD>"])

        self.word2idx = {word: idx for idx, word in enumerate(self.vocab)}
        self.UNK, self.PAD = self.word2idx['<UNK>'], self.word2idx['<PAD>']

    @staticmethod
    def sample_preprocess(sample):
        """
        Static method for text preprocessing.  
        :param sample: sample to preprocess.
        :returns: tokenized and preprocessed string -> list.  
        """

        # To lower.  
        sample = sample.lower()

        # Replacing words contructions.  
        contradiction_dict = {"ain't": "are not", "'s": " is", "aren't": "are not", 
                            "i'm": "i am", "'re": " are", "'ve": " have"}
        for word, replaceVal in contradiction_dict.items():
            sample = sample.replace(word, replaceVal)

        # Remove hashtags, @users, links and digits.  
        reg_str = "(@[A-Za-z0-9]+)|([^0-9A-Za-z \t])|(\w+:\/\/\S+)"
        sample = re.sub(reg_str, " ", sample)
        
        # Replace numbers with NUM.  
        sample = re.sub('[0-9]+', '<NUM>', sample)

        # Remove multiple spaces.  
        sample = re.sub(' +', ' ', sample)

        sample = sample.strip()
        
        return sample.split()

    def pad_collate(self, batch):
        """
        Function to pad data according to the maximum length 
        of sentence in batch (for dataloader).  
        :param batch: 
        :returns: batch padded to the maximum lenght with <PAD> value.  
        """
        
        batch_tokens_lengths = torch.tensor([x.shape[0] for x in batch])
        max_len_per_batch = torch.max(batch_tokens_lengths)

        # Extend to even max_len.  
        max_len_per_batch += (max_len_per_batch % 2)
        lengths_to_pad = max_len_per_batch - batch_tokens_lengths
        
        X_batch = torch.vstack([
            F.pad(x, pad=(0, val_to_pad), value=self.PAD) 
            for x, val_to_pad in zip(batch, lengths_to_pad)
        ]).type(torch.int64)

        return X_batch

    def __getitem__(self, idx):
        preprocessed = self.sample_preprocess(self.texts[idx])

        return torch.LongTensor([self.word2idx.get(token, self.UNK) 
                                for token in preprocessed])
    
    def __len__(self):
        return len(self.texts)

In [8]:
text_dataset = TextDataset(texts, vocab=vocab)
text_dataloader = DataLoader(text_dataset, batch_size=16, collate_fn=text_dataset.pad_collate)

## Words amount by frequencies

In [9]:
vocab_count = defaultdict(int)
for sample in dataset['train']:
    headline = sample_preprocess(sample['headline'])
    description = sample_preprocess(sample['short_description'])

    for token in headline:
        vocab_count[token] += 1
    
    for token in description:
        vocab_count[token] += 1

In [10]:
tokens_to_drop = list()
for word, count in vocab_count.items():
    if count < 3 or count > 60000:
        tokens_to_drop.append(word)

In [11]:
freq = list(vocab_count.values())
min_freq = min(freq)

In [14]:
fig = px.histogram(freq, title="Tokens amount by frequencies in dataset", 
                    range_x=[min_freq, min_freq + 1000], 
                    width=1000, height=600)
fig.show()

## Model

In [15]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()

        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        self.central_emb = nn.Embedding(num_embeddings=self.vocab_size, 
                                        embedding_dim=self.embedding_dim)
        self.context_emb = nn.Embedding(num_embeddings=self.vocab_size, 
                                        embedding_dim=self.embedding_dim)

    def forward(self, central, context, neg_samples):
        emb_central = self.central_emb(central)
        emb_context = self.context_emb(context)
        emb_neg_samples = self.context_emb(neg_samples)

        central_out = torch.matmul(emb_context.T, emb_central)
        rest_out = torch.matmul(emb_neg_samples, emb_central)
        
        return -torch.mean(F.logsigmoid(central_out) + torch.sum(F.logsigmoid(-rest_out)))

In [16]:
model = Word2Vec(vocab_size=len(vocab), embedding_dim=300)

In [17]:
model(torch.tensor(4), torch.tensor(50), torch.tensor([22, 55, 66, 77]))

tensor(38.1902, grad_fn=<NegBackward0>)

In [None]:
def train(model, dataloaders, optimizer, lr, scheduler=None, K_neg_sampling=100, 
          num_epochs=5, start_epoch=-1, prev_losses=list(), device=torch.device('cpu'),
          folder_for_checkpoints='/'):
    if len(prev_losses) > 0:
        history = prev_losses[:]
        curr_step = prev_losses + 1

        for step, loss in prev_losses:
            mlflow.log_metric('train_loss', loss, step=step)
    else:
        history = list()
        curr_step = 1

    model.train()
    for epoch in range(start_epoch + 1, start_epoch + 1 + num_epochs):
        running_loss = 0.0
        
        clear_output(True)
        print("-" * 20)
        print(f"Epoch: {epoch}/{start_epoch + num_epochs}")
        print("-" * 20)
        print("Train: ")
        
        for batch_idx, X in enumerate(tqdm(dataloaders)):
            X = X.to(device)

            # Negative sampling.  
            while True:
                neg_samples = torch.randint(0, model.vocab_size, (K_neg_sampling,))
                if (central not in neg_samples) and (context not in neg_samples):
                    break
            
            loss = model(X)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if scheduler:
                scheduler.step()
            
            running_loss += loss.item()

            if batch_idx % 30 == 29:
                mlflow.log_metric('train_loss', running_loss / (batch_idx + 1), curr_step)
                
                history.append((curr_step, running_loss / (batch_idx + 1)))

                print(f"\nRunning training loss: {running_loss / (batch_idx + 1)}")

            gc.collect()
            del X
            torch.cuda.empty_cache()
            curr_step += 1

        mean_train_loss = running_loss / len(dataloaders['train'])

        print(f"Training loss: {mean_train_loss}")

        state = {
            'epoch': epoch,
            'batch_size_training': dataloaders.batch_size, 
            'model_architecture': model, 
            'model_state_dict': model.state_dict(), 
            'optimizer': optimizer, 
            'optimizer_state_dict': optimizer.state_dict(), 
            'scheduler': scheduler if scheduler else None, 
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None, 
            'lr': lr, 
            'losses': history
        }

        torch.save(state, folder_for_checkpoints + f'checkpoint_epoch_{epoch % 5 + 1}.pt')
        # Logging 5 latest checkpoints.  
        mlflow.log_artifacts(folder_for_checkpoints)

    return model, history