In [1]:
!pip install rouge-score



In [2]:
!pip install datasets
!pip install transformers



In [3]:
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BartTokenizer, get_linear_schedule_with_warmup
from rouge_score import rouge_scorer
from tqdm import tqdm

In [4]:
max_src_len   = 512      
max_tgt_len   = 128      
batch_size    = 8
epochs        = 25
lr            = 0.0001
weight_decay  = 0.01
grad_clip     = 1.0
beam_size     = 4
warmup_ratio  = 0.1
checkpoint_dir = "checkpoint"

In [5]:
os.makedirs(checkpoint_dir, exist_ok=True)

In [6]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
vocab_size = tokenizer.vocab_size

In [7]:
def tokenize(batch):
    src = tokenizer(batch['text'], max_length=max_src_len, truncation=True, padding='max_length')
    tgt = tokenizer(batch['target'], max_length=max_tgt_len, truncation=True, padding='max_length')
    
    return {'input_ids': src.input_ids, 'attention_mask': src.attention_mask, 'labels': tgt.input_ids}

In [8]:
def get_loaders(dataset_path):
    splits = load_dataset(dataset_path, split={"train":"train","validation":"validation"})
    tokenized = splits.map(tokenize, batched=True)
    tokenized.set_format(type='torch', columns=['input_ids','attention_mask','labels'])
    train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True,persistent_workers=True)
    valid_loader = DataLoader(tokenized['validation'], batch_size=batch_size, num_workers=5, pin_memory=True)
    
    return train_loader, valid_loader

In [9]:
train_loader, valid_loader = get_loaders("nlplabtdtu/xlsum_en")

In [10]:
!nvidia-smi

Sat Apr 19 20:28:17 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:81:00.0 Off |                  Off |
|  0%   41C    P8             10W /  450W |       2MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [11]:
pad_token_id = tokenizer.pad_token_id

In [12]:
class TrainablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        self.pos = nn.Parameter(torch.zeros(max_len, d_model))
    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pos[:seq_len].unsqueeze(1)

In [13]:
class TransformerSummarizer(nn.Module):
    
    def __init__(self, vocab_size, d_model=512, nhead=8, enc_layers=3, dec_layers=3, dim_ff=2048, dropout=0.1):
        super().__init__()
        self.embed       = nn.Embedding(vocab_size, d_model)
        self.pos_enc     = TrainablePositionalEncoding(d_model)
        self.pos_dec     = TrainablePositionalEncoding(d_model)
        self.transformer = nn.Transformer(d_model, nhead, enc_layers, dec_layers, dim_ff, dropout, activation='gelu', batch_first=True)
        self.out_proj    = nn.Linear(d_model, vocab_size)
        
    def forward(self, src, tgt, src_attention_mask=None, tgt_attention_mask=None):

        pad_id = tokenizer.pad_token_id
        src_kpm = (src == pad_id)
        tgt_kpm = (tgt == pad_id)
        
        src_emb = self.pos_enc(self.embed(src) * math.sqrt(self.embed.embedding_dim))
        tgt_emb = self.pos_dec(self.embed(tgt) * math.sqrt(self.embed.embedding_dim))
        
        size = tgt.size(1)
        tgt_mask = torch.triu(torch.full((size, size), True,dtype=torch.bool), 1).to(device)
        
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_kpm,
                               tgt_key_padding_mask=tgt_kpm,
                               memory_key_padding_mask=src_kpm
                              )
        
        return self.out_proj(out)
        

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
model = TransformerSummarizer(vocab_size).to(device)

In [16]:
model

TransformerSummarizer(
  (embed): Embedding(50265, 512)
  (pos_enc): TrainablePositionalEncoding()
  (pos_dec): TrainablePositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): Transf

In [17]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

In [18]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='none')

In [19]:
total_steps = len(train_loader) * epochs
warmup_steps = int(total_steps * warmup_ratio)
    
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps
                                            )

In [20]:
print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

Available GPUs: 1
GPU 0: NVIDIA GeForce RTX 4090


In [21]:
def train_epoch(model, train_loader, optimizer, criterion, scheduler, device, vocab_size, pad_token_id, grad_clip=1.0):
    
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for batch in progress_bar:
        
        src = batch['input_ids'].to(device, non_blocking=True)
        src_attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        tgt = batch['labels'].to(device, non_blocking=True)
        
        tgt_inp, tgt_lbl = tgt[:, :-1], tgt[:, 1:]
        
        tgt_attention_mask = (tgt_inp != pad_token_id).to(device, non_blocking=True)
        
        loss_mask = (tgt_lbl != pad_token_id).float()
        
        optimizer.zero_grad()
        
        logits = model(
            src=src,
            tgt=tgt_inp,
            src_attention_mask=src_attention_mask,
            tgt_attention_mask=tgt_attention_mask
        )
        
        loss = criterion(logits.reshape(-1, vocab_size), tgt_lbl.reshape(-1))
        
        masked_loss = (loss * loss_mask.reshape(-1)).sum() / max(loss_mask.sum(), 1)
        
        masked_loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        
        optimizer.step()
        
        scheduler.step()
        
        current_loss = masked_loss.item()
        total_loss += current_loss
        progress_bar.set_postfix({"loss": f"{current_loss:.4f}"})
        
    average_loss = total_loss / len(train_loader)
    return average_loss

In [22]:
def evaluate(model, val_loader, criterion, device, vocab_size, pad_token_id, tokenizer):

    model.eval()
    
    total_loss = 0
    rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    progress_bar = tqdm(val_loader, desc="Validating")
    
    with torch.inference_mode():
        for batch in progress_bar:
            src = batch['input_ids'].to(device, non_blocking=True)
            src_attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            tgt = batch['labels'].to(device, non_blocking=True)
            tgt_inp, tgt_lbl = tgt[:, :-1], tgt[:, 1:]
            
            tgt_attention_mask = (tgt_inp != pad_token_id).to(device, non_blocking=True)
            
            loss_mask = (tgt_lbl != pad_token_id).float()
            
            logits = model(
                src=src,
                tgt=tgt_inp,
                src_attention_mask=src_attention_mask,
                tgt_attention_mask=tgt_attention_mask
            )
            
            loss = criterion(logits.reshape(-1, vocab_size), tgt_lbl.reshape(-1))
            masked_loss = (loss * loss_mask.reshape(-1)).sum() / max(loss_mask.sum(), 1)
            total_loss += masked_loss.item()
            
    avg_loss = total_loss / len(val_loader)
    
    return avg_loss

In [23]:
def train(model, epochs, train_loader, val_loader, vocab_size, tokenizer, optimizer, criterion, scheduler, pad_token_id, grad_clip):
    
    model = model.to(device, non_blocking=True)
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        train_loss = train_epoch(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            criterion=criterion,
            scheduler=scheduler,
            device=device,
            vocab_size=vocab_size,
            pad_token_id=pad_token_id,
            grad_clip=grad_clip
        )
        
        val_loss = evaluate(
            model=model,
            val_loader=val_loader,
            criterion=criterion,
            device=device,
            vocab_size=vocab_size,
            pad_token_id=pad_token_id,
            tokenizer=tokenizer
        )
        
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_loss': train_loss,
                    'val_loss': val_loss
                    }, f"{checkpoint_dir}/transformer_epoch_{epoch+1}.pt")
        print(f"Saved checkpoint for epoch {epoch+1}")

In [24]:
def beam_search_decode(model, src_ids, src_mask=None, max_len=100, beam_size=5, 
                       device='cpu', tokenizer=None, length_penalty=1.0, 
                       early_stopping=True):
    
    model.eval()
    batch_size = src_ids.size(0)
    
    bos_token_id = getattr(tokenizer, 'bos_token_id', tokenizer.cls_token_id)
    eos_token_id = getattr(tokenizer, 'eos_token_id', tokenizer.sep_token_id)
    
        
    if batch_size > 1:
        return [
            beam_search_decode(
                model, src_ids[i:i+1],
                None if src_mask is None else src_mask[i:i+1],
                max_len, beam_size, device, tokenizer, length_penalty, early_stopping
            )
            for i in range(batch_size)
        ]
    
    current_tokens = torch.full(
        (beam_size, 1), bos_token_id, dtype=torch.long, device=device
    )
    beam_scores = torch.zeros(beam_size, device=device)
    done_beams = [False] * beam_size
    
    expanded_src_ids  = src_ids.expand(beam_size, -1)
    expanded_src_mask = None if src_mask is None else src_mask.expand(beam_size, -1)
    
    for step in range(max_len - 1):
        tgt_mask = None
        if hasattr(model, 'generate_square_subsequent_mask'):
            tgt_mask = model.generate_square_subsequent_mask(
                current_tokens.size(1)
            ).to(device)

        with torch.inference_mode():   
            outputs = model(
                src=expanded_src_ids,
                tgt=current_tokens,
                src_attention_mask=expanded_src_mask,
                tgt_attention_mask=tgt_mask
            )
        
        next_token_logits = outputs[:, -1, :]
        next_token_logprobs = F.log_softmax(next_token_logits, dim=-1)
       
        vocab_size = next_token_logprobs.size(-1)
        expanded_scores = beam_scores.unsqueeze(1) + next_token_logprobs 
        flat_scores     = expanded_scores.view(-1)                      
        
        topk_scores, topk_indices = torch.topk(
            flat_scores, k=min(2 * beam_size, flat_scores.size(0))
        )
        beam_ix = topk_indices // vocab_size
        token_ix = topk_indices % vocab_size
        
        candidates = []
        for b, tok, sc in zip(beam_ix.tolist(), token_ix.tolist(), topk_scores.tolist()):
            if done_beams[b]:
                continue
            new_tokens = torch.cat([
                current_tokens[b],
                torch.tensor([tok], dtype=torch.long, device=device)
            ], dim=0)
            candidates.append({
                'tokens': new_tokens,
                'score': sc,
                'is_done': (tok == eos_token_id)
            })
            if len(candidates) >= beam_size:
                break
        
        if all(done_beams) and early_stopping:
            break
        
        while len(candidates) < beam_size:
            candidates.append(candidates[0])
        
        current_tokens = torch.stack([c['tokens'] for c in candidates])
        beam_scores    = torch.tensor([c['score'] for c in candidates], device=device)
        done_beams     = [c['is_done'] for c in candidates]
        
        if all(done_beams) and early_stopping:
            break
    
    seq_lens = current_tokens.size(1)
    adjusted_scores = beam_scores / (seq_lens ** length_penalty)
    best_idx  = adjusted_scores.argmax().item()
    best_tokens = current_tokens[best_idx].tolist()
    
    return tokenizer.decode(best_tokens, skip_special_tokens=True)

In [25]:
def greedy_decode(model, src, src_mask=None, max_len=100, device=None, tokenizer=None):
    
    model.eval()
    if device is None:
        device = src.device

    bos_token_id = getattr(tokenizer, 'bos_token_id', tokenizer.cls_token_id)
    eos_token_id = getattr(tokenizer, 'eos_token_id', tokenizer.sep_token_id)

    decoder_input = torch.tensor([[bos_token_id]], device=device)
    
    with torch.inference_mode():
        for _ in range(max_len - 1):
            logits = model(
                src=src,
                tgt=decoder_input,
                src_attention_mask=src_mask,
                tgt_attention_mask=None
            )
            next_logits = logits[:, -1, :]
            next_token  = next_logits.argmax(dim=-1, keepdim=True)
            decoder_input = torch.cat([decoder_input, next_token], dim=1)
            if next_token.item() == eos_token_id:
                break
    
    return decoder_input

In [26]:
def generate_summary(model, src_ids, src_mask=None, max_len=100, 
                     method="beam_search", beam_size=5, device=None):
                     
    if method == "greedy":
        gen_ids = greedy_decode(
            model, src_ids, src_mask, max_len, device, tokenizer
        )
        return tokenizer.decode(gen_ids[0], skip_special_tokens=True)

    elif method == "beam_search":
        return beam_search_decode(
            model, src_ids, src_mask, max_len, beam_size, device, tokenizer
        )

    else:
        raise ValueError(f"Unknown decoding method: {method}")

In [27]:
def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None):
    
    try:
        print(f"Loading checkpoint from {checkpoint_path}")
        
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Model weights loaded successfully")
        
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("Optimizer state restored")
        
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            print("Scheduler state restored")
        
        metrics = {
            'train_loss': checkpoint.get('train_loss'),
            'val_loss': checkpoint.get('val_loss'),
            'rouge_scores': checkpoint.get('rouge_scores')
        }
        
        epoch = checkpoint.get('epoch', -1) + 1  # +1 because we want to start from the next epoch
        
        print(f"Checkpoint loaded from epoch {epoch}")
        if 'val_loss' in checkpoint:
            print(f"Validation loss: {checkpoint['val_loss']:.4f}")
        
        return model, optimizer, scheduler, epoch, metrics
        
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        raise e

In [28]:
import warnings
warnings.filterwarnings("ignore", message="The PyTorch API of nested tensors is in prototype stage")

In [29]:
train(model, epochs, train_loader, valid_loader, vocab_size, tokenizer, optimizer, criterion, scheduler, pad_token_id, grad_clip)


Epoch 1/25


Training: 100%|██████████| 38316/38316 [23:40<00:00, 26.97it/s, loss=4.1639]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.31it/s]


Train Loss: 6.1565 | Val Loss: 4.8470
Saved checkpoint for epoch 1

Epoch 2/25


Training: 100%|██████████| 38316/38316 [23:39<00:00, 27.00it/s, loss=3.7121]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.38it/s]


Train Loss: 4.4710 | Val Loss: 3.9820
Saved checkpoint for epoch 2

Epoch 3/25


Training: 100%|██████████| 38316/38316 [23:17<00:00, 27.41it/s, loss=3.0504]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.22it/s]


Train Loss: 3.8809 | Val Loss: 3.6080
Saved checkpoint for epoch 3

Epoch 4/25


Training: 100%|██████████| 38316/38316 [23:27<00:00, 27.23it/s, loss=3.1621]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.21it/s]


Train Loss: 3.5465 | Val Loss: 3.4076
Saved checkpoint for epoch 4

Epoch 5/25


Training: 100%|██████████| 38316/38316 [23:28<00:00, 27.21it/s, loss=3.9386]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.21it/s]


Train Loss: 3.3396 | Val Loss: 3.2916
Saved checkpoint for epoch 5

Epoch 6/25


Training: 100%|██████████| 38316/38316 [23:36<00:00, 27.05it/s, loss=4.9369]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.62it/s]


Train Loss: 3.1951 | Val Loss: 3.2235
Saved checkpoint for epoch 6

Epoch 7/25


Training: 100%|██████████| 38316/38316 [23:24<00:00, 27.28it/s, loss=1.6823]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.14it/s]


Train Loss: 3.0816 | Val Loss: 3.1709
Saved checkpoint for epoch 7

Epoch 8/25


Training: 100%|██████████| 38316/38316 [23:33<00:00, 27.10it/s, loss=3.8146]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.43it/s]


Train Loss: 2.9858 | Val Loss: 3.1401
Saved checkpoint for epoch 8

Epoch 9/25


Training: 100%|██████████| 38316/38316 [23:40<00:00, 26.98it/s, loss=1.9707]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.46it/s]


Train Loss: 2.9049 | Val Loss: 3.1094
Saved checkpoint for epoch 9

Epoch 10/25


Training: 100%|██████████| 38316/38316 [24:13<00:00, 26.35it/s, loss=1.8389]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.85it/s]


Train Loss: 2.8303 | Val Loss: 3.0883
Saved checkpoint for epoch 10

Epoch 11/25


Training: 100%|██████████| 38316/38316 [24:27<00:00, 26.11it/s, loss=2.9861]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 70.23it/s]


Train Loss: 2.7626 | Val Loss: 3.0714
Saved checkpoint for epoch 11

Epoch 12/25


Training: 100%|██████████| 38316/38316 [23:53<00:00, 26.73it/s, loss=2.2004]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 70.58it/s]


Train Loss: 2.6987 | Val Loss: 3.0611
Saved checkpoint for epoch 12

Epoch 13/25


Training: 100%|██████████| 38316/38316 [23:38<00:00, 27.01it/s, loss=1.8714]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.14it/s]


Train Loss: 2.6389 | Val Loss: 3.0563
Saved checkpoint for epoch 13

Epoch 14/25


Training: 100%|██████████| 38316/38316 [23:29<00:00, 27.18it/s, loss=1.6905]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.36it/s]


Train Loss: 2.5823 | Val Loss: 3.0514
Saved checkpoint for epoch 14

Epoch 15/25


Training: 100%|██████████| 38316/38316 [23:32<00:00, 27.12it/s, loss=2.5579]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.19it/s]


Train Loss: 2.5269 | Val Loss: 3.0493
Saved checkpoint for epoch 15

Epoch 16/25


Training: 100%|██████████| 38316/38316 [23:42<00:00, 26.93it/s, loss=0.8846]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.74it/s]


Train Loss: 2.4739 | Val Loss: 3.0459
Saved checkpoint for epoch 16

Epoch 17/25


Training: 100%|██████████| 38316/38316 [23:39<00:00, 26.99it/s, loss=2.2711]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.45it/s]


Train Loss: 2.4232 | Val Loss: 3.0472
Saved checkpoint for epoch 17

Epoch 18/25


Training: 100%|██████████| 38316/38316 [23:44<00:00, 26.89it/s, loss=2.7493]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.39it/s]


Train Loss: 2.3729 | Val Loss: 3.0487
Saved checkpoint for epoch 18

Epoch 19/25


Training: 100%|██████████| 38316/38316 [23:51<00:00, 26.77it/s, loss=2.3175]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.87it/s]


Train Loss: 2.3248 | Val Loss: 3.0540
Saved checkpoint for epoch 19

Epoch 20/25


Training: 100%|██████████| 38316/38316 [24:03<00:00, 26.55it/s, loss=1.3998]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 68.84it/s]


Train Loss: 2.2777 | Val Loss: 3.0587
Saved checkpoint for epoch 20

Epoch 21/25


Training: 100%|██████████| 38316/38316 [24:02<00:00, 26.56it/s, loss=1.8405]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.42it/s]


Train Loss: 2.2322 | Val Loss: 3.0600
Saved checkpoint for epoch 21

Epoch 22/25


Training: 100%|██████████| 38316/38316 [23:49<00:00, 26.80it/s, loss=0.5197]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.85it/s]


Train Loss: 2.1881 | Val Loss: 3.0662
Saved checkpoint for epoch 22

Epoch 23/25


Training: 100%|██████████| 38316/38316 [24:12<00:00, 26.38it/s, loss=2.2824]
Validating: 100%|██████████| 1442/1442 [00:21<00:00, 68.57it/s]


Train Loss: 2.1457 | Val Loss: 3.0682
Saved checkpoint for epoch 23

Epoch 24/25


Training: 100%|██████████| 38316/38316 [23:55<00:00, 26.69it/s, loss=3.1005]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 69.61it/s]


Train Loss: 2.1064 | Val Loss: 3.0701
Saved checkpoint for epoch 24

Epoch 25/25


Training: 100%|██████████| 38316/38316 [23:57<00:00, 26.66it/s, loss=2.2588]
Validating: 100%|██████████| 1442/1442 [00:20<00:00, 70.04it/s]


Train Loss: 2.0708 | Val Loss: 3.0722
Saved checkpoint for epoch 25


In [37]:
def evaluate_model(model, val_loader, criterion, device, vocab_size, pad_token_id, tokenizer):
    model.eval()
    
    total_loss = 0
    rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    full_rouge_scores = []

    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    progress_bar = tqdm(val_loader, desc="Validating")
    
    with torch.inference_mode():
        for batch in progress_bar:
            src = batch['input_ids'].to(device, non_blocking=True)
            src_attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            tgt = batch['labels'].to(device, non_blocking=True)
            tgt_inp, tgt_lbl = tgt[:, :-1], tgt[:, 1:]
            
            tgt_attention_mask = (tgt_inp != pad_token_id).to(device, non_blocking=True)
            loss_mask = (tgt_lbl != pad_token_id).float()
            
            logits = model(
                src=src,
                tgt=tgt_inp,
                src_attention_mask=src_attention_mask,
                tgt_attention_mask=tgt_attention_mask
            )
            
            loss = criterion(logits.reshape(-1, vocab_size), tgt_lbl.reshape(-1))
            masked_loss = (loss * loss_mask.reshape(-1)).sum() / max(loss_mask.sum(), 1)
            total_loss += masked_loss.item()
            
            batch_size = src.size(0)
            for i in range(batch_size):
                summary = generate_summary(
                    model=model,
                    src_ids=src[i:i+1],
                    src_mask=src_attention_mask[i:i+1],
                    max_len=100,
                    method="beam_search",
                    beam_size=5,
                    device=device
                )
                
                reference = tokenizer.decode(
                    [t for t in tgt[i].tolist() if t != pad_token_id], 
                    skip_special_tokens=True
                )
                
                scores = scorer.score(reference, summary)
                full_rouge_scores.append({
                    'reference': reference,
                    'summary': summary,
                    'scores': scores
                })
                
                for k in rouge_scores:
                    rouge_scores[k].append(scores[k].fmeasure)
                
    avg_loss = total_loss / len(val_loader)
    avg_rouge = {k: sum(v)/max(len(v), 1) for k, v in rouge_scores.items()}
    
    return avg_loss, avg_rouge, full_rouge_scores


In [52]:
def evaluate_checkpoints(checkpoint_paths, model_class, model_kwargs, val_loader, 
                         criterion, device, vocab_size, pad_token_id, tokenizer):
    results = []

    for ckpt_path in checkpoint_paths:
        print(f"\n--- Evaluating {ckpt_path} ---")
        
        model = model_class(**model_kwargs).to(device)
        model, _, _, epoch, metrics = load_checkpoint(ckpt_path, model)

        val_loss, avg_rouge, full_rouge_scores = evaluate_model(
            model=model,
            val_loader=val_loader,
            criterion=criterion,
            device=device,
            vocab_size=vocab_size,
            pad_token_id=pad_token_id,
            tokenizer=tokenizer
        )

        print(f"[Epoch {epoch:02d}] ValLoss: {val_loss:.4f} | "
              f"ROUGE-1: {avg_rouge['rouge1']:.4f} | "
              f"ROUGE-2: {avg_rouge['rouge2']:.4f} | "
              f"ROUGE-L: {avg_rouge['rougeL']:.4f}")

        results.append({
            "checkpoint": ckpt_path,
            "epoch": epoch,
            "val_loss": val_loss,
            "rouge": avg_rouge,
            "full_rouge_scores": full_rouge_scores
        })

    return results


In [54]:
def beam_search_decode(model, src_ids, src_mask=None, max_len=100, beam_size=5, 
                       device='cpu', tokenizer=None, length_penalty=1.0, 
                       early_stopping=True):
    
    model.eval()
    batch_size = src_ids.size(0)

    bos_token_id = getattr(tokenizer, 'bos_token_id', tokenizer.cls_token_id)
    eos_token_id = getattr(tokenizer, 'eos_token_id', tokenizer.sep_token_id)
    
    if batch_size > 1:
        return [
            beam_search_decode(
                model, src_ids[i:i+1],
                None if src_mask is None else src_mask[i:i+1],
                max_len, beam_size, device, tokenizer, length_penalty, early_stopping
            )
            for i in range(batch_size)
        ]

    current_tokens = torch.full((beam_size, 1), bos_token_id, dtype=torch.long, device=device)
    beam_scores = torch.zeros(beam_size, device=device)
    done_beams = [False] * beam_size

    expanded_src_ids = src_ids.expand(beam_size, -1)
    expanded_src_mask = None if src_mask is None else src_mask.expand(beam_size, -1)

    for step in range(max_len - 1):
        tgt_mask = None
        if hasattr(model, 'generate_square_subsequent_mask'):
            tgt_mask = model.generate_square_subsequent_mask(
                current_tokens.size(1)
            ).to(device)

        with torch.inference_mode():
            outputs = model(
                src=expanded_src_ids,
                tgt=current_tokens,
                src_attention_mask=expanded_src_mask,
                tgt_attention_mask=tgt_mask
            )

        next_token_logits = outputs[:, -1, :]
        next_token_logprobs = F.log_softmax(next_token_logits, dim=-1)

        vocab_size = next_token_logprobs.size(-1)
        expanded_scores = beam_scores.unsqueeze(1) + next_token_logprobs
        flat_scores = expanded_scores.view(-1)

        topk_scores, topk_indices = torch.topk(
            flat_scores, k=min(2 * beam_size, flat_scores.size(0))
        )
        beam_ix = topk_indices // vocab_size
        token_ix = topk_indices % vocab_size

        candidates = []
        for b, tok, sc in zip(beam_ix.tolist(), token_ix.tolist(), topk_scores.tolist()):
            new_tokens = torch.cat([
                current_tokens[b],
                torch.tensor([tok], dtype=torch.long, device=device)
            ], dim=0)
            candidates.append({
                'tokens': new_tokens,
                'score': sc,
                'is_done': done_beams[b] or (tok == eos_token_id)
            })
            if len(candidates) >= beam_size:
                break

        if len(candidates) == 0:
            return tokenizer.decode([bos_token_id], skip_special_tokens=True)

        while len(candidates) < beam_size:
            candidates.append(candidates[0])

        current_tokens = torch.stack([c['tokens'] for c in candidates])
        beam_scores = torch.tensor([c['score'] for c in candidates], device=device)
        done_beams = [c['is_done'] for c in candidates]

        if all(done_beams) and early_stopping:
            break

    seq_lens = current_tokens.size(1)
    adjusted_scores = beam_scores / (seq_lens ** length_penalty)
    best_idx = adjusted_scores.argmax().item()
    best_tokens = current_tokens[best_idx].tolist()

    return tokenizer.decode(best_tokens, skip_special_tokens=True)


In [62]:
def generate_random_sample_summary(model, val_loader, tokenizer, device, 
                                   max_len=100, method='beam_search', beam_size=5):
    model.eval()
    
    data_iter = iter(val_loader)
    batch = None
    for _ in range(random.randint(0, len(val_loader)-1)):
        batch = next(data_iter)
        
    if batch is None:
        print("Validation loader is empty or failed to sample.")
        return

    src = batch['input_ids'].to(device)
    tgt = batch['labels'].to(device)
    
    src = src.to(device)
    tgt = tgt.to(device)

    i = random.randint(0, src.size(0) - 1)
    src_sample = src[i:i+1]
    tgt_sample = tgt[i:i+1]

    src_mask = (src_sample == tokenizer.pad_token_id)

    summary = generate_summary(
        model=model,
        src_ids=src_sample,
        src_mask=src_mask,
        max_len=max_len,
        method=method,
        beam_size=beam_size,
        device=device
    )

    input_text = tokenizer.decode(
        [t for t in src_sample[0].tolist() if t != tokenizer.pad_token_id],
        skip_special_tokens=True
    )
    reference_summary = tokenizer.decode(
        [t for t in tgt_sample[0].tolist() if t != tokenizer.pad_token_id],
        skip_special_tokens=True
    )

    print("\n📘 Input Document:")
    print(input_text)
    print("\n✅ Reference Summary:")
    print(reference_summary)
    print("\n📝 Generated Summary:")
    print(summary)


In [None]:
model = TransformerSummarizer(
    vocab_size=tokenizer.vocab_size,
    d_model=512,
    nhead=8,
    enc_layers=3,
    dec_layers=3,
    dim_ff=2048,
    dropout=0.1
).to(device)

In [65]:
model, _, _, epoch, _ = load_checkpoint("checkpoint/transformer_epoch_16.pt", model)
print(f"Model loaded from epoch {epoch}")

Loading checkpoint from checkpoint/transformer_epoch_16.pt
Model weights loaded successfully
Checkpoint loaded from epoch 16
Validation loss: 3.0459
Model loaded from epoch 16


In [66]:
generate_random_sample_summary(
    model=model,
    val_loader=valid_loader,
    tokenizer=tokenizer,
    device=device,
    max_len=100,
    method="beam_search",
    beam_size=5
)


📘 Input Document:
The study by Family and Childcare Trust and Children in Scotland said there was an 80% variation in the cost of nursery care and a 92% variation for over-5s. It indicated that while the price of nursery care appeared to have stabilised, out of school care costs had increased.  The Scottish Childcare Report covered the year from December 2011 to 2012. As well as sharp variations in costs, it also found that about 40% of local authorities did not know if they had sufficient childcare for working parents, making childcare a postcode lottery for parents.  Of the councils that had some knowledge about the supply of childcare in their local area, there was a particular shortage of childcare for older children and disabled children, it claimed.    The report said more than half of all families in Scotland had used grandparents for childcare purposes in the past six months - the highest proportion in the UK. Children in Scotland Policy Officer Jim Stephen said: We welcome th

In [68]:
generate_random_sample_summary(
    model=model,
    val_loader=valid_loader,
    tokenizer=tokenizer,
    device=device,
    max_len=100,
    method="greedy",
    beam_size=5
)


📘 Input Document:
Jarring photos of facilities in the Rio Grande show 51 female migrants held in a cell made for 40 men, and 71 males held in a cell built for 41 women.  Adults were packed in standing room only cells for a week, with others held in overcrowded cells for over a month.  One facility manager called the situation a ticking time bomb.  We are concerned that overcrowding and prolonged detention represent an immediate risk to the health and safety of [Department of Homeland Security] agents and officers, and to those detained, inspectors said in the report.  The inspectors, from the US inspector general, visited seven sites throughout the Rio Grande valley in southern Texas.  At the facilities, the inspectors found that 30% of the detained children had been held for longer than the 72 hours permitted. Some had no access to showers or hot meals and had little access to clean clothes.  When detainees observed us, they banged on the cell windows, shouted, pressed notes to the w

In [69]:
def generate_summary_from_text(model, tokenizer, device, src_text, max_len=100, method='beam_search', beam_size=5):
    model.eval()
    
    src = tokenizer.encode(src_text, return_tensors="pt").to(device)
    
    src_mask = (src == tokenizer.pad_token_id)

    summary = generate_summary(
        model=model,
        src_ids=src,
        src_mask=src_mask,
        max_len=max_len,
        method=method,
        beam_size=beam_size,
        device=device
    )

    input_text = tokenizer.decode(src[0], skip_special_tokens=True)
    print("\n📘 Input Document:")
    print(input_text)
    print("\n📝 Generated Summary:")
    print(summary)


In [70]:
src_text = "XYZ University is widely recognized as one of the nation’s top-tier institutions for advanced research, education, and innovation in computer science, engineering, and interdisciplinary studies. Each year, the university admits a select cohort of only 30 exceptionally talented and driven students into its prestigious research fellowship program. This limited intake ensures high-quality mentorship and individual attention, making the selection process intensely competitive. These fellows are guided by globally recognized faculty who are pioneers in their respective domains. Faculty members at XYZ University are deeply involved in groundbreaking research across a wide range of fields, including Artificial Intelligence, Machine Learning, Quantum Computing, Computational Neuroscience, Human-Computer Interaction, Software Verification, Natural Language Processing, and Computer Vision."


📘 Input Document:
The field of artificial intelligence has seen rapid advancements in recent years. Researchers are exploring new ways to improve machine learning algorithms, particularly in natural language processing.

📝 Generated Summary:
A team of scientists has been created to help scientists in Jersey to improve the language of the language.  
