In [405]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [503]:
import sys
import os
import torch
from functools import partial
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import torch.nn as nn
from embedding import CustomEmbedding
from transformers import EncoderDecoderTransformer
from utils import padding_collate_fn
import numpy as np

from generate_data import RandomIntegerDataset

In [473]:
n_real_tokens = 10
PAD_TOKEN_IDX = n_real_tokens
SOS_TOKEN_IDX = n_real_tokens + 1
EOS_TOKEN_IDX = n_real_tokens + 2
vocab_size = n_real_tokens + 3
D_MODEL = 32

embeddings = CustomEmbedding(vocab_size, d_model = D_MODEL) # 3 = PAD, SOS, EOS

indices = torch.tensor([1,9])

# print(embeddings.embeddings.weight)
print(embeddings.embeddings(indices))

tensor([[ 1.4521,  0.7097,  1.0505, -0.6166, -1.0084,  1.5881, -0.4116,  0.0634,
         -1.4550, -0.4137,  0.1941, -0.7222, -0.7129,  0.9572, -0.2055, -0.3407,
          0.1691, -0.4828,  1.8598,  0.3019, -0.8803, -0.1361, -0.8615,  0.1109,
         -0.4062, -0.6712,  0.2429, -1.2812, -0.1855, -0.2590, -2.9673,  0.6972],
        [ 2.0117, -0.3468, -0.2348,  0.3069,  0.3879, -0.7460,  0.8612,  1.1340,
         -0.4244, -2.9503,  1.7311, -1.0154,  1.6944, -0.3789,  0.9265, -0.7821,
          0.8744,  0.1687,  1.2608, -1.3116,  0.5958,  0.5129,  0.4727,  0.2873,
          0.6193,  2.2453,  0.2456,  1.1684,  0.8631, -0.5087,  0.4214,  2.7316]],
       grad_fn=<EmbeddingBackward0>)


In [None]:
MAX_CONTEXT_WINDOW = 50

BATCH_SIZE = 32
MIN_SEQ_LEN = 2
MAX_SEQ_LEN = min(10, MAX_CONTEXT_WINDOW)

NUM_TRAINING_SEQUENCES = 10000
NUM_VALIDATION_SEQUENCES = 1000

VOCAB = [i for i in range(n_real_tokens)] # does not include SOS, EOS, PAD

VOCAB_MAP = dict()

for i, token in enumerate(VOCAB):
    VOCAB_MAP[i] = token
VOCAB_MAP[len(VOCAB_MAP)] = PAD_TOKEN_IDX
VOCAB_MAP[len(VOCAB_MAP) + 1] = SOS_TOKEN_IDX
VOCAB_MAP[len(VOCAB_MAP) + 2] = EOS_TOKEN_IDX

train_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_TRAINING_SEQUENCES, VOCAB)
train_dataloader = DataLoader(train_rand_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

val_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_VALIDATION_SEQUENCES, VOCAB)
val_dataloader = DataLoader(val_rand_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

In [498]:
input, label = next(iter(train_dataloader))
print(input[0])
print(input[1])
print(label)

tensor([[ 8,  7,  9,  9,  4,  8,  4,  1,  5,  8],
        [ 6,  8,  7,  8,  5,  2, 10, 10, 10, 10],
        [ 7,  5,  8, 10, 10, 10, 10, 10, 10, 10],
        [ 5,  8,  7,  0,  7,  3,  0,  1,  3, 10],
        [ 9,  2,  6,  4,  5,  3,  9, 10, 10, 10],
        [ 1,  2,  6,  4, 10, 10, 10, 10, 10, 10],
        [ 0,  7,  0,  0,  6,  2,  8,  1,  1,  0],
        [ 0,  0,  9,  3, 10, 10, 10, 10, 10, 10],
        [ 1,  1,  6,  7,  7,  6,  0,  4,  4,  5],
        [ 7,  1,  3,  0,  2,  9,  0, 10, 10, 10],
        [ 4,  3,  8,  5,  1, 10, 10, 10, 10, 10],
        [ 0,  6,  1,  0,  6,  9,  6,  9, 10, 10],
        [ 9,  4,  9,  5, 10, 10, 10, 10, 10, 10],
        [ 2,  5,  3,  9,  6,  9,  4,  6,  2,  4],
        [ 2,  2,  3, 10, 10, 10, 10, 10, 10, 10],
        [ 6,  9,  9,  6,  5,  2,  8, 10, 10, 10],
        [ 4,  1,  7,  9,  8,  5, 10, 10, 10, 10],
        [ 6,  8,  9,  3,  2,  9,  2,  6,  1,  0],
        [ 9,  5, 10, 10, 10, 10, 10, 10, 10, 10],
        [ 0,  0, 10, 10, 10, 10, 10, 10, 10, 10],


In [499]:
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_IDX, reduction = 'sum')

model = EncoderDecoderTransformer(
                    embeddings = embeddings, 
                    vocab_size = vocab_size, 
                    d_model = D_MODEL, 
                    num_attention_heads = 4, 
                    num_encoder_layers = 1, 
                    num_decoder_layers = 1, 
                    dim_feedforward = 32, 
                    dropout = 0.0,
                    max_context_window = MAX_CONTEXT_WINDOW,
                    use_pre_lnorm = True)

optim = torch.optim.SGD(params = model.parameters(), lr = 1e-3, momentum = 0.9, weight_decay = 1e-4)

In [500]:
def greedy_decode(source: torch.Tensor, model: nn.Module) -> torch.Tensor:
    """
    """
    batch_size = source.size(dim = 0)

    encoder_output, source_pad_mask = model.encode(source)

    # target will contain num_batch sequences of indices that are the predicted next-words for each batch element
    target = torch.full((batch_size, 1), SOS_TOKEN_IDX)

    while (target[:, -1] == PAD_TOKEN_IDX).sum() < batch_size and target.size(dim = 1) <= MAX_CONTEXT_WINDOW:

        decoder_output, _ = model.decode(target, encoder_output, source_pad_mask)
        pred_logits = model.project_into_vocab(decoder_output)

        # pred_logits.shape: [batch_size, seq_len, vocab_size]

        last_row_pred_logits = pred_logits[:, -1, :]

        # last_row_pred_logits.shape == [batch_size, vocab_size]

        predictions = torch.argmax(last_row_pred_logits, dim = -1)

        # target.shape: [batch_size, num_loops_complete - 1]
        # predictions.shape: [batch_size]

        target = torch.concat((target, predictions.reshape(-1, 1)), dim = 1)

    return target

In [510]:
def run_train_epoch(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, optimizer: torch.optim.Optimizer, calculate_sequence_accuracy: bool = False, calculate_token_accuracy: bool = False):
    """
    """
    model.train()

    num_sequences = len(dataloader.dataset)
    num_tokens = 0

    epoch_loss = 0.0
    total_correct_sequences = 0
    total_correct_tokens = 0

    for (source, target), label in tqdm(dataloader):

        # FORWARD
        pred_logits = model(source, target)

        # pred_logits.shape: [batch_size, seq_len, vocab_size]
        # label.shape: [batch_size, seq_len]

        # CrossEntropyLoss (loss_fn) only takes 2D predictions (n_batch * seq_len, vocab_size) and 1D labels (n_batch * seq_len)
        batch_loss = loss_fn(pred_logits.view(-1, pred_logits.size(-1)), label.view(-1))

        # LOG
        with torch.no_grad():
            epoch_loss += batch_loss.item()

            predictions = torch.argmax(pred_logits, dim = -1) # predictions.shape: [batch_size, seq_len]
            match_matrix = torch.eq(predictions, label)

            if calculate_sequence_accuracy:
                num_correct_sequences = torch.all(match_matrix, dim = 1).sum()
                total_correct_sequences += num_correct_sequences.item()

            if calculate_token_accuracy:
                num_correct_tokens = match_matrix.sum()      
                total_correct_tokens += num_correct_tokens.item()

                num_tokens += torch.numel(label)

        # BACKWARD
        batch_loss.backward()

        # OPTIMIZE
        optimizer.step()
        optimizer.zero_grad()

    average_epoch_loss = epoch_loss / num_sequences
    average_epoch_sequence_accuracy = total_correct_sequences / num_sequences if calculate_sequence_accuracy else None
    average_epoch_token_accuracy = total_correct_tokens / num_tokens if calculate_token_accuracy else None

    return average_epoch_loss, average_epoch_sequence_accuracy, average_epoch_token_accuracy

def run_gold_validation_loop(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, calculate_sequence_accuracy: bool = False, calculate_token_accuracy: bool = False):
    """
    """
    model.eval()

    num_sequences = len(dataloader.dataset)
    num_tokens = 0

    epoch_loss = 0.0
    total_correct_sequences = 0
    total_correct_tokens = 0

    with torch.no_grad():
        
        for (source, target), label in tqdm(dataloader):
            
            # FORWARD
            pred_logits = model(source, target)
            batch_loss = loss_fn(pred_logits.view(-1, pred_logits.size(-1)), label.view(-1))

            # LOG
            epoch_loss += batch_loss.item()

            predictions = torch.argmax(pred_logits, dim = -1) # predictions.shape: [batch_size, seq_len]
            match_matrix = torch.eq(predictions, label)

            if calculate_sequence_accuracy:
                num_correct_sequences = torch.all(match_matrix, dim = 1).sum()
                total_correct_sequences += num_correct_sequences.item()

            if calculate_token_accuracy:
                num_correct_tokens = match_matrix.sum()      
                total_correct_tokens += num_correct_tokens.item()

                num_tokens += torch.numel(label)

    average_epoch_loss = epoch_loss / num_sequences
    average_epoch_sequence_accuracy = total_correct_sequences / num_sequences if calculate_sequence_accuracy else None
    average_epoch_token_accuracy = total_correct_tokens / num_tokens if calculate_token_accuracy else None

    return average_epoch_loss, average_epoch_sequence_accuracy, average_epoch_token_accuracy

def run_autoregressive_validation_loop(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module):
    """
    """
    model.eval()

    epoch_loss = 0.0

    with torch.no_grad():
        
        for (source, _), label in tqdm(dataloader):

            # FORWARD
            pred_indices = greedy_decode(source, model)
            np_pred_indices = pred_indices.numpy().copy()

            token_values = np.array(list(VOCAB_MAP.values()))
            predicted_source_tokens = token_values[source.numpy().copy()]
            predicted_target_tokens = token_values[np_pred_indices]

            for s, t in zip(predicted_source_tokens, predicted_target_tokens):
                print(f'Source: {s} | Predicted Target: {t}', end = '\n\n')
        
            # batch_loss = loss_fn(..., ...)

            # LOG
            # epoch_loss += batch_loss.item()

            break


In [511]:
run_autoregressive_validation_loop(val_dataloader, model, loss_fn)

  0%|          | 0/32 [00:00<?, ?it/s]

Source: ['2' '5' '1' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'] | Predicted Target: ['<SOS>' '7' '3' '7' '8' '<EOS>' '<SOS>' '7' '1' '5' '3' '3' '3' '3' '3'
 '5' '<SOS>' '1' '<EOS>' '<SOS>' '1' '<EOS>' '0' '0' '0' '0' '0' '0' '0'
 '0' '0' '0' '2' '9' '5' '3' '3' '3' '3' '7' '8' '5' '3' '7' '3' '7' '1'
 '5' '3' '7' '1']

Source: ['9' '7' '7' '4' '1' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'] | Predicted Target: ['<SOS>' '1' '<EOS>' '0' '0' '0' '2' '7' '1' '5' '3' '3' '3' '3' '5'
 '<SOS>' '1' '<EOS>' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0'
 '0' '9' '5' '7' '1' '0' '0' '0' '0' '0' '0' '0' '0' '0' '2' '7' '1' '5'
 '<SOS>' '5']

Source: ['8' '8' '7' '4' '3' '6' '8' '3' '5' '<PAD>'] | Predicted Target: ['<SOS>' '0' '0' '0' '0' '0' '0' '2' '7' '1' '<EOS>' '3' '3' '3' '3' '3'
 '3' '3' '3' '3' '7' '1' '<EOS>' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0'
 '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0']

Source: ['1' '6' '<PAD>' '<PAD>' '<PAD>' '<PAD>' 




In [412]:
EPOCHS = 10

training_losses = list()
training_sequence_accuracies = list()
training_token_accuracies = list()

gold_validation_losses = list()
gold_validation_sequence_accuracies = list()
gold_validation_token_accuracies = list()

for i in range(EPOCHS):
    # print(f'Running epoch {i+1}...')

    training_loss, training_sequence_accuracy, training_token_accuracy = run_train_epoch(train_dataloader, model, loss_fn, optim, calculate_sequence_accuracy = True, calculate_token_accuracy = True)

    training_losses.append(training_loss)
    training_sequence_accuracies.append(training_sequence_accuracy)
    training_token_accuracies.append(training_token_accuracy)

    gold_val_loss, gold_val_sequence_accuracy, gold_val_token_accuracy = run_gold_validation_loop(val_dataloader, model, loss_fn, calculate_sequence_accuracy = True, calculate_token_accuracy = True)
    gold_validation_losses.append(gold_val_loss)
    gold_validation_sequence_accuracies.append(gold_val_sequence_accuracy)
    gold_validation_token_accuracies.append(gold_val_token_accuracy)

print(training_losses)
print(training_sequence_accuracies)
print(training_token_accuracies)

print()

print(gold_validation_losses)
print(gold_validation_sequence_accuracies)
print(gold_validation_token_accuracies)

100%|██████████| 313/313 [00:02<00:00, 113.69it/s]
100%|██████████| 32/32 [00:00<00:00, 312.40it/s]
100%|██████████| 313/313 [00:02<00:00, 116.41it/s]
100%|██████████| 32/32 [00:00<00:00, 304.74it/s]
100%|██████████| 313/313 [00:02<00:00, 117.62it/s]
100%|██████████| 32/32 [00:00<00:00, 303.17it/s]
100%|██████████| 313/313 [00:02<00:00, 118.31it/s]
100%|██████████| 32/32 [00:00<00:00, 315.31it/s]
100%|██████████| 313/313 [00:02<00:00, 118.65it/s]
100%|██████████| 32/32 [00:00<00:00, 321.20it/s]
100%|██████████| 313/313 [00:02<00:00, 123.55it/s]
100%|██████████| 32/32 [00:00<00:00, 317.47it/s]
100%|██████████| 313/313 [00:02<00:00, 123.65it/s]
100%|██████████| 32/32 [00:00<00:00, 316.87it/s]
100%|██████████| 313/313 [00:02<00:00, 117.97it/s]
100%|██████████| 32/32 [00:00<00:00, 292.90it/s]
100%|██████████| 313/313 [00:02<00:00, 117.72it/s]
100%|██████████| 32/32 [00:00<00:00, 309.24it/s]
100%|██████████| 313/313 [00:02<00:00, 121.90it/s]
100%|██████████| 32/32 [00:00<00:00, 309.78it/s]

[3.433036536407471, 1.772785053062439, 1.782198503112793, 1.4157891374588012, 1.3334199272155762, 1.1036983204126358, 1.186024510860443, 0.9799411701440811, 1.0434875957012177, 0.8846561305046081]
[0.008, 0.0272, 0.0326, 0.0364, 0.0439, 0.0495, 0.0478, 0.0531, 0.0524, 0.0603]
[0.514817241881462, 0.5753443375601224, 0.5778949861536219, 0.5872469783020242, 0.5911029304563348, 0.5977832871518156, 0.5958861681970559, 0.6020326387877022, 0.5998452294246176, 0.6056445739257101]

[2.105001425743103, 1.9623536720275878, 4.119647649765015, 1.3570107583999633, 1.1826796026229858, 0.7914193334579468, 0.962474100112915, 1.753965274810791, 0.7659035215377807, 0.7691247215270997]
[0.008, 0.015, 0.001, 0.042, 0.067, 0.04, 0.051, 0.04, 0.064, 0.061]
[0.5802334062727936, 0.587709700948213, 0.5289934354485777, 0.598741794310722, 0.6091356673960613, 0.6198942377826404, 0.6146973012399708, 0.5888037928519329, 0.6227206418672502, 0.6208059810357404]



