In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BartTokenizer, get_linear_schedule_with_warmup
from rouge_score import rouge_scorer
from tqdm import tqdm

In [2]:
max_src_len   = 512      
max_tgt_len   = 128      
batch_size    = 8
epochs        = 25
lr            = 0.0001
weight_decay  = 0.01
grad_clip     = 1.0
beam_size     = 4
warmup_ratio  = 0.1

In [3]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
vocab_size = tokenizer.vocab_size

In [4]:
class TrainablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        self.pos = nn.Parameter(torch.zeros(max_len, d_model))
    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pos[:seq_len].unsqueeze(1)

In [5]:
class TransformerSummarizer(nn.Module):
    
    def __init__(self, vocab_size, d_model=512, nhead=8, enc_layers=3, dec_layers=3, dim_ff=2048, dropout=0.1):
        super().__init__()
        self.embed       = nn.Embedding(vocab_size, d_model)
        self.pos_enc     = TrainablePositionalEncoding(d_model)
        self.pos_dec     = TrainablePositionalEncoding(d_model)
        self.transformer = nn.Transformer(d_model, nhead, enc_layers, dec_layers, dim_ff, dropout, activation='gelu', batch_first=True)
        self.out_proj    = nn.Linear(d_model, vocab_size)
        
    def forward(self, src, tgt, src_attention_mask=None, tgt_attention_mask=None):

        pad_id = tokenizer.pad_token_id
        src_kpm = (src == pad_id)
        tgt_kpm = (tgt == pad_id)
        
        src_emb = self.pos_enc(self.embed(src) * math.sqrt(self.embed.embedding_dim))
        tgt_emb = self.pos_dec(self.embed(tgt) * math.sqrt(self.embed.embedding_dim))
        
        size = tgt.size(1)
        tgt_mask = torch.triu(torch.full((size, size), True,dtype=torch.bool), 1).to(device)
        
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_kpm,
                               tgt_key_padding_mask=tgt_kpm,
                               memory_key_padding_mask=src_kpm
                              )
        
        return self.out_proj(out)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
model = TransformerSummarizer(vocab_size).to(device)

In [8]:
def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [9]:
print(f"Trainable parameters: {count_trainable_params(model):,}")

Trainable parameters: 74,616,921
