In [249]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [250]:
import sys
from pathlib import Path

# Get the parent directory (i.e. project root)
project_root = Path().resolve().parent.parent 
sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from functools import partial

from tokenization.byte_pair_encoding.get_tokenizers import train_and_save_tokenizer_for, load_tokenizer_from

from pre_training.text_summarization.dataset import TextSummarizationDataset

from src.embedding import CustomEmbedding
from src.transformer import EncoderDecoderTransformer
from src.utils import padding_collate_fn

from src.train_utils import run_train_epoch
from src.validation_utils import run_gold_validation_loop, run_autoregressive_validation_loop

In [251]:
DF_DATA_PATH = '../../data/SAMSum/'

BPE_IN_PATH = '../../data/SAMSum/train_summary_and_dialogue.txt'
BPE_OUT_PATH = '../../tokenization/trained_tokenizers/SAMSum_BPE'

In [252]:
MAX_CONTEXT_WINDOW = 100

BATCH_SIZE = 64

D_MODEL = 16

In [253]:
train_df = pd.read_json(DF_DATA_PATH + 'train_df.json', orient = 'records', lines = True)
val_df = pd.read_json(DF_DATA_PATH + 'val_df.json', orient = 'records', lines = True)
test_df = pd.read_json(DF_DATA_PATH + 'test_df.json', orient = 'records', lines = True)

In [254]:
bpe_tokenizer = train_and_save_tokenizer_for(in_file_paths = [BPE_IN_PATH], out_file_dir_path = BPE_OUT_PATH, vocab_size = 4_000)
pretrained_bpe_tokenizer = load_tokenizer_from(dir_path = BPE_OUT_PATH, model_max_length = 10000)

VOCAB_SIZE = pretrained_bpe_tokenizer.vocab_size
PAD_TOKEN_IDX = pretrained_bpe_tokenizer.pad_token_id

print(f'The vocab size is {VOCAB_SIZE}.')
print(f'The pad token index is {PAD_TOKEN_IDX}.')




The vocab size is 4000.
The pad token index is 2.


In [255]:
embeddings = CustomEmbedding(VOCAB_SIZE, D_MODEL)

In [256]:
def normalize_prefix_space(texts: list[str], include_SOS: bool = False):
    return [('<SOS>' if include_SOS else '') + ' ' + text.lstrip() for text in texts]

In [257]:
FILTER_tokenized_train_sources = pretrained_bpe_tokenizer(
    normalize_prefix_space(train_df['dialogue'].tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_train_targets = pretrained_bpe_tokenizer(
    normalize_prefix_space(train_df['summary'].tolist(), include_SOS = True),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_val_sources = pretrained_bpe_tokenizer(
    normalize_prefix_space(val_df['dialogue'].tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_val_targets = pretrained_bpe_tokenizer(
    normalize_prefix_space(val_df['summary'].tolist(), include_SOS = True),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_test_sources = pretrained_bpe_tokenizer(
    normalize_prefix_space(test_df['dialogue'].tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_test_targets = pretrained_bpe_tokenizer(
    normalize_prefix_space(test_df['summary'].tolist(), include_SOS = True),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

valid_src_train_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_train_sources.data['input_ids']])
valid_src_val_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_val_sources.data['input_ids']])
valid_src_test_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_test_sources.data['input_ids']])

valid_tgt_train_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_train_targets.data['input_ids']])
valid_tgt_val_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_val_targets.data['input_ids']])
valid_tgt_test_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_test_targets.data['input_ids']])

valid_train_df = train_df.iloc[valid_src_train_indices & valid_tgt_train_indices]
valid_val_df = val_df.iloc[valid_src_val_indices & valid_tgt_val_indices]
valid_test_df = test_df.iloc[valid_src_test_indices & valid_tgt_test_indices]

print(f'With a max_context_window of {MAX_CONTEXT_WINDOW}...')
print(f'The number of training samples went from {train_df.shape[0]} to {valid_train_df.shape[0]}')
print(f'The number of validation samples went from {val_df.shape[0]} to {valid_val_df.shape[0]}')
print(f'The number of test samples went from {test_df.shape[0]} to {valid_test_df.shape[0]}')

With a max_context_window of 100...
The number of training samples went from 14732 to 5561
The number of validation samples went from 818 to 325
The number of test samples went from 819 to 306


In [258]:
tokenized_train_sources = pretrained_bpe_tokenizer(
    normalize_prefix_space(valid_train_df['dialogue'].tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_train_targets = pretrained_bpe_tokenizer(
    normalize_prefix_space(valid_train_df['summary'].tolist(), include_SOS = True),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_train_labels = pretrained_bpe_tokenizer(
    normalize_prefix_space((valid_train_df['summary'] + '<EOS>').tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_val_sources = pretrained_bpe_tokenizer(
    normalize_prefix_space(valid_val_df['dialogue'].tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_val_targets = pretrained_bpe_tokenizer(
    normalize_prefix_space(valid_val_df['summary'].tolist(), include_SOS = True),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_val_labels = pretrained_bpe_tokenizer(
    normalize_prefix_space((valid_val_df['summary'] + '<EOS>').tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_test_sources = pretrained_bpe_tokenizer(
    normalize_prefix_space(valid_test_df['dialogue'].tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_test_targets = pretrained_bpe_tokenizer(
    normalize_prefix_space(valid_test_df['summary'].tolist(), include_SOS = True),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_test_labels = pretrained_bpe_tokenizer(
    normalize_prefix_space((valid_test_df['summary'] + '<EOS>').tolist()),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

In [259]:
train_ds = TextSummarizationDataset(tokenized_train_sources.data['input_ids'], tokenized_train_targets.data['input_ids'], tokenized_train_labels.data['input_ids'])
val_ds = TextSummarizationDataset(tokenized_val_sources.data['input_ids'], tokenized_val_targets.data['input_ids'], tokenized_val_labels.data['input_ids'])
test_ds = TextSummarizationDataset(tokenized_test_sources.data['input_ids'], tokenized_test_targets.data['input_ids'], tokenized_test_labels.data['input_ids'])

# NOTE: Option to use HuggingFace DataCollatorWithPadding : requires changing TextSummarizationDataset __getitem__
train_dataloader = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))
val_dataloader = DataLoader(val_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))
test_dataloader = DataLoader(test_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

In [260]:
(source, target), label = next(iter(train_dataloader))
print(source)
print(target)
print(label)

tensor([[ 367, 2504,   30,  ...,    2,    2,    2],
        [ 604,  799,   69,  ...,    2,    2,    2],
        [ 331, 2310,   30,  ...,    2,    2,    2],
        ...,
        [1306,   30,  674,  ...,    2,    2,    2],
        [ 988,   30, 1124,  ...,    2,    2,    2],
        [2763,   30,  680,  ...,    2,    2,    2]])
tensor([[   0,  367, 2504,  ...,    2,    2,    2],
        [   0,  331,  389,  ...,    2,    2,    2],
        [   0,  331, 2310,  ...,  365,  996,   18],
        ...,
        [   0,  328,  281,  ...,    2,    2,    2],
        [   0,  988,  317,  ...,    2,    2,    2],
        [   0, 1798,  309,  ...,    2,    2,    2]])
tensor([[ 367, 2504,   16,  ...,    2,    2,    2],
        [ 331,  389,  382,  ...,    2,    2,    2],
        [ 331, 2310,  509,  ...,  996,   18,    1],
        ...,
        [ 328,  281, 2046,  ...,    2,    2,    2],
        [ 988,  317,   78,  ...,    2,    2,    2],
        [1798,  309, 2763,  ...,    2,    2,    2]])


In [261]:
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_IDX, reduction = 'sum')

model = EncoderDecoderTransformer(
    embeddings = embeddings,
    vocab_size = VOCAB_SIZE,
    d_model = D_MODEL,
    num_attention_heads = 4,
    num_encoder_layers = 1,
    num_decoder_layers = 1,
    dim_feedforward = 32,
    dropout = 0.0,
    max_context_window = MAX_CONTEXT_WINDOW,
    use_pre_lnorm = True
)

optim = torch.optim.SGD(model.parameters(), lr = 1e-4, momentum = 0.9, weight_decay = 1e-4)

In [262]:
EPOCHS = 15

training_losses = list()
training_sequence_accuracies = list()
training_token_accuracies = list()

gold_validation_losses = list()
gold_validation_sequence_accuracies = list()
gold_validation_token_accuracies = list()

for i in range(EPOCHS):
    print(f'Running epoch {i+1}...')

    training_loss, training_sequence_accuracy, training_token_accuracy = run_train_epoch(train_dataloader, model, loss_fn, optim, calculate_sequence_accuracy = True, calculate_token_accuracy = True)

    training_losses.append(training_loss)
    training_sequence_accuracies.append(training_sequence_accuracy)
    training_token_accuracies.append(training_token_accuracy)

    gold_val_loss, gold_val_sequence_accuracy, gold_val_token_accuracy = run_gold_validation_loop(val_dataloader, model, loss_fn, calculate_sequence_accuracy = True, calculate_token_accuracy = True)
    
    gold_validation_losses.append(gold_val_loss)
    gold_validation_sequence_accuracies.append(gold_val_sequence_accuracy)
    gold_validation_token_accuracies.append(gold_val_token_accuracy)

print(training_losses)
print(training_sequence_accuracies)
print(training_token_accuracies)

print()

print(gold_validation_losses)
print(gold_validation_sequence_accuracies)
print(gold_validation_token_accuracies)

100%|██████████| 87/87 [00:07<00:00, 11.88it/s]
100%|██████████| 6/6 [00:00<00:00, 30.84it/s]
100%|██████████| 87/87 [00:06<00:00, 13.15it/s]
100%|██████████| 6/6 [00:00<00:00, 30.43it/s]
100%|██████████| 87/87 [00:06<00:00, 13.18it/s]
100%|██████████| 6/6 [00:00<00:00, 23.39it/s]
100%|██████████| 87/87 [00:06<00:00, 12.83it/s]
100%|██████████| 6/6 [00:00<00:00, 24.89it/s]
100%|██████████| 87/87 [00:06<00:00, 12.91it/s]
100%|██████████| 6/6 [00:00<00:00, 30.74it/s]
100%|██████████| 87/87 [00:06<00:00, 13.28it/s]
100%|██████████| 6/6 [00:00<00:00, 32.12it/s]
100%|██████████| 87/87 [00:06<00:00, 13.23it/s]
100%|██████████| 6/6 [00:00<00:00, 31.23it/s]
100%|██████████| 87/87 [00:06<00:00, 13.29it/s]
100%|██████████| 6/6 [00:00<00:00, 29.69it/s]
100%|██████████| 87/87 [00:06<00:00, 12.54it/s]
100%|██████████| 6/6 [00:00<00:00, 26.48it/s]
100%|██████████| 87/87 [00:06<00:00, 13.10it/s]
100%|██████████| 6/6 [00:00<00:00, 30.98it/s]
100%|██████████| 87/87 [00:06<00:00, 13.51it/s]
100%|███████

[144.41541609535156, 129.68088295930926, 125.27613029075256, 121.95349360361222, 119.49773086632912, 117.46702762542708, 115.83128429646759, 114.47991408849015, 113.25448303826313, 112.21780504714755, 111.26309912574739, 110.4387042397669, 109.54869712057184, 108.8436748392825, 108.14483206847915]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.03770503471947612, 0.051263797729779625, 0.056441929635937, 0.06067534980997495, 0.06301493431003177, 0.06598223826093083, 0.06738570721883995, 0.06877961502919723, 0.06942644992347707, 0.0701620149654899, 0.07182874259089037, 0.07183013575410399, 0.07264144991087906, 0.07253662853070846, 0.07379351367108254]

[134.98083364633413, 129.48837665264423, 125.80706486628605, 122.8987714092548, 121.08897536057692, 119.47872013972356, 118.6186923452524, 117.09996300330529, 116.32009446364182, 115.45841796875, 115.09649864783654, 114.21207688551682, 113.73935302734375, 113.24013953575721, 112.70538461538462]
[0.0, 0.0, 0.0


