# 2025 DL Lab6: Text Summarization with Seq2Seq Model

Before we start, please put **your name** and **SID** in following format: <br>
Hi I'm 陸仁賈, 314831000.

**Your Answer:**    
Hi I'm 邱照元, 314834001.

## Overview
This assignment involves implementing a hybrid sequence-to-sequence model to perform text summarization on the SAMSum and Reddit TIFU datasets.

The model architecture is composed of two main parts:
A pre-trained model utilized as the encoder.
A new decoder which must be implemented from scratch.

The objective is to fine-tune the existing encoder while training the custom decoder from the beginning, enabling the complete model to generate accurate and concise summaries. Performance is measured using the standard summarization metric: ROUGE-L Score.

## Kaggle Competition
Kaggle is an online community of data scientists and machine learning practitioners. Kaggle allows users to find and publish datasets, explore and build models in a web-based data-science environment, work with other data scientists and machine learning engineers, and enter competitions to solve data science challenges.

This assignment use kaggle to calculate your grade.  
Please use this [**LINK**](https://www.kaggle.com/t/69788476947b482b88e46c9565db190b) to join the competition.

## Unzip Data

Unzip dataset.zip

### SAMSum
+ `train` : 14700
+ `val` : 818
+ `test` : 819

### Redit_TIFU
+ `train` : 29498
+ `val` : 4212
+ `test` : 8429

In [None]:
# Nmixx
"""
pytorch
numpy
kaggle
datasets
pyarrow
transformer
"""

In [13]:
import torch
import sys
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Version: {torch.version.cuda}")
print(f"Python Major.Minor: {sys.version_info.major}.{sys.version_info.minor}")
print(f"ABI: {torch._C._GLIBCXX_USE_CXX11_ABI}")

PyTorch Version: 2.9.0+cu128
CUDA Version: 12.8
Python Major.Minor: 3.12
ABI: True


In [None]:
import csv
import math
import random
from pathlib import Path
from typing import Optional, Tuple, Union
from data_utils import *
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from transformers import get_linear_schedule_with_warmup
from tqdm.auto import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformer.Constants import *
from transformer.Models import Seq2SeqModelWithFlashAttn
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

MODE = "train"  # set to "predict" for inference
CHECKPOINT_PATH = Path("checkpoints/latest.pt")
BEST_CHECKPOINT_PATH = Path("checkpoints/best.pt")
PREDICT_CHECKPOINT = Path("checkpoints/best.pt")
TIFU_TEST_PATH = Path("dataset/tifu/tifu_test.jsonl")
SAMSUN_TEST_PATH = Path("dataset/samsun/test.csv")
PREDICTION_OUTPUT = Path("result.csv")
MAX_TARGET_LEN = 512
MAX_GENERATION_LEN = MAX_TARGET_LEN
TRAIN_EPOCHS = 30
TRAIN_BATCH_SIZE = 100
GLOBAL_SEED = 42
NUM_WORKERS = 4
def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


In [None]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


## CREATE DATASET
use ConCate dataset to handle multiple datasets situation

In [None]:
def build_dataset(
    path: List[Optional[str]],
    tokenizer: PreTrainedTokenizerBase,
    require_target: bool = True,
) -> Optional[Dataset]:
    if all(p is None for p in path):
        return None
    datasets = []
    for p in path:
        if p is not None:
            dataset = SquadSeq2SeqDataset(
                Path(p), tokenizer, max_source_len=MAX_SOURCE_LEN, max_target_len=MAX_TARGET_LEN, require_target=require_target
            )
            datasets.append(dataset)
    print(f"Built dataset with {sum(len(ds) for ds in datasets)} samples.")
    if len(datasets) == 1:
        return datasets[0]
    return ConcatDataset(datasets)

def build_dataloader(
    source: Union[Optional[Dataset], Optional[str]],
    batch_size: int = 4,
    shuffle: bool = False,
    num_workers: int = 8,
) -> Optional[DataLoader]:
    dataset = source
    collator = QACollator # Don't forget to define QACollator in data_utils.py
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collator,
        num_workers=num_workers,
    )


## Main loop of your model

In [None]:
def run_epoch(
    dataloader: DataLoader,
    model: Seq2SeqModelWithFlashAttn,
    device: torch.device,
    optimizer: Optional[torch.optim.Optimizer],
    scheduler: Optional[object],
    pad_id: int,
    max_grad_norm: float,
    train: bool,
) -> float:
    model.train(train)
    total_loss = 0.0
    steps = 0
    iterator = tqdm(dataloader, desc="train" if train else "eval", leave=False)
    
    for batch in iterator:
        src = batch["src"].to(device) # (B, L_src)
        tgt = batch["tgt"].to(device) # (B, L_tgt), content: [BOS, w1, ..., EOS, PAD...]

        # 檢查 tgt 是否合法
        if tgt.size(1) < 2:
             raise ValueError("Each target sequence must contain at least BOS and EOS tokens.")

        ############### YOUR CODE HERE (FIXED) ###############
        
        # 1. 準備 Decoder Input 和 Labels
        # Decoder Input: 去掉最後一個 token (通常是 EOS 或 PAD)，保留 BOS 開頭
        dec_in = tgt[:, :-1] 
        # Labels: 去掉第一個 token (BOS)，保留 EOS 結尾以計算 Loss
        labels = tgt[:, 1:]

        # 2. 生成 Masks (因為改回了標準 Attention，需要自行生成 Mask)
        # Source Mask: 標記 PAD 的位置為 False (或依模型實作而定，這裡假設 True 為保留)
        # Models.py 的 get_pad_mask 邏輯是 (seq != pad_idx).unsqueeze(-2)
        src_mask = (src != pad_id).unsqueeze(-2).to(device)

        # Target Mask: 結合 Padding Mask 與 Sequence Mask (Causal Mask)
        trg_pad_mask = (dec_in != pad_id).unsqueeze(-2).to(device)
        
        sz_b, len_s = dec_in.size()
        # 建立下三角矩陣 (Causal Mask)
        subsequent_mask = torch.tril(torch.ones((len_s, len_s), device=device)).bool()
        trg_mask = trg_pad_mask & subsequent_mask.unsqueeze(0)

        # 3. Forward Pass
        # 注意：這裡改為傳入 mask，而非 seq_len
        logits = model(
            src_input_ids=src,
            trg_input_ids=dec_in,
            src_mask=src_mask,
            trg_mask=trg_mask
        )

        # 4. Compute Loss
        # 將 logits 展平為 (B * (L-1), Vocab_Size) 以計算 CrossEntropy
        # ignore_index=pad_id 會自動忽略 Padding 的 Loss
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=pad_id)
        
        ######################################################
        
        if train:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
        total_loss += loss.item()
        steps += 1
        iterator.set_postfix(loss=total_loss / max(1, steps))
        
    return total_loss / max(1, steps)

## Checkpoints management

In [None]:
def load_checkpoint(
    model: Seq2SeqModelWithFlashAttn,
    path: Path,
    device: torch.device,
) -> None:
    state = torch.load(path, map_location=device)
    model.load_state_dict(state["model_state_dict"])

def save_checkpoint(
    model: Seq2SeqModelWithFlashAttn,
    optimizer: torch.optim.Optimizer,
    scheduler: Optional[object],
    path: Path,
    epoch: int,
) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    if scheduler is not None and hasattr(scheduler, "state_dict"):
        state["scheduler_state_dict"] = scheduler.state_dict()
    torch.save(state, path)

# (選用) 如果你要繼續訓練，建議改用這個版本來同時載入 Optimizer
def load_checkpoint_for_training(
    model, optimizer, scheduler, path, device
):
    state = torch.load(path, map_location=device)
    model.load_state_dict(state["model_state_dict"])
    optimizer.load_state_dict(state["optimizer_state_dict"])
    if scheduler is not None and "scheduler_state_dict" in state:
        scheduler.load_state_dict(state["scheduler_state_dict"])
    print(f"已載入 Epoch {state['epoch']} 的完整訓練狀態")
    return state["epoch"]

## Training

In [None]:
### Hyperparameters and arguments ###
lr = 1e-4
weight_decay = 0.001
warmup_steps = 2000
epochs = TRAIN_EPOCHS
max_grad_norm = 1.0
batch_size = TRAIN_BATCH_SIZE
num_workers = NUM_WORKERS
#####################################
set_seed(GLOBAL_SEED)
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    raise RuntimeError("CUDA is required to run this code.")

# Check if flash attention is available
# try:
#     import flash_attn  # noqa: F401
# except ImportError:
#     raise ImportError("flash_attn is required to run this code.")

model = Seq2SeqModelWithFlashAttn(
    transformer_model_path="answerdotai/ModernBERT-base",
    freeze_encoder=True,
).to(device)
tokenizer = model.tokenizer
checkpoint_path = CHECKPOINT_PATH
best_checkpoint_path = BEST_CHECKPOINT_PATH

In [None]:
train_set = build_dataset(
    ["dataset/tifu/tifu_train.jsonl", "dataset/samsun/train.csv"],
    tokenizer=model.tokenizer,
)
train_loader = build_dataloader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)
val_set = build_dataset(
    ["dataset/tifu/tifu_val.jsonl", "dataset/samsun/validation.csv"],
    tokenizer=model.tokenizer,
)
valid_loader = build_dataloader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=lr, weight_decay=weight_decay
)
total_steps = epochs * len(train_loader)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=min(warmup_steps, total_steps),
    num_training_steps=total_steps,
)

In [None]:
# a=b

In [None]:
# # 確保你已經定義了 model 和 device，並引入了 Path
# from pathlib import Path
# latest_ckpt_path = Path("checkpoints/latest.pt")
# start_epoch = 0
# if latest_ckpt_path.exists():
#     start_epoch = load_checkpoint_for_training(model, optimizer, scheduler, latest_ckpt_path, device)
#     print(f"成功載入模型權重：{latest_ckpt_path}")
# else:
#     print(f"找不到檔案：{latest_ckpt_path}")

In [None]:
best_val_ppl = float("inf")
for epoch in range(start_epoch + 1, epochs + 1):
    train_loss = run_epoch(
        train_loader,
        model,
        device,
        optimizer,
        scheduler,
        tokenizer.pad_token_id,
        max_grad_norm,
        train=True,
    )
    msg = f"Epoch {epoch}/{epochs} - train loss: {train_loss:.4f}"
    current_val_ppl = None
    with torch.no_grad():
        val_loss = run_epoch(
            valid_loader,
            model,
            device,
            optimizer=None,
            scheduler=None,
            pad_id=tokenizer.pad_token_id,
            max_grad_norm=max_grad_norm,
            train=False,
        )
    perplexity = math.exp(min(20, val_loss))
    current_val_ppl = perplexity
    msg += f" | val loss: {val_loss:.4f} | ppl: {perplexity:.2f}"
    print(msg)
    if checkpoint_path is not None:    
        save_checkpoint(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            path=checkpoint_path,
            epoch=epoch,
        )
    if (
        current_val_ppl is not None
        and current_val_ppl < best_val_ppl
        and best_checkpoint_path is not None
    ):
        best_val_ppl = current_val_ppl
        save_checkpoint(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            path=best_checkpoint_path,
            epoch=epoch,
        )

## Predict Result

Predict the labesl based on testing set. Upload to [Kaggle](https://www.kaggle.com/t/69788476947b482b88e46c9565db190b).

**How to upload**

1. To kaggle. Click "Submit Predictions"
2. Upload the result.csv
3. System will automaticlaly calculate the accuracy of 50% dataset and publish this result to leaderboard.

In [None]:
load_checkpoint(model, PREDICT_CHECKPOINT, device)
model.eval()
test_set = build_dataset(
    [TIFU_TEST_PATH, SAMSUN_TEST_PATH],
    tokenizer=model.tokenizer,
    require_target=False,
)
test_loader = build_dataloader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)
predictions: List[Tuple[str, str]] = []
with torch.no_grad():
    for sample in tqdm(test_loader, desc="predict", leave=False):
        input_ids = sample["src"].to(device)
        src_lens = sample["src_len"].to(device=device, dtype=torch.int32)
        ids = sample["id"] #list of ids
        summaries = model.generate(
            input_ids=input_ids,
            src_seq_len=src_lens,
            generation_limit=MAX_GENERATION_LEN,
            sampling=True,
            top_k=50,
            top_p=0.9,
        )
        predictions.extend(zip(ids, summaries))
output_path = PREDICTION_OUTPUT
write_predictions_csv(output_path, predictions)
print(f"Wrote {len(predictions)} predictions to {output_path}")

In [None]:
!kaggle competitions submit -c lab-6-training-a-seq-2-seq-model-on-s-qu-ad-639401 -f result.csv -m "Message"    