In [172]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [173]:
import sys
import os
from pathlib import Path

# Get the parent directory (i.e. project root)
project_root = Path().resolve().parent.parent 
sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from functools import partial

from tqdm import tqdm

from tokenization.byte_pair_encoding.get_tokenizers import train_and_save_tokenizer_for, load_tokenizer_from

from pre_training.text_summarization.dataset import TextSummarizationDataset

from src.embedding import CustomEmbedding
from src.transformer import EncoderDecoderTransformer
from src.utils import padding_collate_fn

In [174]:
DF_DATA_PATH = '../../data/SAMSum/'

BPE_IN_PATH = '../../data/SAMSum/train_summary_and_dialogue.txt'
BPE_OUT_PATH = '../../tokenization/trained_tokenizers/SAMSum_BPE'

In [175]:
MAX_CONTEXT_WINDOW = 100

BATCH_SIZE = 2

In [176]:
train_df = pd.read_json(DF_DATA_PATH + 'train_df.json', orient = 'records', lines = True)
val_df = pd.read_json(DF_DATA_PATH + 'val_df.json', orient = 'records', lines = True)
test_df = pd.read_json(DF_DATA_PATH + 'test_df.json', orient = 'records', lines = True)

In [177]:
bpe_tokenizer = train_and_save_tokenizer_for(in_file_paths = [BPE_IN_PATH], out_file_dir_path = BPE_OUT_PATH, vocab_size = 4_000)
pretrained_bpe_tokenizer = load_tokenizer_from(dir_path = BPE_OUT_PATH, model_max_length = 10000)

PAD_TOKEN_IDX = pretrained_bpe_tokenizer.pad_token_id
print(f'The pad token index is {PAD_TOKEN_IDX}.')




The pad token index is 2.


In [178]:
FILTER_tokenized_train_sources = pretrained_bpe_tokenizer(
    train_df['dialogue'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_train_targets = pretrained_bpe_tokenizer(
    train_df['summary'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_val_sources = pretrained_bpe_tokenizer(
    val_df['dialogue'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_val_targets = pretrained_bpe_tokenizer(
    val_df['summary'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_test_sources = pretrained_bpe_tokenizer(
    test_df['dialogue'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

FILTER_tokenized_test_targets = pretrained_bpe_tokenizer(
    test_df['summary'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

valid_src_train_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_train_sources.data['input_ids']])
valid_src_val_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_val_sources.data['input_ids']])
valid_src_test_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW for example in FILTER_tokenized_test_sources.data['input_ids']])

valid_tgt_train_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW - 1 for example in FILTER_tokenized_train_targets.data['input_ids']])
valid_tgt_val_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW - 1 for example in FILTER_tokenized_val_targets.data['input_ids']])
valid_tgt_test_indices = np.array([len(example) <= MAX_CONTEXT_WINDOW - 1 for example in FILTER_tokenized_test_targets.data['input_ids']])

valid_train_df = train_df.iloc[valid_src_train_indices & valid_tgt_train_indices]
valid_val_df = val_df.iloc[valid_src_val_indices & valid_tgt_val_indices]
valid_test_df = test_df.iloc[valid_src_test_indices & valid_tgt_test_indices]

print(f'With a max_context_window of {MAX_CONTEXT_WINDOW}...')
print(f'The number of training samples went from {train_df.shape[0]} to {valid_train_df.shape[0]}')
print(f'The number of validation samples went from {val_df.shape[0]} to {valid_val_df.shape[0]}')
print(f'The number of test samples went from {test_df.shape[0]} to {valid_test_df.shape[0]}')

With a max_context_window of 100...
The number of training samples went from 14732 to 5580
The number of validation samples went from 818 to 325
The number of test samples went from 819 to 308


In [179]:
tokenized_train_sources = pretrained_bpe_tokenizer(
    valid_train_df['dialogue'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_train_targets = pretrained_bpe_tokenizer(
    ('<SOS> ' + valid_train_df['summary']).tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_train_labels = pretrained_bpe_tokenizer(
    (valid_train_df['summary'] + ' <EOS>').tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_val_sources = pretrained_bpe_tokenizer(
    valid_val_df['dialogue'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_val_targets = pretrained_bpe_tokenizer(
    ('<SOS> ' + valid_val_df['summary']).tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_val_labels = pretrained_bpe_tokenizer(
    (valid_val_df['summary'] + ' <EOS>').tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_test_sources = pretrained_bpe_tokenizer(
    valid_test_df['dialogue'].tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_test_targets = pretrained_bpe_tokenizer(
    ('<SOS> ' + valid_test_df['summary']).tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

tokenized_test_labels = pretrained_bpe_tokenizer(
    (valid_test_df['summary'] + ' <EOS>').tolist(),
    add_special_tokens = False,
    return_attention_mask = False,
    return_token_type_ids = False
)

In [180]:
train_ds = TextSummarizationDataset(tokenized_train_sources.data['input_ids'], tokenized_train_targets.data['input_ids'], tokenized_train_labels.data['input_ids'])
val_ds = TextSummarizationDataset(tokenized_val_sources.data['input_ids'], tokenized_val_targets.data['input_ids'], tokenized_val_labels.data['input_ids'])
test_ds = TextSummarizationDataset(tokenized_test_sources.data['input_ids'], tokenized_test_targets.data['input_ids'], tokenized_test_labels.data['input_ids'])

# NOTE: Option to use HuggingFace DataCollatorWithPadding : requires changing TextSummarizationDataset __getitem__
train_dataloader = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))
val_dataloader = DataLoader(val_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))
test_dataloader = DataLoader(test_ds, batch_size = BATCH_SIZE, collate_fn = partial(padding_collate_fn, pad_token_idx = PAD_TOKEN_IDX))

In [181]:
(source, target), label = next(iter(train_dataloader))
print(source)
print(target)
print(label)

([1665, 30, 1798, 16, 448, 314, 263, 1676, 302, 401, 327, 3329, 588, 2071, 287, 313, 365, 81, 18, 206, 203, 1665, 30, 1164, 286, 2296, 476, 324, 896, 1675, 80, 299, 390, 16, 286, 381, 466, 263, 1170, 1457, 18, 206, 203, 1665, 30, 487, 351, 527, 709, 523, 286, 518, 837, 2547, 307, 440, 1043, 280, 353, 304, 912, 275, 1232, 337, 275, 582, 27, 436, 18, 206, 203, 867, 30, 273, 419, 898, 304, 586, 273, 400, 612, 365, 403, 756, 18, 206, 203, 867, 30, 812, 5, 206, 203, 1665, 30, 610, 811, 18], [0, 421, 301, 2446, 1798, 346, 316, 585, 2547, 476, 324, 275, 1675, 80, 299, 390, 317, 1664, 280, 466, 263, 1170, 1457, 327, 3329, 588, 2071, 287, 313, 18, 1798, 351, 898, 304, 586, 316, 314, 612, 365, 538, 756, 18]) [1665, 2446, 1798, 346, 316, 585, 2547, 476, 324, 275, 1675, 80, 299, 390, 317, 1664, 280, 466, 263, 1170, 1457, 327, 3329, 588, 2071, 287, 313, 18, 1798, 351, 898, 304, 586, 316, 314, 612, 365, 538, 756, 18, 225, 1]
([748, 30, 496, 263, 311, 858, 281, 206, 203, 748, 30, 263, 1173, 501, 206,