In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import sys
import os
from pathlib import Path

# Get the parent directory (i.e. project root)
project_root = Path().resolve().parent.parent 
sys.path.insert(0, str(project_root))

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from functools import partial

from tqdm import tqdm

from transformer.pre_training.sort_integer_lists.dataset import RandomIntegerDataset

from src.embedding import CustomEmbedding
from src.transformer import EncoderDecoderTransformer
from src.utils import padding_collate_fn

In [8]:
n_real_tokens = 10
PAD_TOKEN_IDX = n_real_tokens
SOS_TOKEN_IDX = n_real_tokens + 1
EOS_TOKEN_IDX = n_real_tokens + 2
vocab_size = n_real_tokens + 3
D_MODEL = 64

embeddings = CustomEmbedding(vocab_size, d_model = D_MODEL) # 3 = PAD, SOS, EOS

indices = torch.tensor([1,9])

# print(embeddings.embeddings.weight)
print(embeddings.embeddings(indices))

tensor([[-2.3593e+00, -1.8738e+00, -1.4160e+00,  1.8563e-02, -2.2517e-01,
         -1.0900e+00,  5.5094e-01, -2.6020e-01, -1.8901e+00,  2.8605e+00,
          1.5376e+00,  7.5922e-02, -2.2348e+00, -1.5930e+00,  5.2896e-01,
         -7.8548e-01, -9.2142e-01, -5.8698e-01, -1.5417e+00,  2.2931e+00,
          1.7263e+00, -1.1216e-01, -1.2678e+00, -9.0122e-01,  2.2178e+00,
          8.6897e-01,  5.4969e-01,  7.9022e-01,  7.3506e-01, -1.9632e+00,
         -3.4712e-01, -1.3287e-01, -2.2265e+00,  3.7048e-01,  4.8923e-01,
         -7.2138e-01,  1.0277e+00, -5.5461e-01,  2.5897e-01,  2.2498e-01,
         -1.0654e+00,  1.3920e+00, -5.3598e-01,  8.4914e-01,  1.1571e-01,
          3.2136e-01,  1.4248e+00, -2.9901e-01, -9.3666e-01,  5.2934e-01,
          1.2361e+00, -1.1005e+00, -1.6959e+00, -1.7556e+00,  9.2442e-02,
         -4.5121e-01,  2.2787e-01,  1.8150e-01, -3.1161e-03, -9.6059e-01,
          3.4410e-01,  5.5738e-01,  6.0925e-01,  6.4350e-01],
        [-3.3884e-01, -9.4837e-01,  1.2291e+00, -7

In [9]:
MAX_CONTEXT_WINDOW = 50

BATCH_SIZE = 64
MIN_SEQ_LEN = 2
MAX_SEQ_LEN = min(20, MAX_CONTEXT_WINDOW)

NUM_TRAINING_SEQUENCES = 10000
NUM_VALIDATION_SEQUENCES = 1000

VOCAB = [i for i in range(n_real_tokens)] # does not include SOS, EOS, PAD

VOCAB_MAP = dict()

for i, token in enumerate(VOCAB):
    VOCAB_MAP[i] = token
VOCAB_MAP[len(VOCAB_MAP)] = '<PAD>'
VOCAB_MAP[len(VOCAB_MAP) + 1] = '<SOS>'
VOCAB_MAP[len(VOCAB_MAP) + 2] = '<EOS>'

train_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_TRAINING_SEQUENCES, VOCAB)
train_dataloader = DataLoader(train_rand_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

val_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_VALIDATION_SEQUENCES, VOCAB)
val_dataloader = DataLoader(val_rand_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

In [10]:
input, label = next(iter(train_dataloader))
print(input[0])
print(input[1])
print(label)

tensor([[ 9,  7,  9,  ..., 10, 10, 10],
        [ 9,  2,  8,  ..., 10, 10, 10],
        [ 8,  7,  6,  ..., 10, 10, 10],
        ...,
        [ 5,  2,  9,  ...,  5, 10, 10],
        [ 7,  1,  3,  ..., 10, 10, 10],
        [ 6,  8,  8,  ..., 10, 10, 10]])
tensor([[11,  0,  0,  ..., 10, 10, 10],
        [11,  0,  1,  ..., 10, 10, 10],
        [11,  0,  2,  ..., 10, 10, 10],
        ...,
        [11,  0,  0,  ...,  9, 10, 10],
        [11,  1,  1,  ..., 10, 10, 10],
        [11,  1,  3,  ..., 10, 10, 10]])
tensor([[ 0,  0,  5,  ..., 10, 10, 10],
        [ 0,  1,  1,  ..., 10, 10, 10],
        [ 0,  2,  2,  ..., 10, 10, 10],
        ...,
        [ 0,  0,  2,  ..., 12, 10, 10],
        [ 1,  1,  2,  ..., 10, 10, 10],
        [ 1,  3,  5,  ..., 10, 10, 10]])


In [11]:
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_IDX, reduction = 'sum')

model = EncoderDecoderTransformer(
                    embeddings = embeddings, 
                    vocab_size = vocab_size, 
                    d_model = D_MODEL, 
                    num_attention_heads = 4, 
                    num_encoder_layers = 2, 
                    num_decoder_layers = 2, 
                    dim_feedforward = 32, 
                    dropout = 0.0,
                    max_context_window = MAX_CONTEXT_WINDOW,
                    use_pre_lnorm = True)

optim = torch.optim.SGD(params = model.parameters(), lr = 1e-4, momentum = 0.9, weight_decay = 1e-4)

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
def greedy_decode(source: torch.Tensor, model: nn.Module) -> torch.Tensor:
    """
    Designed to do autoregressive inference in an Encoder-Decoder transformer.

    This greedy decoder always predicts the vocabulary token corresponding to the highest logit.

    Takes the source sequence and the Encoder-Decoder to produce a predicted
    sequence starting from <SOS>.

    Note: This function *can* handle batches of sequences.

    Args:
        source - The source sequence to be passed to the Transformer's encoder block.
        model - The Encoder-Decoder transformer over which to greedy decode

    Returns:
        target - The batch of predicted sequences corresponding to the input sources.
        target_logits - (# batch elements =) batch_size (# rows =) seq_len (# cols =) vocab-dimensional vectors, each of which
        corresponds to the set of logits on a particular inference step within a given sequence.
    """
    batch_size = source.size(dim = 0)

    encoder_output, source_pad_mask = model.encode(source)

    # target will contain num_batch sequences of indices that are the predicted next-words for each batch element
    target = torch.full((batch_size, 1), SOS_TOKEN_IDX) # target.shape: [batch_size, num_loops_complete - 1]
    target_logits = torch.zeros((batch_size, 1, vocab_size))

    finished = torch.full((batch_size, ), False)

    while not finished.all() and target.size(dim = 1) <= MAX_CONTEXT_WINDOW:

        decoder_output, _ = model.decode(target, encoder_output, source_pad_mask)
        pred_logits = model.project_into_vocab(decoder_output) # pred_logits.shape: [batch_size, seq_len, vocab_size]

        last_row_pred_logits = pred_logits[:, -1, :] # last_row_pred_logits.shape == [batch_size, vocab_size]

        # Track next-word logits for loss_fn later.
        target_logits = torch.concat((target_logits, last_row_pred_logits.unsqueeze(1)), dim = 1)

        predictions = torch.argmax(last_row_pred_logits, dim = -1) # predictions.shape: [batch_size]

        # For any finished sequences (i.e. previous EOS-producers), force their prediction from this round to be a pad.
        predictions[finished] = PAD_TOKEN_IDX

        # Mark any additional sequences that just produced an EOS as finished.
        finished |= predictions == EOS_TOKEN_IDX

        target = torch.concat((target, predictions.reshape(-1, 1)), dim = 1) # target.shape: [batch_size, num_loops_complete]

    return target, target_logits[:, 1:, :]

In [13]:
def run_train_epoch(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, optimizer: torch.optim.Optimizer, calculate_sequence_accuracy: bool = False, calculate_token_accuracy: bool = False):
    """
    Runs one training epoch (processing the entire training dataset once).
    
    Uses Teacher Forcing to train token-to-token mapping quality without cascading errors and for parallelization.

    Args:
        dataloader - The dataloader to process the dataset in BATCH_SIZE batches
        model - The Encoder-Decoder that is being trained
        loss_fn - The loss function to calculate the model's correctness
        optimizer - The optimizer to improve the model's weights
        calculate_sequence_accuracy - A flag to mark whether sequence-level correctness should be tracked
        calculate_token_accuracy - A flag to mark whether token-level correctness should be tracked
    """
    model.train()

    num_sequences = len(dataloader.dataset)
    num_tokens = 0

    epoch_loss = 0.0
    total_correct_sequences = 0
    total_correct_tokens = 0

    for (source, target), label in tqdm(dataloader):

        # FORWARD
        pred_logits = model(source, target)

        # pred_logits.shape: [batch_size, seq_len, vocab_size]
        # label.shape: [batch_size, seq_len]

        # CrossEntropyLoss (loss_fn) only takes 2D predictions (n_batch * seq_len, vocab_size) and 1D labels (n_batch * seq_len)
        batch_loss = loss_fn(pred_logits.view(-1, pred_logits.size(-1)), label.view(-1))

        # LOG
        with torch.no_grad():
            epoch_loss += batch_loss.item()

            predictions = torch.argmax(pred_logits, dim = -1) # predictions.shape: [batch_size, seq_len]
            match_matrix = torch.eq(predictions, label)

            if calculate_sequence_accuracy:
                num_correct_sequences = torch.all(match_matrix, dim = 1).sum()
                total_correct_sequences += num_correct_sequences.item()

            if calculate_token_accuracy:
                num_correct_tokens = match_matrix.sum()      
                total_correct_tokens += num_correct_tokens.item()

                num_tokens += torch.numel(label)

        # BACKWARD
        batch_loss.backward()

        # OPTIMIZE
        optimizer.step()
        optimizer.zero_grad()

    average_epoch_loss = epoch_loss / num_sequences
    average_epoch_sequence_accuracy = total_correct_sequences / num_sequences if calculate_sequence_accuracy else None
    average_epoch_token_accuracy = total_correct_tokens / num_tokens if calculate_token_accuracy else None

    return average_epoch_loss, average_epoch_sequence_accuracy, average_epoch_token_accuracy

def run_gold_validation_loop(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, calculate_sequence_accuracy: bool = False, calculate_token_accuracy: bool = False):
    """
    Runs one validation epoch (processing the entire validation dataset once). 

    Uses Teacher Forcing (i.e. "gold") to evaluate token-to-token mapping quality and for parallelization.

    Args:
        dataloader - The dataloader to process the dataset in BATCH_SIZE batches
        model - The Encoder-Decoder that is being trained
        loss_fn - The loss function to calculate the model's correctness
        calculate_sequence_accuracy - A flag to mark whether sequence-level correctness should be tracked
        calculate_token_accuracy - A flag to mark whether token-level correctness should be tracked
    """
    model.eval()

    num_sequences = len(dataloader.dataset)
    num_tokens = 0

    epoch_loss = 0.0
    total_correct_sequences = 0
    total_correct_tokens = 0

    with torch.no_grad():
        
        for (source, target), label in tqdm(dataloader):
            
            # FORWARD
            pred_logits = model(source, target)
            batch_loss = loss_fn(pred_logits.view(-1, pred_logits.size(-1)), label.view(-1))

            # LOG
            epoch_loss += batch_loss.item()

            predictions = torch.argmax(pred_logits, dim = -1) # predictions.shape: [batch_size, seq_len]
            match_matrix = torch.eq(predictions, label)

            if calculate_sequence_accuracy:
                num_correct_sequences = torch.all(match_matrix, dim = 1).sum()
                total_correct_sequences += num_correct_sequences.item()

            if calculate_token_accuracy:
                num_correct_tokens = match_matrix.sum()      
                total_correct_tokens += num_correct_tokens.item()

                num_tokens += torch.numel(label)

    average_epoch_loss = epoch_loss / num_sequences
    average_epoch_sequence_accuracy = total_correct_sequences / num_sequences if calculate_sequence_accuracy else None
    average_epoch_token_accuracy = total_correct_tokens / num_tokens if calculate_token_accuracy else None

    return average_epoch_loss, average_epoch_sequence_accuracy, average_epoch_token_accuracy

def run_autoregressive_validation_loop(dataloader: DataLoader, model: nn.Module):
    """
    Runs one autoregressive validation epoch (processing the entire validation dataset once). 

    Args:
        dataloader - The dataloader to process the dataset in BATCH_SIZE batches
        model - The Encoder-Decoder that is being trained
    """
    model.eval()

    correct_sequences = 0
    incorrect_sequences = 0
    total_sequences = 0

    with torch.no_grad():
        
        for (source, _), label in tqdm(dataloader):

            # FORWARD
            pred_indices, pred_logits = greedy_decode(source, model)

            np_source_indices = source.numpy().copy()
            np_pred_target_indices = pred_indices.numpy().copy()

            token_values = np.array(list(VOCAB_MAP.values()))
            predicted_source_tokens = token_values[np_source_indices]
            predicted_target_tokens = token_values[np_pred_target_indices]

            for s, t in zip(predicted_source_tokens, predicted_target_tokens):
                source_end_index = np.argmax(s == '<PAD>')  if '<PAD>' in s else len(s)
                target_end_index = np.argmax(t == '<EOS>')
                if np.array_equal(np.sort(s[:source_end_index]), t[1:target_end_index]):
                    correct_sequences += 1
                else:
                    incorrect_sequences += 1
                    print(f'Incorrect Sequence {incorrect_sequences}:')
                    print(np.sort(s[:source_end_index]))
                    print(t[1:target_end_index])
                    print(f'{'Source:':<20} {s}\n{'Predicted Target:':<20} {t}', end = '\n\n')

            total_sequences += predicted_target_tokens.shape[0]

    return correct_sequences / total_sequences


In [14]:
EPOCHS = 10

training_losses = list()
training_sequence_accuracies = list()
training_token_accuracies = list()

gold_validation_losses = list()
gold_validation_sequence_accuracies = list()
gold_validation_token_accuracies = list()

for i in range(EPOCHS):
    # print(f'Running epoch {i+1}...')

    training_loss, training_sequence_accuracy, training_token_accuracy = run_train_epoch(train_dataloader, model, loss_fn, optim, calculate_sequence_accuracy = True, calculate_token_accuracy = True)

    training_losses.append(training_loss)
    training_sequence_accuracies.append(training_sequence_accuracy)
    training_token_accuracies.append(training_token_accuracy)

    gold_val_loss, gold_val_sequence_accuracy, gold_val_token_accuracy = run_gold_validation_loop(val_dataloader, model, loss_fn, calculate_sequence_accuracy = True, calculate_token_accuracy = True)
    
    gold_validation_losses.append(gold_val_loss)
    gold_validation_sequence_accuracies.append(gold_val_sequence_accuracy)
    gold_validation_token_accuracies.append(gold_val_token_accuracy)

print(training_losses)
print(training_sequence_accuracies)
print(training_token_accuracies)

print()

print(gold_validation_losses)
print(gold_validation_sequence_accuracies)
print(gold_validation_token_accuracies)

100%|██████████| 157/157 [00:09<00:00, 17.38it/s]
100%|██████████| 16/16 [00:00<00:00, 47.70it/s]
100%|██████████| 157/157 [00:08<00:00, 19.19it/s]
100%|██████████| 16/16 [00:00<00:00, 57.21it/s]
100%|██████████| 157/157 [00:08<00:00, 19.59it/s]
100%|██████████| 16/16 [00:00<00:00, 57.31it/s]
100%|██████████| 157/157 [00:08<00:00, 19.19it/s]
100%|██████████| 16/16 [00:00<00:00, 57.17it/s]
100%|██████████| 157/157 [00:08<00:00, 18.27it/s]
100%|██████████| 16/16 [00:00<00:00, 53.58it/s]
100%|██████████| 157/157 [00:09<00:00, 17.12it/s]
100%|██████████| 16/16 [00:00<00:00, 53.71it/s]
100%|██████████| 157/157 [00:08<00:00, 18.89it/s]
100%|██████████| 16/16 [00:00<00:00, 55.43it/s]
100%|██████████| 157/157 [00:08<00:00, 19.05it/s]
100%|██████████| 16/16 [00:00<00:00, 55.37it/s]
100%|██████████| 157/157 [00:09<00:00, 16.94it/s]
100%|██████████| 16/16 [00:00<00:00, 43.63it/s]
100%|██████████| 157/157 [00:08<00:00, 19.12it/s]
100%|██████████| 16/16 [00:00<00:00, 51.35it/s]


[6.349227062988281, 1.903295444869995, 0.8003001699447632, 0.8503367011547088, 0.6524146330356598, 0.24151200650930404, 0.2364796200990677, 0.6866460898399352, 0.2824481308877468, 0.033657983401417735]
[0.0015, 0.0135, 0.027, 0.026, 0.0351, 0.0474, 0.0465, 0.0388, 0.0494, 0.0563]
[0.4545991915185722, 0.541941856913797, 0.5598696248856359, 0.5582389613360788, 0.5629674145299145, 0.5699728779891512, 0.5696750248072666, 0.5632823830241966, 0.5698223156285823, 0.57234180596901]

[5.321473709106446, 3.8494616088867186, 1.3720016479492188, 0.49469689750671386, 0.23472629165649414, 1.3920778617858887, 4.330578262329102, 4.869108871459961, 0.0588050457239151, 0.016656220108270645]
[0.0, 0.007, 0.024, 0.039, 0.036, 0.023, 0.001, 0.005, 0.048, 0.05]
[0.4797193387158785, 0.5456074586697424, 0.5594482891195693, 0.576316801230296, 0.5804978854286813, 0.5600249903883122, 0.5266243752402922, 0.5505094194540562, 0.5827566320645905, 0.5831891580161477]


In [15]:
acc = run_autoregressive_validation_loop(val_dataloader, model)
print(acc)

 50%|█████     | 8/16 [00:01<00:01,  5.95it/s]

Incorrect Sequence 1:
['0' '0' '0' '8']
['0' '0' '0' '0' '8' '8']
Source:              ['0' '0' '0' '8' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '0' '0' '0' '0' '8' '8' '<EOS>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']



 69%|██████▉   | 11/16 [00:01<00:00,  6.06it/s]

Incorrect Sequence 2:
['2' '3' '3' '3' '3' '4']
['2' '3' '3' '3' '3' '3' '4']
Source:              ['2' '3' '3' '3' '4' '3' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '2' '3' '3' '3' '3' '3' '4' '<EOS>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']



100%|██████████| 16/16 [00:02<00:00,  6.03it/s]

0.998



