In [1]:
import sys
sys.path.append("../src/")

import time
import json
import torch
import random
import pickle as pkl
import torch.nn as nn
from types import NoneType
from itertools import cycle
import torch.optim as optim
from utils import get_next_batch
from vocabulary import Vocabulary
from pytorch_lightning import Trainer
from typing import Union, Mapping, Any
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import Dataset, IterableDataset
from pytorch_lightning.callbacks import ModelCheckpoint

DEBUG = True
BATCH_SIZE = 256
EPOCHS = 10


In [2]:
Trainer?

[0;31mInit signature:[0m
[0mTrainer[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maccelerator[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mpytorch_lightning[0m[0;34m.[0m[0maccelerators[0m[0;34m.[0m[0maccelerator[0m[0;34m.[0m[0mAccelerator[0m[0;34m][0m [0;34m=[0m [0;34m'auto'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstrategy[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mpytorch_lightning[0m[0;34m.[0m[0mstrategies[0m[0;34m.[0m[0mstrategy[0m[0;34m.[0m[0mStrategy[0m[0;34m][0m [0;34m=[0m [0;34m'auto'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdevices[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m [0mstr[0m[0;34m,[0m [0mint[0m[0;34m][0m [0;34m=[0m [0;34m'auto'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_nodes[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m1[0m[0;34m,[0m[0;34

In [26]:
def get_samples(tokenized_texts, window_size, texts_count):
    for text_num, tokens in enumerate(tokenized_texts):
        if texts_count and text_num >= texts_count:
            break
        for i in range(len(tokens)):
            central_word = vocabulary.get_index(tokens[i])
            for delta in range(-window_size, window_size + 1):
                if delta == 0:
                    continue
                if 0 <= (i + delta) < len(tokens):
                    context_word = vocabulary.get_index(tokens[i + delta])
                    yield (torch.LongTensor([central_word]),
                           torch.LongTensor([context_word]))


def get_samples_cycle(tokenized_texts, window_size, texts_count):
    while True:
        for sample in get_samples(tokenized_texts, window_size, texts_count):
            yield sample


class Word2VecDataset(Dataset):
    def __init__(self, tokenized_texts, vocabulary, window_size=2, texts_count=100000):
        self.samples = list(get_samples(tokenized_texts, window_size, texts_count))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        return self.samples[index]


class Word2VecIterableDataset(IterableDataset):
    def __init__(self, tokenized_texts, vocabulary, window_size=2, texts_count=None):
        self.tokenized_texts = tokenized_texts
        self.vocabulary = vocabulary
        self.window_size = window_size
        self.texts_count = texts_count

    def __iter__(self):
        return get_samples_cycle(self.tokenized_texts, self.window_size, self.texts_count)

class SkipGramModel(LightningModule):
    def __init__(self, vocab_size, embedding_dim=128):
        super().__init__()
        
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.out_layer = nn.Linear(embedding_dim, vocab_size)
        self.loss = nn.CrossEntropyLoss()
        self.train_outputs = []
        self.val_outputs = []
        self.test_outputs = []
    
    def forward(self, centrals, contexts):
        projections = self.embeddings.forward(centrals)
        logits = self.out_layer.forward(projections)
        logits = logits.transpose(1, 2)
        loss = self.loss(logits, contexts)
        return loss
    
    def training_step(self, batch, batch_nb):
        result = self(*batch)
        self.log("train_loss", result)
        return {'loss': result}
    
    def validation_step(self, batch, batch_nb):
        result = self(*batch)
        self.log("val_loss", result)  
        return {'val_loss': result}

    def test_step(self, batch, batch_nb):
        result = self(*batch)
        self.log("test_loss", result)
        return {'test_loss': self(*batch)}

    def on_validation_batch_end(
        self,
        outputs: Union[torch.Tensor, Mapping[str, Any], NoneType],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        self.val_outputs.append(outputs)
    
    def on_validation_epoch_end(self):
        outputs = self.val_outputs
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        self.log("val_loss_epoch", avg_loss, on_step=False, on_epoch=True)
        return {'val_loss': avg_loss, 'progress_bar': tensorboard_logs}

    def on_test_batch_end(
        self,
        outputs: Union[torch.Tensor, Mapping[str, Any], NoneType],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        self.test_outputs.append(outputs)
    
    def on_test_epoch_end(self):
        outputs = self.test_outputs 
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        self.log("test_loss_epoch", avg_loss, on_step=False, on_epoch=True)
        return {'test_loss': avg_loss, 'progress_bar': tensorboard_logs}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        return [optimizer]


In [3]:
with open("../data/prepared.pkl", "rb") as fp:
    prepared = pkl.load(fp)
vocabulary = prepared["vocabulary"]
texts = prepared["texts"]
contexts = prepared["contexts"]
test_texts = prepared["test_texts"]
del prepared

In [27]:
from torch.utils.data import DataLoader, RandomSampler

random.shuffle(texts)
train_data = Word2VecIterableDataset(texts, vocabulary)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE)

random.shuffle(test_texts)
val_data = Word2VecIterableDataset(test_texts, vocabulary)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)


In [43]:

model = SkipGramModel(vocabulary.size)
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=5,
    verbose=True,
    mode="min",
)
ckpt_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="ckpt",
    filename='{epoch}-{val_loss:.2f}',
    save_top_k=3,
    mode="min",
    save_last=True
)
trainer = Trainer(
    max_epochs=EPOCHS,
    callbacks=[early_stop_callback, ckpt_callback],
    limit_train_batches=2 if DEBUG else 40000,
    limit_val_batches=2 if DEBUG else 500,
    val_check_interval=1 if DEBUG else 2000
)
trainer.fit(model, train_loader, val_loader)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.

  | Name       | Type             | Params
------------------------------------------------
0 | embeddings | Embedding        | 9.1 M 
1 | out_layer  | Linear           | 9.2 M 
2 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
18.3 M    Trainable params
0         Non-trainable params
18.3 M    Total params
73.179    Total estimated model params size (MB)


Epoch 0:  50%|██████████████████████████████████                                  | 1/2 [00:00<00:00,  7.78it/s, v_num=20]
Validation: |                                                                                       | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                   | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                      | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|███████████████████████████████                               | 1/2 [00:00<00:00, 21.60it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 21.18it/s][A
Epoch 0: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.59it/s, v_num=20][A
Validation: |                                                                                       | 0/? [00:00<?, ?it/s

Metric val_loss improved. New best score: 11.239


Epoch 1:  50%|██████████████████████████████████                                  | 1/2 [00:00<00:00,  8.25it/s, v_num=20]
Validation: |                                                                                       | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                   | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                      | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|███████████████████████████████                               | 1/2 [00:00<00:00, 21.79it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 21.43it/s][A
Epoch 1: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.71it/s, v_num=20][A
Validation: |                                                                                       | 0/? [00:00<?, ?it/s

Metric val_loss improved by 0.108 >= min_delta = 0.0. New best score: 11.131


Epoch 2:  50%|██████████████████████████████████                                  | 1/2 [00:00<00:00,  8.01it/s, v_num=20]
Validation: |                                                                                       | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                   | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                      | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|███████████████████████████████                               | 1/2 [00:00<00:00, 20.73it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 20.85it/s][A
Epoch 2: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.64it/s, v_num=20][A
Validation: |                                                                                       | 0/? [00:00<?, ?it/s

Metric val_loss improved by 0.091 >= min_delta = 0.0. New best score: 11.040


Epoch 3:  50%|██████████████████████████████████                                  | 1/2 [00:00<00:00,  6.72it/s, v_num=20]
Validation: |                                                                                       | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                   | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                      | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|███████████████████████████████                               | 1/2 [00:00<00:00, 21.99it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 21.58it/s][A
Epoch 3: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.19it/s, v_num=20][A
Validation: |                                                                                       | 0/? [00:00<?, ?it/s

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 11.025


Epoch 4:  50%|██████████████████████████████████                                  | 1/2 [00:00<00:00,  8.05it/s, v_num=20]
Validation: |                                                                                       | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                                   | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                      | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|███████████████████████████████                               | 1/2 [00:00<00:00, 21.46it/s][A
Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 20.93it/s][A
Epoch 4: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.67it/s, v_num=20][A
Validation: |                                                                                       | 0/? [00:00<?, ?it/s

Monitored metric val_loss did not improve in the last 5 records. Best score: 11.025. Signaling Trainer to stop.


Epoch 8: 100%|████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  0.93it/s, v_num=20]
