In [184]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [185]:
import sys
import os
import torch
from functools import partial
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import torch.nn as nn
from embedding import CustomEmbedding
from transformer import EncoderDecoderTransformer
from utils import padding_collate_fn
import numpy as np

from generate_data import RandomIntegerDataset

In [186]:
n_real_tokens = 10
PAD_TOKEN_IDX = n_real_tokens
SOS_TOKEN_IDX = n_real_tokens + 1
EOS_TOKEN_IDX = n_real_tokens + 2
vocab_size = n_real_tokens + 3
D_MODEL = 64

embeddings = CustomEmbedding(vocab_size, d_model = D_MODEL) # 3 = PAD, SOS, EOS

indices = torch.tensor([1,9])

# print(embeddings.embeddings.weight)
print(embeddings.embeddings(indices))

tensor([[-0.8831,  1.6792, -0.1320, -0.9545, -0.0167, -0.2237, -1.3355, -0.6046,
         -0.3862, -0.5330,  1.3366, -0.0794, -0.4472,  0.8060,  0.4868,  1.2126,
         -0.7660, -0.2120, -0.5290,  0.7862,  2.3279,  0.2625,  0.1194,  0.7605,
         -1.7818, -0.1591, -0.5018, -0.0629,  0.8225, -0.3061, -0.7713, -1.2474,
         -1.0157,  1.3000, -1.0103,  0.1111,  1.6965, -1.0151,  2.0740,  0.7094,
         -1.1590,  0.8839, -0.5437, -1.1268,  0.3676, -0.9636, -0.2503, -0.8230,
         -0.3887, -0.9166,  0.1668,  1.7458,  0.1024, -1.8647,  1.9556,  2.0464,
          0.0211,  1.6812,  0.9375, -2.0898, -0.3710, -1.0631, -0.5699, -0.0828],
        [-0.7415,  0.7212,  0.2374,  0.1031, -0.6000, -0.5691,  0.1777,  0.2734,
          0.1659,  0.6873,  1.2199, -1.3455,  1.2414,  0.2652,  0.4928, -0.1385,
         -0.1114, -0.5635,  0.2973, -0.2960, -0.5593,  1.2331, -0.2966,  0.4955,
         -0.5019, -1.3238, -0.8767, -0.4242, -0.4270, -2.6206,  1.1567,  0.3940,
          0.7986, -1.3425, 

In [187]:
MAX_CONTEXT_WINDOW = 50

BATCH_SIZE = 64
MIN_SEQ_LEN = 2
MAX_SEQ_LEN = min(20, MAX_CONTEXT_WINDOW)

NUM_TRAINING_SEQUENCES = 10000
NUM_VALIDATION_SEQUENCES = 1000

VOCAB = [i for i in range(n_real_tokens)] # does not include SOS, EOS, PAD

VOCAB_MAP = dict()

for i, token in enumerate(VOCAB):
    VOCAB_MAP[i] = token
VOCAB_MAP[len(VOCAB_MAP)] = '<PAD>'
VOCAB_MAP[len(VOCAB_MAP) + 1] = '<SOS>'
VOCAB_MAP[len(VOCAB_MAP) + 2] = '<EOS>'

train_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_TRAINING_SEQUENCES, VOCAB)
train_dataloader = DataLoader(train_rand_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

val_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_VALIDATION_SEQUENCES, VOCAB)
val_dataloader = DataLoader(val_rand_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

In [188]:
input, label = next(iter(train_dataloader))
print(input[0])
print(input[1])
print(label)

tensor([[ 3,  1,  5,  ...,  1, 10, 10],
        [ 0,  1,  9,  ..., 10, 10, 10],
        [ 3,  8,  4,  ..., 10, 10, 10],
        ...,
        [ 1,  6,  2,  ..., 10, 10, 10],
        [ 1,  8,  5,  ..., 10, 10, 10],
        [ 7,  6,  4,  ..., 10, 10, 10]])
tensor([[11,  0,  0,  ...,  9, 10, 10],
        [11,  0,  1,  ..., 10, 10, 10],
        [11,  3,  4,  ..., 10, 10, 10],
        ...,
        [11,  1,  2,  ..., 10, 10, 10],
        [11,  1,  5,  ..., 10, 10, 10],
        [11,  0,  1,  ..., 10, 10, 10]])
tensor([[ 0,  0,  0,  ..., 12, 10, 10],
        [ 0,  1,  1,  ..., 10, 10, 10],
        [ 3,  4,  4,  ..., 10, 10, 10],
        ...,
        [ 1,  2,  2,  ..., 10, 10, 10],
        [ 1,  5,  7,  ..., 10, 10, 10],
        [ 0,  1,  2,  ..., 10, 10, 10]])


In [189]:
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_IDX, reduction = 'sum')

model = EncoderDecoderTransformer(
                    embeddings = embeddings, 
                    vocab_size = vocab_size, 
                    d_model = D_MODEL, 
                    num_attention_heads = 4, 
                    num_encoder_layers = 2, 
                    num_decoder_layers = 2, 
                    dim_feedforward = 32, 
                    dropout = 0.0,
                    max_context_window = MAX_CONTEXT_WINDOW,
                    use_pre_lnorm = True)

optim = torch.optim.SGD(params = model.parameters(), lr = 1e-4, momentum = 0.9, weight_decay = 1e-4)

In [190]:
def greedy_decode(source: torch.Tensor, model: nn.Module) -> torch.Tensor:
    """
    Designed to do autoregressive inference in an Encoder-Decoder transformer.

    This greedy decoder always predicts the vocabulary token corresponding to the highest logit.

    Takes the source sequence and the Encoder-Decoder to produce a predicted
    sequence starting from <SOS>.

    Note: This function *can* handle batches of sequences.

    Args:
        source - The source sequence to be passed to the Transformer's encoder block.
        model - The Encoder-Decoder transformer over which to greedy decode

    Returns:
        target - The batch of predicted sequences corresponding to the input sources.
        target_logits - (# batch elements =) batch_size (# rows =) seq_len (# cols =) vocab-dimensional vectors, each of which
        corresponds to the set of logits on a particular inference step within a given sequence.
    """
    batch_size = source.size(dim = 0)

    encoder_output, source_pad_mask = model.encode(source)

    # target will contain num_batch sequences of indices that are the predicted next-words for each batch element
    target = torch.full((batch_size, 1), SOS_TOKEN_IDX) # target.shape: [batch_size, num_loops_complete - 1]
    target_logits = torch.zeros((batch_size, 1, vocab_size))

    finished = torch.full((batch_size, ), False)

    while not finished.all() and target.size(dim = 1) <= MAX_CONTEXT_WINDOW:

        decoder_output, _ = model.decode(target, encoder_output, source_pad_mask)
        pred_logits = model.project_into_vocab(decoder_output) # pred_logits.shape: [batch_size, seq_len, vocab_size]

        last_row_pred_logits = pred_logits[:, -1, :] # last_row_pred_logits.shape == [batch_size, vocab_size]

        # Track next-word logits for loss_fn later.
        target_logits = torch.concat((target_logits, last_row_pred_logits.unsqueeze(1)), dim = 1)

        predictions = torch.argmax(last_row_pred_logits, dim = -1) # predictions.shape: [batch_size]

        # For any finished sequences (i.e. previous EOS-producers), force their prediction from this round to be a pad.
        predictions[finished] = PAD_TOKEN_IDX

        # Mark any additional sequences that just produced an EOS as finished.
        finished |= predictions == EOS_TOKEN_IDX

        target = torch.concat((target, predictions.reshape(-1, 1)), dim = 1) # target.shape: [batch_size, num_loops_complete]

    return target, target_logits[:, 1:, :]

In [191]:
def run_train_epoch(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, optimizer: torch.optim.Optimizer, calculate_sequence_accuracy: bool = False, calculate_token_accuracy: bool = False):
    """
    Runs one training epoch (processing the entire training dataset once).
    
    Uses Teacher Forcing to train token-to-token mapping quality without cascading errors and for parallelization.

    Args:
        dataloader - The dataloader to process the dataset in BATCH_SIZE batches
        model - The Encoder-Decoder that is being trained
        loss_fn - The loss function to calculate the model's correctness
        optimizer - The optimizer to improve the model's weights
        calculate_sequence_accuracy - A flag to mark whether sequence-level correctness should be tracked
        calculate_token_accuracy - A flag to mark whether token-level correctness should be tracked
    """
    model.train()

    num_sequences = len(dataloader.dataset)
    num_tokens = 0

    epoch_loss = 0.0
    total_correct_sequences = 0
    total_correct_tokens = 0

    for (source, target), label in tqdm(dataloader):

        # FORWARD
        pred_logits = model(source, target)

        # pred_logits.shape: [batch_size, seq_len, vocab_size]
        # label.shape: [batch_size, seq_len]

        # CrossEntropyLoss (loss_fn) only takes 2D predictions (n_batch * seq_len, vocab_size) and 1D labels (n_batch * seq_len)
        batch_loss = loss_fn(pred_logits.view(-1, pred_logits.size(-1)), label.view(-1))

        # LOG
        with torch.no_grad():
            epoch_loss += batch_loss.item()

            predictions = torch.argmax(pred_logits, dim = -1) # predictions.shape: [batch_size, seq_len]
            match_matrix = torch.eq(predictions, label)

            if calculate_sequence_accuracy:
                num_correct_sequences = torch.all(match_matrix, dim = 1).sum()
                total_correct_sequences += num_correct_sequences.item()

            if calculate_token_accuracy:
                num_correct_tokens = match_matrix.sum()      
                total_correct_tokens += num_correct_tokens.item()

                num_tokens += torch.numel(label)

        # BACKWARD
        batch_loss.backward()

        # OPTIMIZE
        optimizer.step()
        optimizer.zero_grad()

    average_epoch_loss = epoch_loss / num_sequences
    average_epoch_sequence_accuracy = total_correct_sequences / num_sequences if calculate_sequence_accuracy else None
    average_epoch_token_accuracy = total_correct_tokens / num_tokens if calculate_token_accuracy else None

    return average_epoch_loss, average_epoch_sequence_accuracy, average_epoch_token_accuracy

def run_gold_validation_loop(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, calculate_sequence_accuracy: bool = False, calculate_token_accuracy: bool = False):
    """
    Runs one validation epoch (processing the entire validation dataset once). 

    Uses Teacher Forcing (i.e. "gold") to evaluate token-to-token mapping quality and for parallelization.

    Args:
        dataloader - The dataloader to process the dataset in BATCH_SIZE batches
        model - The Encoder-Decoder that is being trained
        loss_fn - The loss function to calculate the model's correctness
        calculate_sequence_accuracy - A flag to mark whether sequence-level correctness should be tracked
        calculate_token_accuracy - A flag to mark whether token-level correctness should be tracked
    """
    model.eval()

    num_sequences = len(dataloader.dataset)
    num_tokens = 0

    epoch_loss = 0.0
    total_correct_sequences = 0
    total_correct_tokens = 0

    with torch.no_grad():
        
        for (source, target), label in tqdm(dataloader):
            
            # FORWARD
            pred_logits = model(source, target)
            batch_loss = loss_fn(pred_logits.view(-1, pred_logits.size(-1)), label.view(-1))

            # LOG
            epoch_loss += batch_loss.item()

            predictions = torch.argmax(pred_logits, dim = -1) # predictions.shape: [batch_size, seq_len]
            match_matrix = torch.eq(predictions, label)

            if calculate_sequence_accuracy:
                num_correct_sequences = torch.all(match_matrix, dim = 1).sum()
                total_correct_sequences += num_correct_sequences.item()

            if calculate_token_accuracy:
                num_correct_tokens = match_matrix.sum()      
                total_correct_tokens += num_correct_tokens.item()

                num_tokens += torch.numel(label)

    average_epoch_loss = epoch_loss / num_sequences
    average_epoch_sequence_accuracy = total_correct_sequences / num_sequences if calculate_sequence_accuracy else None
    average_epoch_token_accuracy = total_correct_tokens / num_tokens if calculate_token_accuracy else None

    return average_epoch_loss, average_epoch_sequence_accuracy, average_epoch_token_accuracy

def run_autoregressive_validation_loop(dataloader: DataLoader, model: nn.Module):
    """
    Runs one autoregressive validation epoch (processing the entire validation dataset once). 

    Args:
        dataloader - The dataloader to process the dataset in BATCH_SIZE batches
        model - The Encoder-Decoder that is being trained
    """
    model.eval()

    correct_sequences = 0
    total_sequences = 0

    with torch.no_grad():
        
        for (source, _), label in tqdm(dataloader):

            # FORWARD
            pred_indices, pred_logits = greedy_decode(source, model)

            np_source_indices = source.numpy().copy()
            np_pred_target_indices = pred_indices.numpy().copy()

            token_values = np.array(list(VOCAB_MAP.values()))
            predicted_source_tokens = token_values[np_source_indices]
            predicted_target_tokens = token_values[np_pred_target_indices]

            for s, t in zip(predicted_source_tokens, predicted_target_tokens):
                source_end_index = np.argmax(s == '<PAD>')  if '<PAD>' in s else len(s)
                target_end_index = np.argmax(t == '<EOS>')
                if np.array_equal(np.sort(s[:source_end_index]), t[1:target_end_index]):
                    correct_sequences += 1
                else:
                    print(np.sort(s[:source_end_index]))
                    print(t[1:target_end_index])
                    print(f'{'Source:':<20} {s}\n{'Predicted Target:':<20} {t}', end = '\n\n')

            total_sequences += predicted_target_tokens.shape[0]

    return correct_sequences / total_sequences


In [192]:
EPOCHS = 10

training_losses = list()
training_sequence_accuracies = list()
training_token_accuracies = list()

gold_validation_losses = list()
gold_validation_sequence_accuracies = list()
gold_validation_token_accuracies = list()

for i in range(EPOCHS):
    # print(f'Running epoch {i+1}...')

    training_loss, training_sequence_accuracy, training_token_accuracy = run_train_epoch(train_dataloader, model, loss_fn, optim, calculate_sequence_accuracy = True, calculate_token_accuracy = True)

    training_losses.append(training_loss)
    training_sequence_accuracies.append(training_sequence_accuracy)
    training_token_accuracies.append(training_token_accuracy)

    gold_val_loss, gold_val_sequence_accuracy, gold_val_token_accuracy = run_gold_validation_loop(val_dataloader, model, loss_fn, calculate_sequence_accuracy = True, calculate_token_accuracy = True)
    
    gold_validation_losses.append(gold_val_loss)
    gold_validation_sequence_accuracies.append(gold_val_sequence_accuracy)
    gold_validation_token_accuracies.append(gold_val_token_accuracy)

print(training_losses)
print(training_sequence_accuracies)
print(training_token_accuracies)

print()

print(gold_validation_losses)
print(gold_validation_sequence_accuracies)
print(gold_validation_token_accuracies)

100%|██████████| 157/157 [00:08<00:00, 17.94it/s]
100%|██████████| 16/16 [00:00<00:00, 54.36it/s]
100%|██████████| 157/157 [00:08<00:00, 18.76it/s]
100%|██████████| 16/16 [00:00<00:00, 54.11it/s]
100%|██████████| 157/157 [00:08<00:00, 18.84it/s]
100%|██████████| 16/16 [00:00<00:00, 54.20it/s]
100%|██████████| 157/157 [00:08<00:00, 19.01it/s]
100%|██████████| 16/16 [00:00<00:00, 54.17it/s]
100%|██████████| 157/157 [00:08<00:00, 19.03it/s]
100%|██████████| 16/16 [00:00<00:00, 54.26it/s]
100%|██████████| 157/157 [00:08<00:00, 18.86it/s]
100%|██████████| 16/16 [00:00<00:00, 54.26it/s]
100%|██████████| 157/157 [00:08<00:00, 18.91it/s]
100%|██████████| 16/16 [00:00<00:00, 54.18it/s]
100%|██████████| 157/157 [00:08<00:00, 18.79it/s]
100%|██████████| 16/16 [00:00<00:00, 53.90it/s]
100%|██████████| 157/157 [00:08<00:00, 18.86it/s]
100%|██████████| 16/16 [00:00<00:00, 54.05it/s]
100%|██████████| 157/157 [00:08<00:00, 18.97it/s]
100%|██████████| 16/16 [00:00<00:00, 53.15it/s]

[6.1481803527832035, 1.9719365125656128, 0.9654950741767884, 0.6140843582391738, 0.17663685646057128, 0.6911928603410721, 0.047026490390300754, 0.540928494977951, 0.2598152288377285, 0.0217185126721859]
[0.0023, 0.0103, 0.0239, 0.0287, 0.0452, 0.0298, 0.0522, 0.0365, 0.0442, 0.0513]
[0.4563645912910619, 0.5380749866503929, 0.5552391038851996, 0.560982013289544, 0.5669360372036288, 0.5599467394624312, 0.5695632254123396, 0.5617891763987481, 0.5664703748950142, 0.5691843390037379]

[9.066053771972657, 0.867771535873413, 0.6750355739593505, 0.5016322746276856, 0.37303468894958497, 0.1334613080024719, 0.05476113259792328, 0.7255270099639892, 0.028448798418045044, 0.02773188304901123]
[0.0, 0.028, 0.03, 0.032, 0.04, 0.049, 0.053, 0.025, 0.055, 0.051]
[0.4538116163546045, 0.5529709591134887, 0.5567443637753152, 0.5581295376385174, 0.5612342376767291, 0.5649598777225832, 0.565533053114253, 0.5545949560565533, 0.5661062285059228, 0.5657718761941154]





In [193]:
acc = run_autoregressive_validation_loop(val_dataloader, model)
print(acc)

 38%|███▊      | 6/16 [00:00<00:01,  5.95it/s]

['0' '0' '1' '1' '1' '1' '1' '2' '2' '5' '5' '5' '5' '5' '7' '8' '8' '8'
 '8' '8']
['0' '0' '1' '1' '1' '1' '1' '2' '2' '5' '5' '5' '5' '5' '7' '8' '8' '8'
 '8']
Source:              ['8' '1' '8' '0' '1' '7' '5' '1' '8' '1' '5' '5' '8' '2' '0' '5' '8' '2'
 '1' '5']
Predicted Target:    ['<SOS>' '0' '0' '1' '1' '1' '1' '1' '2' '2' '5' '5' '5' '5' '5' '7' '8'
 '8' '8' '8' '<EOS>' '<PAD>']



 50%|█████     | 8/16 [00:01<00:01,  5.94it/s]

['1' '2' '2' '2' '3' '3' '4' '4' '6' '6' '7' '8' '8' '8' '8' '8' '9' '9']
['1' '2' '2' '2' '3' '3' '4' '4' '6' '6' '7' '8' '8' '8' '8' '9' '9']
Source:              ['4' '9' '3' '6' '8' '8' '3' '6' '4' '2' '2' '8' '8' '1' '8' '2' '9' '7'
 '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '1' '2' '2' '2' '3' '3' '4' '4' '6' '6' '7' '8' '8' '8' '8' '9'
 '9' '<EOS>' '<PAD>' '<PAD>' '<PAD>']



 62%|██████▎   | 10/16 [00:01<00:01,  5.97it/s]

['0' '1' '1' '2' '3' '4' '4' '4' '4' '5' '5' '8' '8' '8' '8' '8' '9' '9'
 '9' '9']
['0' '1' '1' '2' '3' '4' '4' '4' '4' '5' '5' '8' '8' '8' '8' '9' '9' '9'
 '9']
Source:              ['2' '9' '0' '8' '9' '8' '9' '4' '1' '1' '8' '5' '4' '8' '4' '8' '5' '9'
 '3' '4']
Predicted Target:    ['<SOS>' '0' '1' '1' '2' '3' '4' '4' '4' '4' '5' '5' '8' '8' '8' '8' '9'
 '9' '9' '9' '<EOS>' '<PAD>']

['1' '1' '1' '3' '3']
['1' '1' '3' '1' '3']
Source:              ['1' '1' '1' '3' '3' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '1' '1' '3' '1' '3' '<EOS>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>']



 81%|████████▏ | 13/16 [00:02<00:00,  5.96it/s]

['0' '0' '1' '4' '5' '6' '6' '6' '6' '7' '8' '8' '8' '8' '8' '8' '9' '9'
 '9' '9']
['0' '0' '1' '4' '5' '6' '6' '6' '6' '7' '8' '8' '8' '8' '8' '9' '9' '9'
 '9']
Source:              ['8' '8' '4' '8' '0' '1' '6' '9' '8' '9' '6' '5' '0' '8' '9' '9' '7' '8'
 '6' '6']
Predicted Target:    ['<SOS>' '0' '0' '1' '4' '5' '6' '6' '6' '6' '7' '8' '8' '8' '8' '8' '9'
 '9' '9' '9' '<EOS>' '<PAD>']

['0' '1' '1' '1' '5' '6' '6' '6' '7' '7' '7' '8' '8' '8' '8' '8' '9' '9']
['0' '1' '1' '1' '5' '6' '6' '6' '7' '7' '7' '8' '8' '8' '8' '9' '9']
Source:              ['7' '7' '8' '8' '7' '9' '8' '8' '9' '0' '1' '6' '1' '5' '1' '6' '6' '8'
 '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '0' '1' '1' '1' '5' '6' '6' '6' '7' '7' '7' '8' '8' '8' '8' '9'
 '9' '<EOS>' '<PAD>' '<PAD>' '<PAD>']

['0' '1' '1' '2' '3' '3' '3' '4' '4' '5' '6' '6' '7' '8' '8' '8' '8' '8'
 '8' '9']
['0' '1' '1' '2' '3' '3' '3' '4' '4' '5' '6' '6' '7' '8' '8' '8' '8' '8'
 '9']
Source:              ['8' '1' '8' '3' '3' '8' '8' '7' '2' 

100%|██████████| 16/16 [00:02<00:00,  6.08it/s]

0.993



