In [1]:
from datasets import load_dataset, concatenate_datasets
import re
import gc
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

import mlflow
from collections import defaultdict

from IPython.display import clear_output
from tqdm import tqdm
import plotly.express as px

In [2]:
dataset = load_dataset('Fraser/news-category-dataset')
dataset = concatenate_datasets([dataset['train'], dataset['test'], dataset['validation']])

Using custom data configuration default
Reusing dataset news_category (/home/teacher/.cache/huggingface/datasets/Fraser___news_category/default/0.0.0/737b7b6dff469cbba49a6202c9e94f9d39da1fed94e13170cf7ac4b61a75fb9c)


  0%|          | 0/3 [00:00<?, ?it/s]

## Preprocessing method

In [3]:
def sample_preprocess(sample):
    # To lower.  
    sample = sample.lower()

    # Replacing words contructions.  
    contradiction_dict = {"ain't": "are not", "'s": " is", "aren't": "are not", 
                        "i'm": "i am", "'re": " are", "'ve": " have"}
    for word, replaceVal in contradiction_dict.items():
        sample = sample.replace(word, replaceVal)

    # Remove hashtags, @users, links and digits.  
    reg_str = "(@[A-Za-z0-9]+)|([^0-9A-Za-z \t])|(\w+:\/\/\S+)"
    sample = re.sub(reg_str, " ", sample)
    
    # Replace numbers with NUM.  
    sample = re.sub('[0-9]+', '<NUM>', sample)

    # Remove multiple spaces.  
    sample = re.sub(' +', ' ', sample)

    sample = sample.strip()
    
    return sample.split()

## Form vocab of texts and vocabulary

In [4]:
texts = [sample['headline'] for sample in dataset]
texts.extend([sample['short_description'] for sample in dataset])

In [5]:
vocab = set()
for sample in dataset:
    vocab.update(set(sample_preprocess(sample['headline'])))
    vocab.update(set(sample_preprocess(sample['short_description'])))

In [6]:
vocab = list(vocab)
vocab.extend(['<UNK>', '<PAD>'])

## Dataset with preprocessing data

In [7]:
class TextDataset(Dataset):
    def __init__(self, texts, vocab=None, window_size=5, K_sampling=100):
        """
        Dataset class. Preprocessing data according to vocabulary.  
        :param texts: corpus of texts, 
        :param vocab: vocabulary.  
        """
        
        super().__init__()
        assert window_size % 2 == 1, "Window size must be odd!"
        self.window_size = window_size
        self.K = K_sampling

        self.texts = texts
        
        if vocab:
            self.vocab = vocab
        else:
            self.vocab = set()
            for sample in self.texts:
                self.vocab.update(set(self.sample_preprocess(sample['headline'])))
                self.vocab.update(set(self.sample_preprocess(sample['short_description'])))

            self.vocab = list(self.vocab)
            self.vocab.extend(["<UNK>", "<PAD>"])

        self.word2idx = {word: idx for idx, word in enumerate(self.vocab)}
        self.UNK, self.PAD = self.word2idx["<UNK>"], self.word2idx["<PAD>"]

        # Preprocessing and splitting text into triplets.  
        self.deuce_sample = list()
        for sample in texts:
            self.deuce_sample.extend(self.split_sentence(
                self.sample_preprocess(sample)
            ))

    def sample_preprocess(self, sample):
        """
        Static method for text preprocessing.  
        :param sample: sample to preprocess.
        :returns: tokenized and preprocessed string -> list.  
        """

        # To lower.  
        sample = sample.lower()

        # Replacing words contructions.  
        contradiction_dict = {"ain't": "are not", "'s": " is", "aren't": "are not", 
                            "i'm": "i am", "'re": " are", "'ve": " have"}
        for word, replaceVal in contradiction_dict.items():
            sample = sample.replace(word, replaceVal)
        
        # Remove hashtags, @users and links.  
        reg_str = "(@[A-Za-z0-9]+)|([^0-9A-Za-z \t])|(\w+:\/\/\S+)"
        sample = re.sub(reg_str, " ", sample)
        
        # Replace numbers with NUM.  
        sample = re.sub('[0-9]+', '<NUM>', sample)

        # Remove multiple spaces.  
        sample = re.sub(' +', ' ', sample)

        sample = sample.strip().split()

        return torch.LongTensor([self.word2idx.get(token, self.UNK) 
                                for token in sample])
    
    def split_sentence(self, sample):
        sent_split = list()
        for token_idx in range(len(sample) - self.window_size + 1):
            central = sample[token_idx + self.window_size//2]
            context_list = sample[token_idx : token_idx+self.window_size]
            context_list = context_list[context_list != central]
            
            for context in context_list:
                sent_split.append([central, context])

        return sent_split

    def __getitem__(self, idx):
        central, context = self.deuce_sample[idx]

        # Negative sampling.  
        while True:
            neg_samples = torch.randint(0, len(self.vocab), (self.K,))
            
            if (central not in neg_samples) and (context not in neg_samples):
                break
        
        return central, context, neg_samples
    
    def __len__(self):
        return len(self.texts)

In [8]:
def collate_fn(batch):
    central, context, negatives = list(), list(), list()
    for b in batch:
        central.append(b[0])
        context.append(b[1])
        negatives.append(b[2])
    
    central = torch.stack(central)
    context = torch.stack(context)
    negatives = torch.stack(negatives)

    return central, context, negatives

## Words amount by frequencies

In [None]:
vocab_count = defaultdict(int)
for sample in dataset['train']:
    headline = sample_preprocess(sample['headline'])
    description = sample_preprocess(sample['short_description'])

    for token in headline:
        vocab_count[token] += 1
    
    for token in description:
        vocab_count[token] += 1

In [None]:
tokens_to_drop = list()
for word, count in vocab_count.items():
    if count < 3 or count > 60000:
        tokens_to_drop.append(word)

In [None]:
freq = list(vocab_count.values())
min_freq = min(freq)

In [None]:
fig = px.histogram(freq, title="Tokens amount by frequencies in dataset", 
                    range_x=[min_freq, min_freq + 1000], 
                    width=1000, height=600)
fig.show()

## Model

In [9]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()

        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        self.central_emb = nn.Embedding(num_embeddings=self.vocab_size, 
                                        embedding_dim=self.embedding_dim)
        self.context_emb = nn.Embedding(num_embeddings=self.vocab_size, 
                                        embedding_dim=self.embedding_dim)

    def forward(self, central, context, neg_samples):
        emb_central = self.central_emb(central)
        emb_context = self.context_emb(context)
        emb_neg_samples = self.context_emb(neg_samples)

        central_out = torch.matmul(emb_context.T, emb_central)
        rest_out = torch.bmm(emb_neg_samples, emb_central.unsqueeze(-1)).squeeze(-1)
        
        return -torch.mean(F.logsigmoid(central_out) + torch.sum(F.logsigmoid(-rest_out)))

In [10]:
def train(model, dataloaders, optimizer, lr, scheduler=None, 
          num_epochs=5, start_epoch=-1, prev_losses=list(), device=torch.device('cpu'),
          folder_for_checkpoints='/'):
    if len(prev_losses) > 0:
        history = prev_losses[:]
        curr_step = prev_losses + 1

        for step, loss in prev_losses:
            mlflow.log_metric('train_loss', loss, step=step)
    else:
        history = list()
        curr_step = 1

    model.train()
    for epoch in range(start_epoch + 1, start_epoch + 1 + num_epochs):
        running_loss = 0.0
        
        clear_output(True)
        print("-" * 20)
        print(f"Epoch: {epoch}/{start_epoch + num_epochs}")
        print("-" * 20)
        print("Train: ")
        
        for batch_idx, (central, context, negatives) in enumerate(tqdm(dataloaders)):
            central = central.to(device)
            context = context.to(device)
            negatives = negatives.to(device)
            
            loss = model(central, context, negatives)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if scheduler:
                scheduler.step()
            
            running_loss += loss.item()

            if batch_idx % 100 == 99:
                mlflow.log_metric('train_loss', running_loss / (batch_idx + 1), curr_step)
                
                history.append((curr_step, running_loss / (batch_idx + 1)))

                print(f"\nRunning training loss: {running_loss / (batch_idx + 1)}")
                
                state = {
                    'epoch': epoch,
                    'batch_size_training': dataloaders.batch_size, 
                    'model_architecture': model, 
                    'model_state_dict': model.state_dict(), 
                    'optimizer': optimizer, 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    'scheduler': scheduler if scheduler else None, 
                    'scheduler_state_dict': scheduler.state_dict() if scheduler else None, 
                    'lr': lr, 
                    'losses': history
                }

                torch.save(state, folder_for_checkpoints + f'checkpoint_epoch_{epoch % 5 + 1}_{batch_idx}.pt')
                # Logging 5 latest checkpoints.  
                mlflow.log_artifacts(folder_for_checkpoints)
                
            gc.collect()
            del central, context, negatives
            torch.cuda.empty_cache()
            curr_step += 1
            
            
        mean_train_loss = running_loss / len(dataloaders['train'])

        print(f"Training loss: {mean_train_loss}")

    return model, history

## Model configure

In [11]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_PATH = 'Checkpoints/'
NUM_EPOCHS = 5
BATCH_SIZE = 512
LR = 1e-3
EMBEDDING_DIM = 300
WINDOW = 5
K = 100

In [None]:
text_dataset = TextDataset(texts, vocab=vocab, window_size=WINDOW, K_sampling=K)
text_dataloader = DataLoader(text_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [None]:
try:
    del optimizer
    del model
    torch.cuda.empty_cache()
    gc.collect()
except:
    pass
model = Word2Vec(vocab_size=len(vocab), embedding_dim=EMBEDDING_DIM).to(DEVICE)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = CosineAnnealingLR(optimizer, len(text_dataloader))

## Training

In [None]:
mlflow.set_tracking_uri('databricks')
mlflow.set_experiment("/Users/timkakhanovich@gmail.com/Word2Vec/v1")

with mlflow.start_run(run_name="Word2Vec model v1.0"):
    mlflow.set_tags({
        'Python': '.'.join(map(str, sys.version_info[:3])), 
        'Device': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
    })
    mlflow.log_param('batch_size', BATCH_SIZE)
    mlflow.log_param('lr', LR)
    _, _ = train(
        model, text_dataloader, optimizer, lr=LR, scheduler=scheduler,
        num_epochs=NUM_EPOCHS, device=DEVICE, folder_for_checkpoints=CHECKPOINT_PATH
    )