In [1]:
from tokenizers import ByteLevelBPETokenizer
from datasets import load_dataset
from pathlib import Path

In [2]:
dataset = load_dataset("nlplabtdtu/xlsum_en")

In [3]:
def get_corpus(dataset):
    for split in ['train', 'validation', 'test']:
        for example in dataset[split]:
            yield example['text'] + ' ' + example['target']

In [4]:
corpus = list(get_corpus(dataset))

In [5]:
tokenizer = ByteLevelBPETokenizer()

In [6]:
tokenizer.train_from_iterator(
    corpus,                     
    min_frequency=2,       
    special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
)






In [7]:
tokenizer.get_vocab_size()

30000

In [8]:
output_dir = "my_tokenizer"
Path(output_dir).mkdir(parents=True, exist_ok=True)  # Create directory if it doesn't exist
tokenizer.save_model(output_dir)

['my_tokenizer/vocab.json', 'my_tokenizer/merges.txt']

In [9]:
from transformers import BartTokenizerFast

tokenizer = BartTokenizerFast(
    vocab_file=f"{output_dir}/vocab.json",
    merges_file=f"{output_dir}/merges.txt",
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    mask_token="<mask>"
)

In [10]:
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BartTokenizer, get_linear_schedule_with_warmup
from rouge_score import rouge_scorer
from tqdm import tqdm

In [11]:
max_src_len   = 512      
max_tgt_len   = 128      
batch_size    = 8
epochs        = 10
lr            = 0.0001
weight_decay  = 0.01
grad_clip     = 1.0
beam_size     = 4
warmup_ratio  = 0.1
checkpoint_dir = "checkpoint"

In [12]:
vocab_size = tokenizer.vocab_size

In [13]:
def tokenize(batch):
    src = tokenizer(batch['text'], max_length=max_src_len, truncation=True, padding='max_length')
    tgt = tokenizer(batch['target'], max_length=max_tgt_len, truncation=True, padding='max_length')
    
    return {'input_ids': src.input_ids, 'attention_mask': src.attention_mask, 'labels': tgt.input_ids}

In [14]:
def get_loaders(dataset_path):
    splits = load_dataset(dataset_path, split={"train":"train","validation":"validation"})
    tokenized = splits.map(tokenize, batched=True)
    tokenized.set_format(type='torch', columns=['input_ids','attention_mask','labels'])
    train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True,persistent_workers=True)
    valid_loader = DataLoader(tokenized['validation'], batch_size=batch_size, num_workers=5, pin_memory=True)
    
    return train_loader, valid_loader

In [15]:
train_loader, valid_loader = get_loaders("nlplabtdtu/xlsum_en")

In [16]:
pad_token_id = tokenizer.pad_token_id

In [17]:
class TrainablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        self.pos = nn.Parameter(torch.zeros(max_len, d_model))
    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pos[:seq_len].unsqueeze(1)

In [18]:
class TransformerSummarizer(nn.Module):
    
    def __init__(self, vocab_size, d_model=512, nhead=8, enc_layers=3, dec_layers=3, dim_ff=2048, dropout=0.1):
        super().__init__()
        self.embed       = nn.Embedding(vocab_size, d_model)
        self.pos_enc     = TrainablePositionalEncoding(d_model)
        self.pos_dec     = TrainablePositionalEncoding(d_model)
        self.transformer = nn.Transformer(d_model, nhead, enc_layers, dec_layers, dim_ff, dropout, activation='gelu', batch_first=True)
        self.out_proj    = nn.Linear(d_model, vocab_size)
        
    def forward(self, src, tgt, src_attention_mask=None, tgt_attention_mask=None):

        pad_id = tokenizer.pad_token_id
        src_kpm = (src == pad_id)
        tgt_kpm = (tgt == pad_id)
        
        src_emb = self.pos_enc(self.embed(src) * math.sqrt(self.embed.embedding_dim))
        tgt_emb = self.pos_dec(self.embed(tgt) * math.sqrt(self.embed.embedding_dim))
        
        size = tgt.size(1)
        tgt_mask = torch.triu(torch.full((size, size), True,dtype=torch.bool), 1).to(device)
        
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_kpm,
                               tgt_key_padding_mask=tgt_kpm,
                               memory_key_padding_mask=src_kpm
                              )
        
        return self.out_proj(out)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
model = TransformerSummarizer(vocab_size).to(device)

In [21]:
model

TransformerSummarizer(
  (embed): Embedding(30000, 512)
  (pos_enc): TrainablePositionalEncoding()
  (pos_dec): TrainablePositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): Transf

In [22]:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

In [23]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='none')

In [24]:
total_steps = len(train_loader) * epochs
warmup_steps = int(total_steps * warmup_ratio)
    
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps
                                            )

In [25]:
print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

Available GPUs: 1
GPU 0: NVIDIA GeForce RTX 4090


In [26]:
def train_epoch(model, train_loader, optimizer, criterion, scheduler, device, vocab_size, pad_token_id, grad_clip=1.0):
    
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for batch in progress_bar:
        
        src = batch['input_ids'].to(device, non_blocking=True)
        src_attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        tgt = batch['labels'].to(device, non_blocking=True)
        
        tgt_inp, tgt_lbl = tgt[:, :-1], tgt[:, 1:]
        
        tgt_attention_mask = (tgt_inp != pad_token_id).to(device, non_blocking=True)
        
        loss_mask = (tgt_lbl != pad_token_id).float()
        
        optimizer.zero_grad()
        
        logits = model(
            src=src,
            tgt=tgt_inp,
            src_attention_mask=src_attention_mask,
            tgt_attention_mask=tgt_attention_mask
        )
        
        loss = criterion(logits.reshape(-1, vocab_size), tgt_lbl.reshape(-1))
        
        masked_loss = (loss * loss_mask.reshape(-1)).sum() / max(loss_mask.sum(), 1)
        
        masked_loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        
        optimizer.step()
        
        scheduler.step()
        
        current_loss = masked_loss.item()
        total_loss += current_loss
        progress_bar.set_postfix({"loss": f"{current_loss:.4f}"})
        
    average_loss = total_loss / len(train_loader)
    return average_loss

In [27]:
def evaluate(model, val_loader, criterion, device, vocab_size, pad_token_id, tokenizer):

    model.eval()
    
    total_loss = 0
    rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    progress_bar = tqdm(val_loader, desc="Validating")
    
    with torch.inference_mode():
        for batch in progress_bar:
            src = batch['input_ids'].to(device, non_blocking=True)
            src_attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            tgt = batch['labels'].to(device, non_blocking=True)
            tgt_inp, tgt_lbl = tgt[:, :-1], tgt[:, 1:]
            
            tgt_attention_mask = (tgt_inp != pad_token_id).to(device, non_blocking=True)
            
            loss_mask = (tgt_lbl != pad_token_id).float()
            
            logits = model(
                src=src,
                tgt=tgt_inp,
                src_attention_mask=src_attention_mask,
                tgt_attention_mask=tgt_attention_mask
            )
            
            loss = criterion(logits.reshape(-1, vocab_size), tgt_lbl.reshape(-1))
            masked_loss = (loss * loss_mask.reshape(-1)).sum() / max(loss_mask.sum(), 1)
            total_loss += masked_loss.item()
            
    avg_loss = total_loss / len(val_loader)
    
    return avg_loss

In [28]:
def train(model, epochs, train_loader, val_loader, vocab_size, tokenizer, optimizer, criterion, scheduler, pad_token_id, grad_clip):
    
    model = model.to(device, non_blocking=True)
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        train_loss = train_epoch(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            criterion=criterion,
            scheduler=scheduler,
            device=device,
            vocab_size=vocab_size,
            pad_token_id=pad_token_id,
            grad_clip=grad_clip
        )
        
        val_loss = evaluate(
            model=model,
            val_loader=val_loader,
            criterion=criterion,
            device=device,
            vocab_size=vocab_size,
            pad_token_id=pad_token_id,
            tokenizer=tokenizer
        )
        
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_loss': train_loss,
                    'val_loss': val_loss
                    }, f"{checkpoint_dir}/transformer_epoch_{epoch+1}.pt")
        print(f"Saved checkpoint for epoch {epoch+1}")

In [29]:
import warnings
warnings.filterwarnings("ignore", message="The PyTorch API of nested tensors is in prototype stage")
warnings.filterwarnings("ignore", message=".*The current process just got forked.*")

In [30]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [31]:
train(model, epochs, train_loader, valid_loader, vocab_size, tokenizer, optimizer, criterion, scheduler, pad_token_id, grad_clip)


Epoch 1/10


Training: 100%|██████████| 38316/38316 [41:12<00:00, 15.49it/s, loss=3.9932]
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 29.00it/s]


Train Loss: 5.6396 | Val Loss: 4.3961
Saved checkpoint for epoch 1

Epoch 2/10


Training: 100%|██████████| 38316/38316 [41:11<00:00, 15.51it/s, loss=5.9295]
Validating: 100%|██████████| 1442/1442 [00:50<00:00, 28.78it/s]


Train Loss: 4.1076 | Val Loss: 3.7102
Saved checkpoint for epoch 2

Epoch 3/10


Training: 100%|██████████| 38316/38316 [41:08<00:00, 15.52it/s, loss=3.3837]
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 29.14it/s]


Train Loss: 3.6406 | Val Loss: 3.4667
Saved checkpoint for epoch 3

Epoch 4/10


Training: 100%|██████████| 38316/38316 [41:11<00:00, 15.50it/s, loss=3.7923]
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 29.13it/s]


Train Loss: 3.3926 | Val Loss: 3.3279
Saved checkpoint for epoch 4

Epoch 5/10


Training: 100%|██████████| 38316/38316 [41:19<00:00, 15.46it/s, loss=3.6574] 
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 29.13it/s]


Train Loss: 3.2176 | Val Loss: 3.2373
Saved checkpoint for epoch 5

Epoch 6/10


Training: 100%|██████████| 38316/38316 [41:09<00:00, 15.51it/s, loss=2.1435]
Validating: 100%|██████████| 1442/1442 [00:50<00:00, 28.74it/s]


Train Loss: 3.0772 | Val Loss: 3.1787
Saved checkpoint for epoch 6

Epoch 7/10


Training: 100%|██████████| 38316/38316 [41:13<00:00, 15.49it/s, loss=3.6840]
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 28.90it/s]


Train Loss: 2.9566 | Val Loss: 3.1343
Saved checkpoint for epoch 7

Epoch 8/10


Training: 100%|██████████| 38316/38316 [41:13<00:00, 15.49it/s, loss=2.8114]
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 29.06it/s]


Train Loss: 2.8510 | Val Loss: 3.1043
Saved checkpoint for epoch 8

Epoch 9/10


Training: 100%|██████████| 38316/38316 [41:14<00:00, 15.49it/s, loss=4.8861]
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 29.14it/s]


Train Loss: 2.7561 | Val Loss: 3.0839
Saved checkpoint for epoch 9

Epoch 10/10


Training: 100%|██████████| 38316/38316 [41:10<00:00, 15.51it/s, loss=4.2148]
Validating: 100%|██████████| 1442/1442 [00:49<00:00, 29.11it/s]


Train Loss: 2.6749 | Val Loss: 3.0764
Saved checkpoint for epoch 10
