TODO

1. Implement beam search
2. Write down hyperparameters and come up with experiments to run
3. Generate graphs for paper
4. Write paper sections
5. modularize code

In [2]:
!pip install datasets
!pip install rouge_score
!pip install transformers
!pip install torch

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata

In [3]:
!pip install numba

from numba import cuda
device = cuda.get_current_device()
device.reset()



In [4]:
from google.colab import drive
drive.mount('/content/drive/')

import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/Cambridge/L90')

Mounted at /content/drive/


In [5]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Nov 29 14:38:30 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    24W / 300W |      2MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
import numpy as np
import random
from tqdm import tqdm
import json
import math

import spacy
from torch.nn import Transformer
import torch
import torch.nn as nn
from torch import Tensor
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k, WikiText2
from typing import Iterable, List
from torch.utils.data import DataLoader, TensorDataset, dataset
from torch.cuda.amp import GradScaler, autocast  # For mixed precision training
from torch.profiler import profile, record_function, ProfilerActivity


from torch.nn.utils.rnn import pad_sequence

In [7]:
with open("data/train.json", 'r') as f:
    train_data = json.load(f)

with open("data/validation.json", 'r') as f:
    validation_data = json.load(f)

with open("data/test.json", 'r') as f:
    test_data = json.load(f)

# ================= LOAD DATASET ===========================
train_articles = [article['article'] for article in train_data]
train_summaries = [article['summary'] for article in train_data]

val_articles = [article['article'] for article in validation_data]
val_summaries = [article['summary'] for article in validation_data]


test_articles = [article['article'] for article in test_data]
test_summaries = [article['summary'] for article in test_data]

# ================= REDUCE SIZE ===========================
size_of_dataset = 5000
train_articles = train_articles[:size_of_dataset]
train_summaries = train_summaries[:size_of_dataset]

val_articles = val_articles[:int(size_of_dataset/5)]
val_summaries = val_summaries[:int(size_of_dataset/5)]


test_articles = test_articles[:int(size_of_dataset/5)]
test_summaries = test_summaries[:int(size_of_dataset/5)]

In [8]:
SEED_NUM = 42

torch.manual_seed(SEED_NUM)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED_NUM)

# Model parameters may need to be adjusted based on the specifics of the summarization task
EMB_SIZE = 256
NHEAD = 4
FFN_HID_DIM = 256
BATCH_SIZE = 8
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 1

In [9]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [10]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

### New Data Preparation

In [14]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

token_transform = get_tokenizer('spacy', language='en_core_web_sm')

def yield_tokens(data_iter: Iterable) -> List[str]:
    for data_sample in tqdm(data_iter, desc='Tokenizing data.'):
        yield token_transform(data_sample)[:(EMB_SIZE-2)]

# Create torchtext's Vocab object for English
vocab_transform = build_vocab_from_iterator(yield_tokens(train_articles),
                                            min_freq=1,
                                            specials=special_symbols,
                                            special_first=True)

vocab_transform.set_default_index(UNK_IDX)

Tokenizing data.: 100%|██████████| 5000/5000 [00:18<00:00, 275.46it/s]


In [15]:
# Adjusting the model for English only
SRC_VOCAB_SIZE = TGT_VOCAB_SIZE = len(vocab_transform)
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)



In [16]:
len(vocab_transform)

52516

In [17]:
# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
# Function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    tensor = torch.cat((torch.tensor([BOS_IDX]),
                        torch.tensor(token_ids)))
    # Ensure the tensor is at most 254 in length to accommodate EOS
    if len(tensor) >= 255:
        tensor = tensor[:255]
    tensor = torch.cat((tensor, torch.tensor([EOS_IDX])))
    # Pad if less than 256
    final_tensor = torch.cat((tensor, torch.tensor([PAD_IDX] * (256 - len(tensor))))) if len(tensor) < 256 else tensor
    return final_tensor


# src and tgt language text transforms
text_transform = sequential_transforms(token_transform,  # Tokenization
                                       vocab_transform,  # Numericalization
                                       tensor_transform) # Add BOS/EOS and create tensor


def pad(tensor, length):
    if len(tensor) < length:
        return torch.cat([tensor, torch.full((length - len(tensor),), PAD_IDX)])
    return tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_processed = text_transform(src_sample)
        tgt_processed = text_transform(tgt_sample)

        # Truncate if longer than 256, else pad
        src_batch.append(src_processed[:EMB_SIZE] if len(src_processed) > EMB_SIZE else pad(src_processed, EMB_SIZE))
        tgt_batch.append(tgt_processed[:EMB_SIZE] if len(tgt_processed) > EMB_SIZE else pad(tgt_processed, EMB_SIZE))

    # Convert lists to tensors
    src_batch = torch.stack(src_batch)
    tgt_batch = torch.stack(tgt_batch)
    return src_batch, tgt_batch

In [18]:
# Creating DataLoader instances
train_dataloader = DataLoader([(article, summary) for article, summary in zip(train_articles, train_summaries)],
                        batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0, shuffle=True)

# Creating DataLoader instances
val_dataloader = DataLoader([(article, summary) for article, summary in zip(val_articles, val_summaries)],
                        batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0, shuffle=True)

In [21]:
for step, (src, tgt) in enumerate(tqdm(train_dataloader, desc='Training')):
    for t in src:
        print(t)
        print(len(t))
    break

Training:   0%|          | 0/625 [00:00<?, ?it/s]

tensor([    2,    68,     4,  2168, 16754,     6,  1717,  2853,  6737,     4,
        17413,   249,    35,  1035,  1099,     7,  4935,    10,     5,  1205,
          545,     5,   357,  3582,   135,    50,    39,   600,    11,     8,
        12416,    62,  1723,   474,    73,     6,    32,    13,   430,   160,
            4,  1502,  6318,     9,  6360,  1575, 20359,  2232,  2673,    31,
         1323,    14, 22837,     6, 15801,    10,  2650,     7,    35,  1035,
            7,  9313,    15,     7,  5868,     4,    76,  3114,  1169,  5174,
            5,   780,     7,  7409,     5,  4863,    84,   230,    23, 15907,
           72,  2192,    10,  2803, 50198,     4,    86,  2834,    27,    32,
           13,    62, 44686,    73,     7,   531,  3342,  1099,     7,     5,
         1205,   120,    50,    93,   938,  1937,     6,   119,   592,  4077,
            6, 23570,   881,    10,  9313,  1024,  3476,  5868,     4,   127,
        20359,    27,  2411,     6,  2995,   403,     9,     5, 




### Old Data Preparation

In [32]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

token_transform = get_tokenizer('spacy', language='en_core_web_sm')

def yield_tokens(data_iter: Iterable) -> List[str]:
    for data_sample in tqdm(data_iter, desc='Tokenizing data.'):
        yield token_transform(data_sample)

# Create torchtext's Vocab object for English
vocab_transform = build_vocab_from_iterator(yield_tokens(train_articles),
                                            min_freq=1,
                                            specials=special_symbols,
                                            special_first=True)

vocab_transform.set_default_index(UNK_IDX)

Tokenizing data.: 100%|██████████| 5000/5000 [00:17<00:00, 284.86it/s]


In [33]:
len(vocab_transform)

86565

In [34]:
# Adjusting the model for English only
SRC_VOCAB_SIZE = TGT_VOCAB_SIZE = len(vocab_transform)
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)



In [35]:
# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = sequential_transforms(token_transform, #Tokenization
                                    vocab_transform, #Numericalization
                                    tensor_transform) # Add BOS/EOS and create tensor

def pad(tensor, length):
    if len(tensor) < length:
        return torch.cat([tensor, torch.full((length - len(tensor),), PAD_IDX)])
    return tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform(src_sample))
        tgt_batch.append(text_transform(tgt_sample))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [36]:
# Creating DataLoader instances
train_dataloader = DataLoader([(article, summary) for article, summary in zip(train_articles, train_summaries)],
                        batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0, shuffle=True)

# Creating DataLoader instances
val_dataloader = DataLoader([(article, summary) for article, summary in zip(val_articles, val_summaries)],
                        batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0, shuffle=True)

In [40]:
for step, (src, tgt) in enumerate(tqdm(train_dataloader, desc='Training')):

    print(len(src))
    break

Training:   0%|          | 0/625 [00:00<?, ?it/s]

2043





### Training

In [16]:
torch.cuda.empty_cache()

In [17]:
def train_epoch(model, train_dataloader, optimizer, scaler, grad_accumulate_steps=2):
    model.train()
    losses = 0

    optimizer.zero_grad()

    for step, (src, tgt) in enumerate(tqdm(train_dataloader, desc='Training')):
        with autocast():  # Mixed precision training
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)

            # assert src.shape[1] == EMB_SIZE, "To avoid unnecessary overhead, please truncate src tensors to the expected embedding size."
            # assert tgt.shape[1] == EMB_SIZE, "To avoid unnecessary overhead, please truncate tgt tensors to the expected embedding size."
            # assert src.shape[0] == BATCH_SIZE, f"Batch size mismatch (src). Current src batch size: {src.shape[0]}, expected batch size: {BATCH_SIZE}"
            # assert tgt.shape[0] == BATCH_SIZE, f"Batch size mismatch (tgt). Current tgt batch size: {tgt.shape[0]}, expected batch size: {BATCH_SIZE}"

            tgt_input = tgt[:-1, :]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

            logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
            tgt_out = tgt[1:, :]
            loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            loss = loss / grad_accumulate_steps  # Normalize loss

        scaler.scale(loss).backward()  # Scaled backpropagation

        if (step + 1) % grad_accumulate_steps == 0:
            scaler.step(optimizer)  # Update optimizer
            scaler.update()  # Update scaler
            optimizer.zero_grad()
            torch.cuda.empty_cache()  # Clear GPU cache

        losses += loss.item()

    return losses / len(train_dataloader)

def evaluate(model, val_dataloader):
    model.eval()
    losses = 0

    with torch.no_grad():  # No gradient computation in evaluation
        for src, tgt in tqdm(val_dataloader, desc='Validating'):
            src = src.to(DEVICE)
            tgt = tgt.to(DEVICE)

            tgt_input = tgt[:-1, :]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

            logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
            tgt_out = tgt[1:, :]
            loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            losses += loss.item()

    return losses / len(val_dataloader)

In [18]:
def pretty_print(epoch, train_loss, val_loss, epoch_time):
    print(f"\n{'Epoch':<10}{'Train Loss':<20}{'Val Loss':<20}{'Epoch Time (s)':<20}")
    print("-" * 70)
    print(f"{epoch:<10}{train_loss:<20.3f}{val_loss:<20.3f}{epoch_time:<20.3f}\n")

In [41]:
from timeit import default_timer as timer
NUM_EPOCHS = 15
GRAD_ACCUMULATE_STEPS = 2  # Adjust as per your requirement
scaler = GradScaler()  # For mixed precision

train_losses = []
val_losses = []
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, train_dataloader, optimizer, scaler, GRAD_ACCUMULATE_STEPS)
    end_time = timer()
    val_loss = evaluate(transformer, val_dataloader)
    epoch_time = end_time - start_time
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    pretty_print(epoch, train_loss, val_loss, epoch_time)

Training: 100%|██████████| 625/625 [00:38<00:00, 16.32it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 26.28it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
1         4.195               7.124               38.296              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.70it/s]
Validating: 100%|██████████| 125/125 [00:05<00:00, 24.64it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
2         3.489               6.817               37.431              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.61it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 30.88it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
3         3.319               6.676               37.649              



Training: 100%|██████████| 625/625 [00:36<00:00, 16.89it/s]
Validating: 100%|██████████| 125/125 [00:05<00:00, 22.68it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
4         3.197               6.564               37.010              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.57it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 31.03it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
5         3.094               6.465               37.727              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.57it/s]
Validating: 100%|██████████| 125/125 [00:05<00:00, 21.43it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
6         3.006               6.417               37.723              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.77it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 29.34it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
7         2.929               6.340               37.274              



Training: 100%|██████████| 625/625 [00:38<00:00, 16.30it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 25.61it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
8         2.860               6.310               38.343              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.81it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 29.03it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
9         2.796               6.277               37.206              



Training: 100%|██████████| 625/625 [00:38<00:00, 16.10it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 29.81it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
10        2.738               6.285               38.826              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.88it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 25.91it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
11        2.683               6.251               37.035              



Training: 100%|██████████| 625/625 [00:38<00:00, 16.35it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 29.42it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
12        2.630               6.247               38.234              



Training: 100%|██████████| 625/625 [00:38<00:00, 16.42it/s]
Validating: 100%|██████████| 125/125 [00:05<00:00, 21.27it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
13        2.578               6.251               38.080              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.70it/s]
Validating: 100%|██████████| 125/125 [00:04<00:00, 30.47it/s]



Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
14        2.527               6.289               37.433              



Training: 100%|██████████| 625/625 [00:37<00:00, 16.53it/s]
Validating: 100%|██████████| 125/125 [00:05<00:00, 22.82it/s]


Epoch     Train Loss          Val Loss            Epoch Time (s)      
----------------------------------------------------------------------
15        2.477               6.302               37.832              






In [None]:
import matplotlib.pyplot as plt

# Creating the plot
plt.figure(figsize=(10, 6))
plt.plot(range(1, NUM_EPOCHS+1), train_losses, label='Training Loss')
plt.plot(range(1, NUM_EPOCHS+1), val_losses, label='Validation Loss')
plt.title('Training vs Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

### Saving Model

In [30]:
torch.save(transformer, 'models/testing_the_nans.pt')

In [26]:
# transformer = torch.load('models/30_epoch_emb_size_256_half_dataset.pt')

### Inference and Decoding

In [23]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size):
    # Initialize the beam with the start symbol and an initial score
    initial_beam = (torch.tensor([start_symbol]), 0.0)  # (sequence, score)
    beams = [initial_beam]

    for _ in range(max_len):
        new_beams = []
        for beam in beams:
            # Expand the current beam
            seq, score = beam
            if seq[-1] == EOS_IDX:
                # If the sequence is finished, pass it through
                new_beams.append(beam)
                continue

            # Get probabilities of next words
            # This part depends on your model architecture and may need modification
            prob = get_prob_from_model(model, seq, src, src_mask)

            # Choose top `beam_size` continuations
            topk_prob, topk_indices = torch.topk(prob, beam_size)

            for prob, word_idx in zip(topk_prob, topk_indices):
                new_seq = torch.cat([seq, word_idx.view(1)])
                new_score = score + torch.log(prob)
                new_beams.append((new_seq, new_score))

        # Sort all new beams and select the top `beam_size`
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]

    # Choose the sequence with the highest score
    best_seq, _ = max(beams, key=lambda x: x[1])
    return best_seq


def get_prob_from_model(model, seq, src, src_mask):
    # Ensure the model is in evaluation mode, which turns off layers like dropout
    model.eval()

    # Encode the source sentence
    memory = model.encode(src, src_mask)
    # Add batch dimension to sequence for compatibility with the model
    seq = seq.unsqueeze(1)

    # Create a target mask for the sequence
    tgt_mask = generate_square_subsequent_mask(seq.size(0)).to(DEVICE)

    # Decode the sequence
    output = model.decode(seq, memory, tgt_mask)

    # Convert the output to probabilities using the generator
    prob = model.generator(output[-1])

    return torch.softmax(prob, dim=-1).squeeze(0)


# actual function summarize a piece of text
def summarize(model: torch.nn.Module, src_sentence: str, decoding_method: str = 'greedy'):
    model.eval()
    src = text_transform(src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    max_len = 50

    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    if decoding_method == 'greedy':
        tgt_tokens = greedy_decode(model, src, src_mask, max_len, start_symbol=BOS_IDX)


    elif decoding_method == 'beam_search':
        beam_size = 5
        # Assuming beam_size is defined or passed as an argument
        tgt_tokens = beam_search_decode(
            model,
            src,
            src_mask,
            max_len,
            start_symbol=BOS_IDX,
            beam_size=beam_size
        )

    else:
        raise ValueError("Invalid decoding method. Choose 'greedy' or 'beam_search'.")

    tgt_tokens = tgt_tokens.flatten()
    outputted_text = vocab_transform.lookup_tokens(list(tgt_tokens.cpu().numpy()))
    return " ".join(outputted_text).replace("<bos>", "").replace("<eos>", "")

In [43]:
article = """
Playing computer games such as Angry Birds teaches children important life skills including concentration, resilience and problem solving, an academic has said. Professor Angela Mcfarlane, an education expert who will become head of training body the College of Teachers next month, said many games were complex and required deep learning and lateral thinking to solve them. Prof Mcfarlane said she herself had become 'hooked' on the Lemmings computer game, as well as Angry Birds, and said such games could have a place in the classroom provided they were used under supervision. Professor Angela Mcfarlane says computer games like Angry Birds can teach children valuable life-skills . Expert: Prof Mcfarlane says the games can help children learn problem solving, resilience and concentration . She said: 'There are many computer games that require quite deep learning to master the games. 'Some of that learning applies beyond games to wider life, such as concentration, problem solving, and resilience - important life skills. 'Anyone who has tried to play complex video games will know they are difficult.' Speaking to The Times, Prof Mcfarlane said she had developed an obsession with both Angry Birds and a precursor, Lemmings, because they had made her think and get her strategy right. The education expert, who has advised the government on educational technology, and who is currently writing a book, Authentic Learning for the Digital Generation, said computer games could be used in the classroom to good effect provided it was done properly. Prof Mcfarlane said even pre-school children could benefit from games, as long as they were supervised and not just given a phone to play with to keep them quiet. Prof Mcfarlane said she herself had become 'hooked' on a computer game called Lemmings, pictured . She said some games could teach children fine motor control, or help with vocabulary or simple maths, and taught skills such as resilience that could be applicable to real life. Next month Prof Mcfarlane, who began her career as a secondary school teacher and head of department, will become chief executive and registrar of the College of Teachers, which offers professional training to teachers and support staff.
"""

summarize(transformer, article, decoding_method='greedy')

" The former former Court of the first time of the world 's mother . <unk> The former Court of her mother - old son 's mother 's mother . <unk> The couple have been charged with a ' I 'm not be ' <unk> The couple have been '"

In [29]:
# def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
#     sos_idx = tokenizer_tgt.token_to_id('[SOS]')
#     eos_idx = tokenizer_tgt.token_to_id('[EOS]')

#     # Precompute the encoder output and reuse it for every step
#     encoder_output = model.encode(source, source_mask)
#     # Initialize the decoder input with the sos token
#     decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

#     # Create a candidate list
#     candidates = [(decoder_initial_input, 1)]

#     while True:

#         # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
#         if any([cand.size(1) == max_len for cand, _ in candidates]):
#             break

#         # Create a new list of candidates
#         new_candidates = []

#         for candidate, score in candidates:

#             # Do not expand candidates that have reached the eos token
#             if candidate[0][-1].item() == eos_idx:
#                 continue

#             # Build the candidate's mask
#             candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)
#             # calculate output
#             out = model.decode(encoder_output, source_mask, candidate, candidate_mask)
#             # get next token probabilities
#             prob = model.project(out[:, -1])
#             # get the top k candidates
#             topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
#             for i in range(beam_size):
#                 # for each of the top k candidates, get the token and its probability
#                 token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
#                 token_prob = topk_prob[0][i].item()
#                 # create a new candidate by appending the token to the current candidate
#                 new_candidate = torch.cat([candidate, token], dim=1)
#                 # We sum the log probabilities because the probabilities are in log space
#                 new_candidates.append((new_candidate, score + token_prob))

#         # Sort the new candidates by their score
#         candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
#         # Keep only the top k candidates
#         candidates = candidates[:beam_size]

#         # If all the candidates have reached the eos token, stop
#         if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
#             break

#     # Return the best candidate
#     return candidates[0][0].squeeze()

In [45]:
for test_article in test_articles[:10]:

    summary = summarize(transformer, test_article, decoding_method='greedy')
    print(test_article)
    print(summary, end='\n\n')

The mother and daughter who survived a tragic car accident this week, which saw three children die, have been reunited. Aluel Manyang was moved from the intensive care unit at the Royal Children's about 5.15pm on Friday, and greeted her distraught mother, Akon Goode, with a 'big hug', her father said. 'She didn't believe that her mum was still alive,' Joseph Manyang said, according to the Herald Sun. Scroll down for videos . Aueel Manyang, pictured here as a baby with her mother Akon Guode, believes her three siblings who died in the crash at a Melbourne lake were eaten by crocodiles in the water . Ms Guode visited her daughter for the first time but did not stay the night in the hospital. Mr Manyang said his daughter was expected to make a '100 per cent' recovery and she should be allowed to go home within four days. The five-year-old girl who survived when a car driven by her mother plunged into a lake believes her three siblings who died in the crash were eaten by crocodiles. Aluel 

### Misc Testing

In [20]:
model = torch.load('models/30_epoch_emb_size_256_half_dataset.pt')

In [24]:
article = """
Playing computer games such as Angry Birds teaches children important life skills including concentration, resilience and problem solving, an academic has said. Professor Angela Mcfarlane, an education expert who will become head of training body the College of Teachers next month, said many games were complex and required deep learning and lateral thinking to solve them. Prof Mcfarlane said she herself had become 'hooked' on the Lemmings computer game, as well as Angry Birds, and said such games could have a place in the classroom provided they were used under supervision. Professor Angela Mcfarlane says computer games like Angry Birds can teach children valuable life-skills . Expert: Prof Mcfarlane says the games can help children learn problem solving, resilience and concentration . She said: 'There are many computer games that require quite deep learning to master the games. 'Some of that learning applies beyond games to wider life, such as concentration, problem solving, and resilience - important life skills. 'Anyone who has tried to play complex video games will know they are difficult.' Speaking to The Times, Prof Mcfarlane said she had developed an obsession with both Angry Birds and a precursor, Lemmings, because they had made her think and get her strategy right. The education expert, who has advised the government on educational technology, and who is currently writing a book, Authentic Learning for the Digital Generation, said computer games could be used in the classroom to good effect provided it was done properly. Prof Mcfarlane said even pre-school children could benefit from games, as long as they were supervised and not just given a phone to play with to keep them quiet. Prof Mcfarlane said she herself had become 'hooked' on a computer game called Lemmings, pictured . She said some games could teach children fine motor control, or help with vocabulary or simple maths, and taught skills such as resilience that could be applicable to real life. Next month Prof Mcfarlane, who began her career as a secondary school teacher and head of department, will become chief executive and registrar of the College of Teachers, which offers professional training to teachers and support staff.
"""

summarize(model, article, decoding_method='greedy')

' The new book is part of her new book on Twitter . <unk> She says she will be happy with reading comprehension and leisure areas of thousands of her mother . <unk> The pair have been seen at her eyes , she says she will be able to be'