In [125]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [126]:
import sys
import os
import torch
from functools import partial
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import torch.nn as nn
from embedding import CustomEmbedding
from transformers import EncoderDecoderTransformer
from utils import padding_collate_fn

from generate_data import RandomIntegerDataset

In [127]:
n_real_tokens = 10
PAD_TOKEN_IDX = n_real_tokens
SOS_TOKEN_IDX = n_real_tokens + 1
EOS_TOKEN_IDX = n_real_tokens + 2
vocab_size = n_real_tokens + 3
D_MODEL = 16

embeddings = CustomEmbedding(vocab_size = vocab_size + 3, d_model = D_MODEL) # 3 = PAD, SOS, EOS

indices = torch.tensor([1,9])

# print(embeddings.embeddings.weight)
print(embeddings.embeddings(indices))

tensor([[ 0.4314, -0.0769,  0.4303,  0.8248,  2.1725, -1.6138,  0.3383, -1.9118,
          0.7193,  0.4160, -2.4166, -0.0286, -0.6437,  1.1062,  1.5996, -1.1645],
        [-1.8077, -1.1667, -2.4940,  0.5303,  0.3790,  0.3417,  0.5226, -2.0415,
         -1.0051,  0.7987,  1.4476,  0.8875, -0.3965,  0.7148, -0.3149,  1.0780]],
       grad_fn=<EmbeddingBackward0>)


In [128]:
MAX_CONTEXT_WINDOW = 50

BATCH_SIZE = 32
MIN_SEQ_LEN = 2
MAX_SEQ_LEN = min(10, MAX_CONTEXT_WINDOW)

rand_ds = RandomIntegerDataset(2, 10, 10000, [i for i in range(10)])
dataloader = DataLoader(rand_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

In [129]:
input, label = next(iter(dataloader))
print(input[0])
print(input[1])
print(label)

tensor([[ 3,  3,  1, 10, 10, 10, 10, 10, 10, 10],
        [ 7,  4,  9,  6,  4,  3,  8,  0, 10, 10],
        [ 5,  2,  2,  9,  5,  8,  2, 10, 10, 10],
        [ 9,  9,  9,  0,  4,  6,  1,  8,  0,  1],
        [ 5,  9,  3,  7,  0,  6,  4,  3,  9,  9],
        [ 4,  3,  0,  5,  3,  3,  8,  3,  4, 10],
        [ 7,  4,  6,  6,  2,  4, 10, 10, 10, 10],
        [ 6,  6,  9,  9,  7,  8,  8, 10, 10, 10],
        [ 7,  7,  7,  4,  0,  4, 10, 10, 10, 10],
        [ 7,  4,  0,  0,  2,  0,  5, 10, 10, 10],
        [ 8,  3,  9,  5,  6, 10, 10, 10, 10, 10],
        [ 7,  6,  6,  0,  1,  3,  4,  3,  7,  9],
        [ 4,  3,  3,  3,  4,  9,  0,  1,  7, 10],
        [ 6,  9,  8,  2,  5,  8, 10, 10, 10, 10],
        [ 7,  8,  5,  9,  7,  4,  5,  9,  2, 10],
        [ 8,  2,  7,  2,  3,  4,  8,  0, 10, 10],
        [ 8,  0,  4,  9,  8,  4,  4,  2,  9,  0],
        [ 4,  8,  0,  0,  9,  8,  0,  8, 10, 10],
        [ 0,  9,  9,  6,  7,  0,  8, 10, 10, 10],
        [ 2,  8,  5,  7, 10, 10, 10, 10, 10, 10],


In [130]:
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_IDX, reduction = 'sum')

model = EncoderDecoderTransformer(
                    embeddings = embeddings, 
                    vocab_size = vocab_size, 
                    d_model = D_MODEL, 
                    num_attention_heads = 4, 
                    num_encoder_layers = 1, 
                    num_decoder_layers = 1, 
                    dim_feedforward = 32, 
                    max_context_window = MAX_CONTEXT_WINDOW)

optim = torch.optim.SGD(params = model.parameters(), lr = 1e-3, momentum = 0.9, weight_decay = 1e-4)

In [131]:
def run_train_epoch(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, optimizer: torch.optim.Optimizer, calculate_sequence_accuracy: bool = False, calculate_token_accuracy: bool = False):
    """

    """
    model.train()

    epoch_loss = 0.0
    total_correct_sequences = 0
    total_correct_tokens = 0

    for i, ((source, target), label) in tqdm(enumerate(dataloader)):

        # FORWARD
        pred_logits = model(source, target)

        # pred_logits.shape: [batch_size, seq_len, vocab_size]
        # label.shape: [batch_size, seq_len]

        # CrossEntropyLoss (loss_fn) only takes 2D predictions (n_batch * seq_len, vocab_size) and 1D labels (n_batch * seq_len)
        batch_loss = loss_fn(pred_logits.view(-1, vocab_size), label.view(-1))

        # LOG
        with torch.no_grad():
            epoch_loss += batch_loss.item()

            if calculate_sequence_accuracy:
                predictions = torch.argmax(pred_logits, dim = -1) # predictions.shape: [batch_size, seq_len]
                match_matrix = torch.eq(predictions, label)
                num_correct_sequences = torch.all(match_matrix, dim = 1).sum()
                total_correct_sequences += num_correct_sequences

            if calculate_token_accuracy:
                ...        

        # BACKWARD
        batch_loss.backward()

        # OPTIMIZE
        optimizer.step()
        optimizer.zero_grad()

    average_epoch_loss = epoch_loss / len(dataloader.dataset)
    if calculate_sequence_accuracy:
        average_epoch_sequence_accuracy = total_correct_sequences.item() / len(dataloader.dataset)
    if calculate_token_accuracy:
        ...

    return average_epoch_loss, average_epoch_sequence_accuracy

def run_validation_loop(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module):
    
    model.eval()

In [132]:
EPOCHS = 10

training_losses = list()
training_accuracies = list()
validation_losses = list()
validation_accuracies = list()

for i in range(EPOCHS):
    # print(f'Running epoch {i+1}...')

    training_loss, training_accuracy = run_train_epoch(dataloader, model, loss_fn, optim, calculate_sequence_accuracy = True)
    training_losses.append(training_loss)
    training_accuracies.append(training_accuracy)

    # validation_loss = run_validation_loop(dataloader, model, loss_fn)
    # validation_losses.append(validation_loss)

print(training_losses)
print(training_accuracies)

313it [00:02, 134.78it/s]
313it [00:02, 140.12it/s]
313it [00:02, 140.05it/s]
313it [00:02, 140.23it/s]
313it [00:02, 136.18it/s]
313it [00:02, 136.20it/s]
313it [00:02, 137.01it/s]
313it [00:02, 130.64it/s]
313it [00:02, 136.09it/s]
313it [00:02, 139.23it/s]

[5.708299437332153, 2.954068369293213, 2.5684640590667724, 2.4651227336883546, 2.2263846506118776, 2.1493445888519287, 1.8944116312026977, 1.7958404651641846, 1.655647212409973, 1.6988458086013793]
[0.0008, 0.0054, 0.0103, 0.011, 0.0149, 0.0185, 0.0227, 0.0234, 0.0276, 0.0258]



