In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
from pathlib import Path

# Get the parent directory (i.e. project root)
project_root = Path().resolve().parent.parent 
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from functools import partial

from pre_training_encoder_decoder.sort_integer_lists.dataset import RandomIntegerDataset

from src.embedding import CustomEmbedding
from src.transformer import EncoderDecoderTransformer
from src.utils import padding_collate_fn

from src.train_utils import run_train_epoch
from src.validation_utils import run_gold_validation_loop, run_autoregressive_validation_loop

In [13]:
n_real_tokens = 10
PAD_TOKEN_IDX = n_real_tokens
SOS_TOKEN_IDX = n_real_tokens + 1
EOS_TOKEN_IDX = n_real_tokens + 2
vocab_size = n_real_tokens + 3
D_MODEL = 64

embeddings = CustomEmbedding(vocab_size, d_model = D_MODEL) # 3 = PAD, SOS, EOS

In [14]:
MAX_CONTEXT_WINDOW = 50

BATCH_SIZE = 64
MIN_SEQ_LEN = 2
MAX_SEQ_LEN = min(20, MAX_CONTEXT_WINDOW)

NUM_TRAINING_SEQUENCES = 10000
NUM_VALIDATION_SEQUENCES = 1000

VOCAB = [i for i in range(n_real_tokens)] # does not include SOS, EOS, PAD

VOCAB_MAP = dict()

for i, token in enumerate(VOCAB):
    VOCAB_MAP[i] = token
VOCAB_MAP[len(VOCAB_MAP)] = '<PAD>'
VOCAB_MAP[len(VOCAB_MAP) + 1] = '<SOS>'
VOCAB_MAP[len(VOCAB_MAP) + 2] = '<EOS>'

train_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_TRAINING_SEQUENCES, VOCAB)
train_dataloader = DataLoader(train_rand_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

val_rand_ds = RandomIntegerDataset(MIN_SEQ_LEN, MAX_SEQ_LEN, NUM_VALIDATION_SEQUENCES, VOCAB)
val_dataloader = DataLoader(val_rand_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

In [15]:
input, label = next(iter(train_dataloader))
print(input[0])
print(input[1])
print(label)

tensor([[ 7,  6,  4,  ..., 10, 10, 10],
        [ 1,  3,  6,  ..., 10, 10, 10],
        [ 6,  8,  5,  ..., 10, 10, 10],
        ...,
        [ 1,  2, 10,  ..., 10, 10, 10],
        [ 1,  3,  1,  ...,  3,  7, 10],
        [ 3,  8,  2,  ..., 10, 10, 10]])
tensor([[11,  4,  6,  ..., 10, 10, 10],
        [11,  1,  3,  ..., 10, 10, 10],
        [11,  5,  6,  ..., 10, 10, 10],
        ...,
        [11,  1,  2,  ..., 10, 10, 10],
        [11,  0,  0,  ...,  7,  9, 10],
        [11,  0,  0,  ..., 10, 10, 10]])
tensor([[ 4,  6,  7,  ..., 10, 10, 10],
        [ 1,  3,  6,  ..., 10, 10, 10],
        [ 5,  6,  6,  ..., 10, 10, 10],
        ...,
        [ 1,  2, 12,  ..., 10, 10, 10],
        [ 0,  0,  1,  ...,  9, 12, 10],
        [ 0,  0,  1,  ..., 10, 10, 10]])


In [16]:
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_IDX, reduction = 'sum')

model = EncoderDecoderTransformer(
                    embeddings = embeddings, 
                    vocab_size = vocab_size, 
                    d_model = D_MODEL, 
                    num_attention_heads = 4, 
                    num_encoder_layers = 2, 
                    num_decoder_layers = 2, 
                    dim_feedforward = 32, 
                    dropout = 0.0,
                    max_context_window = MAX_CONTEXT_WINDOW,
                    use_pre_lnorm = True)

optim = torch.optim.SGD(params = model.parameters(), lr = 1e-4, momentum = 0.9, weight_decay = 1e-4)

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
EPOCHS = 10

training_losses = list()
training_sequence_accuracies = list()
training_token_accuracies = list()

gold_validation_losses = list()
gold_validation_sequence_accuracies = list()
gold_validation_token_accuracies = list()

for i in range(EPOCHS):
    # print(f'Running epoch {i+1}...')

    training_loss, training_sequence_accuracy, training_token_accuracy = run_train_epoch(train_dataloader, model, loss_fn, optim, calculate_sequence_accuracy = True, calculate_token_accuracy = True)

    training_losses.append(training_loss)
    training_sequence_accuracies.append(training_sequence_accuracy)
    training_token_accuracies.append(training_token_accuracy)

    gold_val_loss, gold_val_sequence_accuracy, gold_val_token_accuracy = run_gold_validation_loop(val_dataloader, model, loss_fn, calculate_sequence_accuracy = True, calculate_token_accuracy = True)
    
    gold_validation_losses.append(gold_val_loss)
    gold_validation_sequence_accuracies.append(gold_val_sequence_accuracy)
    gold_validation_token_accuracies.append(gold_val_token_accuracy)

print(training_losses)
print(training_sequence_accuracies)
print(training_token_accuracies)

print()

print(gold_validation_losses)
print(gold_validation_sequence_accuracies)
print(gold_validation_token_accuracies)

100%|██████████| 157/157 [00:07<00:00, 20.34it/s]
100%|██████████| 16/16 [00:00<00:00, 60.96it/s]
100%|██████████| 157/157 [00:07<00:00, 21.45it/s]
100%|██████████| 16/16 [00:00<00:00, 60.75it/s]
100%|██████████| 157/157 [00:07<00:00, 21.20it/s]
100%|██████████| 16/16 [00:00<00:00, 55.14it/s]
100%|██████████| 157/157 [00:07<00:00, 19.92it/s]
100%|██████████| 16/16 [00:00<00:00, 62.05it/s]
100%|██████████| 157/157 [00:07<00:00, 22.38it/s]
100%|██████████| 16/16 [00:00<00:00, 62.77it/s]
100%|██████████| 157/157 [00:07<00:00, 22.07it/s]
100%|██████████| 16/16 [00:00<00:00, 58.50it/s]
100%|██████████| 157/157 [00:07<00:00, 21.76it/s]
100%|██████████| 16/16 [00:00<00:00, 62.95it/s]
100%|██████████| 157/157 [00:06<00:00, 22.51it/s]
100%|██████████| 16/16 [00:00<00:00, 54.11it/s]
100%|██████████| 157/157 [00:07<00:00, 20.12it/s]
100%|██████████| 16/16 [00:00<00:00, 58.94it/s]
100%|██████████| 157/157 [00:07<00:00, 20.69it/s]
100%|██████████| 16/16 [00:00<00:00, 43.74it/s]

[6.073650161743164, 1.9356828886032105, 1.0216213793754578, 0.4279505677700043, 0.2879812387943268, 0.7921349413394928, 0.14866112573742868, 0.36519185576438906, 0.3517247172355652, 0.17542103811502457]
[0.0011, 0.0125, 0.024, 0.0386, 0.0425, 0.0294, 0.0479, 0.0448, 0.0446, 0.045]
[0.4569874532835024, 0.53721192002442, 0.5527929285223891, 0.5631155050003817, 0.5651566315628815, 0.5566821681786994, 0.5673946645294252, 0.5643936503090895, 0.5643636086326547, 0.5663416075650118]

[3.8889976806640627, 1.5424590377807617, 0.9636995162963867, 0.38659665870666504, 0.6576802501678467, 0.48469909572601316, 0.05351661467552185, 4.776434967041015, 0.040406386137008664, 0.03617469418048859]
[0.001, 0.003, 0.016, 0.03, 0.018, 0.03, 0.05, 0.002, 0.051, 0.052]
[0.4915456629728697, 0.534772640427971, 0.5437523882307986, 0.5546904852884983, 0.5506782575468093, 0.5544516622086358, 0.5595147115017195, 0.51920137562094, 0.559753534581582, 0.5598012991975545]





In [21]:
special_token_idxs = {
    'SOS_TOKEN_IDX': SOS_TOKEN_IDX,
    'EOS_TOKEN_IDX': EOS_TOKEN_IDX,
    'PAD_TOKEN_IDX': PAD_TOKEN_IDX
}

acc = run_autoregressive_validation_loop(val_dataloader, model, VOCAB_MAP, special_token_idxs, MAX_CONTEXT_WINDOW)
print(acc)

 50%|█████     | 8/16 [00:01<00:01,  6.34it/s]

Incorrect Sequence 1:
['0' '4' '6' '7' '8' '8' '9' '9' '9' '9' '9']
['0' '4' '6' '7' '8' '8' '9' '9' '9' '9']
Source:              ['9' '9' '4' '9' '8' '6' '0' '8' '7' '9' '9' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '0' '4' '6' '7' '8' '8' '9' '9' '9' '9' '<EOS>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']

Incorrect Sequence 2:
['0' '0' '1' '1' '1' '8' '9']
['0' '0' '1' '1' '1' '1' '8' '9']
Source:              ['9' '0' '1' '1' '1' '8' '0' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '0' '0' '1' '1' '1' '1' '8' '9' '<EOS>' '<PAD>' '<PAD>' '<PAD>'
 '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>' '<PAD>']



 88%|████████▊ | 14/16 [00:02<00:00,  6.38it/s]

Incorrect Sequence 3:
['0' '0' '1' '1' '1' '1' '1' '2' '2' '2' '2' '2' '2' '7' '7' '9' '9' '9']
['0' '0' '1' '1' '1' '1' '1' '2' '2' '2' '2' '2' '7' '7' '9' '9' '9']
Source:              ['9' '0' '7' '0' '2' '2' '2' '1' '7' '1' '1' '9' '2' '2' '2' '1' '9' '1'
 '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '0' '0' '1' '1' '1' '1' '1' '2' '2' '2' '2' '2' '7' '7' '9' '9'
 '9' '<EOS>' '<PAD>' '<PAD>' '<PAD>']



100%|██████████| 16/16 [00:02<00:00,  6.33it/s]

Incorrect Sequence 4:
['2' '2' '2' '2' '2' '2' '2' '2' '2' '2' '3' '3' '3' '4' '6' '7' '8' '8']
['2' '2' '2' '2' '2' '2' '2' '2' '2' '2' '2' '3' '3' '3' '4' '6' '7' '8'
 '8']
Source:              ['7' '2' '2' '8' '3' '2' '3' '2' '2' '3' '4' '2' '6' '8' '2' '2' '2' '2'
 '<PAD>' '<PAD>']
Predicted Target:    ['<SOS>' '2' '2' '2' '2' '2' '2' '2' '2' '2' '2' '2' '3' '3' '3' '4' '6'
 '7' '8' '8' '<EOS>' '<PAD>']

0.996



