# 2025 DL Lab6: Text Summarization with Seq2Seq Model

**Your Answer:**    
Hi I'm 邱照元, 314834001.

## Overview
This assignment involves implementing a hybrid sequence-to-sequence model to perform text summarization on the SAMSum and Reddit TIFU datasets.

The model architecture is composed of two main parts:
A pre-trained model utilized as the encoder.
A new decoder which must be implemented from scratch.

The objective is to fine-tune the existing encoder while training the custom decoder from the beginning, enabling the complete model to generate accurate and concise summaries. Performance is measured using the standard summarization metric: ROUGE-L Score.

## Kaggle Competition
Kaggle is an online community of data scientists and machine learning practitioners. Kaggle allows users to find and publish datasets, explore and build models in a web-based data-science environment, work with other data scientists and machine learning engineers, and enter competitions to solve data science challenges.

This assignment use kaggle to calculate your grade.  
Please use this [**LINK**](https://www.kaggle.com/t/69788476947b482b88e46c9565db190b) to join the competition.

## Unzip Data

Unzip dataset.zip

### SAMSum
+ `train` : 14700
+ `val` : 818
+ `test` : 819

### Redit_TIFU
+ `train` : 29498
+ `val` : 4212
+ `test` : 8429

In [1]:
import torch
import sys
import transformers

print(f"PyTorch Version: {torch.__version__}")
print(f"Python Major.Minor: {sys.version_info.major}.{sys.version_info.minor}")
print(f"transformers version: {transformers.__version__}")
print(f"ABI: {torch._C._GLIBCXX_USE_CXX11_ABI}")
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA Version: {torch.version.cuda}")
print(f"Device Name: {torch.cuda.get_device_name(0)}")

import flash_attn
print(f"Flash Attention version: {flash_attn.__version__}")
# 簡單測試 (若無報錯即成功)
q = torch.randn(1, 1, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(1, 1, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(1, 1, 32, 64, device='cuda', dtype=torch.float16)
from flash_attn import flash_attn_func
out = flash_attn_func(q, k, v)
print("Flash Attention calculation successful.")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch Version: 2.8.0+cu129
Python Major.Minor: 3.12
transformers version: 4.57.3
ABI: True
Is CUDA available: True
CUDA Version: 12.9
Device Name: NVIDIA GeForce RTX 5090
Flash Attention version: 2.8.3
Flash Attention calculation successful.


In [2]:
import csv
import math
import random
from pathlib import Path
from typing import Optional, Tuple, Union
from data_utils import *
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from transformers import get_linear_schedule_with_warmup
from tqdm.auto import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformer.Const import *
from transformer.Models import Seq2SeqModelWithFlashAttn
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

MODE = "train"  # set to "predict" for inference
CHECKPOINT_PATH = Path("checkpoints/latest.pt")
BEST_CHECKPOINT_PATH = Path("checkpoints/best.pt")
PREDICT_CHECKPOINT = Path("checkpoints/best.pt")
TIFU_TEST_PATH = Path("dataset/tifu/tifu_test.jsonl")
SAMSUN_TEST_PATH = Path("dataset/samsun/test.csv")
PREDICTION_OUTPUT = Path("result.csv")
MAX_TARGET_LEN = 512
MAX_GENERATION_LEN = MAX_TARGET_LEN
TRAIN_EPOCHS = 30
TRAIN_BATCH_SIZE = 256
GLOBAL_SEED = 42
NUM_WORKERS = 4
def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


In [3]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


## CREATE DATASET
use ConCate dataset to handle multiple datasets situation

In [4]:
def build_dataset(
    path: List[Optional[str]],
    tokenizer: PreTrainedTokenizerBase,
    require_target: bool = True,
) -> Optional[Dataset]:
    if all(p is None for p in path):
        return None
    datasets = []
    for p in path:
        if p is not None:
            dataset = SquadSeq2SeqDataset(
                Path(p), tokenizer, max_source_len=MAX_SOURCE_LEN, max_target_len=MAX_TARGET_LEN, require_target=require_target
            )
            datasets.append(dataset)
    print(f"Built dataset with {sum(len(ds) for ds in datasets)} samples.")
    if len(datasets) == 1:
        return datasets[0]
    return ConcatDataset(datasets)

def build_dataloader(
    source: Union[Optional[Dataset], Optional[str]],
    batch_size: int = 4,
    shuffle: bool = False,
    num_workers: int = 8,
) -> Optional[DataLoader]:
    dataset = source
    collator = QACollator # Don't forget to define QACollator in data_utils.py
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collator,
        num_workers=num_workers,
    )


## Main loop of your model

In [5]:
def run_epoch(
    dataloader: DataLoader,
    model: Seq2SeqModelWithFlashAttn,
    device: torch.device,
    optimizer: Optional[torch.optim.Optimizer],
    scheduler: Optional[object],
    pad_id: int,
    max_grad_norm: float,
    train: bool,
) -> float:
    model.train(train)
    total_loss = 0.0
    steps = 0
    iterator = tqdm(dataloader, desc="train" if train else "eval", leave=False)
    
    for batch in iterator:
        # 注意：這裡的 src 和 tgt 都是 1D 的 Packed Tensor
        src = batch["src"].to(device) # (Total_Src_Tokens,)
        tgt = batch["tgt"].to(device) # (Total_Tgt_Tokens,)
        src_len = batch["src_len"].to(device) # (B,)
        tgt_len = batch["tgt_len"].to(device) # (B,)

        ############### YOUR CODE HERE (FLASH ATTENTION VERSION) ###############
        
        # 因為 tgt 是多個句子串接在一起的 (Packed)，我們不能直接用 [:, :-1] 切片
        # 我們需要利用 cumsum 找出每個句子的邊界
        
        # 1. 計算累積長度來找出每個句子的邊界
        cu_len = torch.cumsum(tgt_len, dim=0, dtype=torch.long)
        
        # 找出每個句子最後一個 Token (EOS) 的位置 -> 用於 Decoder Input (要移除)
        end_indices = cu_len - 1
        
        # 找出每個句子第一個 Token (BOS) 的位置 -> 用於 Labels (要移除)
        start_indices = cu_len - tgt_len
        
        # 2. 製作 Mask 並切片
        total_tokens = tgt.size(0)
        
        # 製作 Decoder Input: 保留所有 token，但移除每個句子的最後一個 token (EOS)
        mask_in = torch.ones(total_tokens, dtype=torch.bool, device=device)
        mask_in[end_indices] = False
        dec_in = tgt[mask_in]
        
        # 製作 Labels: 保留所有 token，但移除每個句子的第一個 token (BOS)
        mask_label = torch.ones(total_tokens, dtype=torch.bool, device=device)
        mask_label[start_indices] = False
        labels = tgt[mask_label]
        
        # 調整長度 (每個句子都少了一個 token)
        dec_len = tgt_len - 1

        # 3. Forward Pass
        logits = model(
            src_input_ids=src,
            trg_input_ids=dec_in,
            src_seq_len=src_len,
            trg_seq_len=dec_len
        )

        # 4. Compute Loss
        # logits: (Total_Tokens, Vocab_Size), labels: (Total_Tokens,)
        # 直接計算 CrossEntropy，不需要再 reshape
        loss = F.cross_entropy(logits, labels, ignore_index=pad_id)
        
        ######################################################
        
        if train:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
        total_loss += loss.item()
        steps += 1
        iterator.set_postfix(loss=total_loss / max(1, steps))
        
    return total_loss / max(1, steps)

## Checkpoints management

In [6]:
def load_checkpoint(
    model: Seq2SeqModelWithFlashAttn,
    path: Path,
    device: torch.device,
) -> None:
    state = torch.load(path, map_location=device)
    model.load_state_dict(state["model_state_dict"])

def save_checkpoint(
    model: Seq2SeqModelWithFlashAttn,
    optimizer: torch.optim.Optimizer,
    scheduler: Optional[object],
    path: Path,
    epoch: int,
) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    if scheduler is not None and hasattr(scheduler, "state_dict"):
        state["scheduler_state_dict"] = scheduler.state_dict()
    torch.save(state, path)

# (選用) 如果你要繼續訓練，建議改用這個版本來同時載入 Optimizer
def load_checkpoint_for_training(
    model, optimizer, scheduler, path, device
):
    state = torch.load(path, map_location=device)
    model.load_state_dict(state["model_state_dict"])
    optimizer.load_state_dict(state["optimizer_state_dict"])
    if scheduler is not None and "scheduler_state_dict" in state:
        scheduler.load_state_dict(state["scheduler_state_dict"])
    print(f"已載入 Epoch {state['epoch']} 的完整訓練狀態")
    return state["epoch"]

## Training

In [7]:
### Hyperparameters and arguments ###
lr = 8e-5
weight_decay = 0.001
warmup_steps = 2000
epochs = TRAIN_EPOCHS
max_grad_norm = 1.0
batch_size = TRAIN_BATCH_SIZE
num_workers = NUM_WORKERS
#####################################
set_seed(GLOBAL_SEED)
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    raise RuntimeError("CUDA is required to run this code.")

# Check if flash attention is available
try:
    import flash_attn  # noqa: F401
except ImportError:
    raise ImportError("flash_attn is required to run this code.")

model = Seq2SeqModelWithFlashAttn(
    transformer_model_path="answerdotai/ModernBERT-base",
    freeze_encoder=True,
).to(device)
tokenizer = model.tokenizer
checkpoint_path = CHECKPOINT_PATH
best_checkpoint_path = BEST_CHECKPOINT_PATH

`torch_dtype` is deprecated! Use `dtype` instead!


In [8]:
train_set = build_dataset(
    ["dataset/tifu/tifu_train.jsonl", "dataset/samsun/train.csv"],
    tokenizer=model.tokenizer,
)
train_loader = build_dataloader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)
val_set = build_dataset(
    ["dataset/tifu/tifu_val.jsonl", "dataset/samsun/validation.csv"],
    tokenizer=model.tokenizer,
)
valid_loader = build_dataloader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=lr, weight_decay=weight_decay
)
total_steps = epochs * len(train_loader)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=min(warmup_steps, total_steps),
    num_training_steps=total_steps,
)

start_epoch = 0

Built dataset with 44229 samples.
Built dataset with 5030 samples.


In [9]:
# a=b

In [10]:
# # 確保你已經定義了 model 和 device，並引入了 Path
# from pathlib import Path
# latest_ckpt_path = Path("checkpoints/latest.pt")
# if latest_ckpt_path.exists():
#     start_epoch = load_checkpoint_for_training(model, optimizer, scheduler, latest_ckpt_path, device)
#     print(f"成功載入模型權重：{latest_ckpt_path}")
# else:
#     print(f"找不到檔案：{latest_ckpt_path}")

In [None]:
best_val_ppl = float("inf")
for epoch in range(start_epoch + 1, epochs + 1):
    train_loss = run_epoch(
        train_loader,
        model,
        device,
        optimizer,
        scheduler,
        tokenizer.pad_token_id,
        max_grad_norm,
        train=True,
    )
    msg = f"Epoch {epoch}/{epochs} - train loss: {train_loss:.4f}"
    current_val_ppl = None
    with torch.no_grad():
        val_loss = run_epoch(
            valid_loader,
            model,
            device,
            optimizer=None,
            scheduler=None,
            pad_id=tokenizer.pad_token_id,
            max_grad_norm=max_grad_norm,
            train=False,
        )
    perplexity = math.exp(min(20, val_loss))
    current_val_ppl = perplexity
    msg += f" | val loss: {val_loss:.4f} | ppl: {perplexity:.2f}"
    print(msg)
    if checkpoint_path is not None:    
        save_checkpoint(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            path=checkpoint_path,
            epoch=epoch,
        )
    if (
        current_val_ppl is not None
        and current_val_ppl < best_val_ppl
        and best_checkpoint_path is not None
    ):
        best_val_ppl = current_val_ppl
        save_checkpoint(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            path=best_checkpoint_path,
            epoch=epoch,
        )

                                                                   

Epoch 1/30 - train loss: 11.1582 | val loss: 8.0469 | ppl: 3124.02


train:  53%|█████▎    | 91/173 [00:22<00:20,  4.02it/s, loss=7.84]

In [None]:
import re
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# --- 原始日誌字串 ---
log_string = """
Epoch 1/30 - train loss: 8.7638 | val loss: 6.9381 | ppl: 1030.82
                                                                   
Epoch 2/30 - train loss: 6.7033 | val loss: 6.2065 | ppl: 495.96
                                                                   
Epoch 3/30 - train loss: 6.0818 | val loss: 5.7010 | ppl: 299.16
                                                                   
Epoch 4/30 - train loss: 5.7039 | val loss: 5.4553 | ppl: 233.99
                                                                   
Epoch 5/30 - train loss: 5.4383 | val loss: 5.2457 | ppl: 189.75
                                                                   
Epoch 6/30 - train loss: 5.1855 | val loss: 5.0729 | ppl: 159.64
                                                                   
Epoch 7/30 - train loss: 4.9982 | val loss: 4.9571 | ppl: 142.18
                                                                   
Epoch 8/30 - train loss: 4.8601 | val loss: 4.8750 | ppl: 130.97
                                                                   
Epoch 9/30 - train loss: 4.7542 | val loss: 4.7981 | ppl: 121.28
                                                                   
Epoch 10/30 - train loss: 4.6684 | val loss: 4.7381 | ppl: 114.21
                                                                   
Epoch 11/30 - train loss: 4.6021 | val loss: 4.6900 | ppl: 108.85
                                                                   
Epoch 12/30 - train loss: 4.5514 | val loss: 4.6581 | ppl: 105.43
                                                                   
Epoch 13/30 - train loss: 4.5083 | val loss: 4.6342 | ppl: 102.94
                                                                   
Epoch 14/30 - train loss: 4.4744 | val loss: 4.6109 | ppl: 100.58
                                                                   
Epoch 15/30 - train loss: 4.4458 | val loss: 4.6017 | ppl: 99.66
                                                                   
Epoch 16/30 - train loss: 4.4242 | val loss: 4.5864 | ppl: 98.14
                                                                   
Epoch 17/30 - train loss: 4.4043 | val loss: 4.5692 | ppl: 96.47
                                                                   
Epoch 18/30 - train loss: 4.3873 | val loss: 4.5582 | ppl: 95.41
                                                                   
Epoch 19/30 - train loss: 4.3742 | val loss: 4.5545 | ppl: 95.06
                                                                   
Epoch 20/30 - train loss: 4.3646 | val loss: 4.5447 | ppl: 94.14
                                                                   
Epoch 21/30 - train loss: 4.3573 | val loss: 4.5447 | ppl: 94.14
                                                                   
Epoch 22/30 - train loss: 4.3493 | val loss: 4.5392 | ppl: 93.62
                                                                   
Epoch 23/30 - train loss: 4.3438 | val loss: 4.5377 | ppl: 93.47
                                                                   
Epoch 24/30 - train loss: 4.3403 | val loss: 4.5368 | ppl: 93.39
                                                                   
Epoch 25/30 - train loss: 4.3362 | val loss: 4.5325 | ppl: 92.99
                                                                   
Epoch 26/30 - train loss: 4.3346 | val loss: 4.5294 | ppl: 92.70
                                                                   
Epoch 27/30 - train loss: 4.3319 | val loss: 4.5294 | ppl: 92.70
                                                                   
Epoch 28/30 - train loss: 4.3311 | val loss: 4.5279 | ppl: 92.56
                                                                   
Epoch 29/30 - train loss: 4.3315 | val loss: 4.5273 | ppl: 92.51
                                                                   
Epoch 30/30 - train loss: 4.3299 | val loss: 4.5276 | ppl: 92.53
"""

# --- 數據提取與處理 ---
epochs = []
train_losses = []
val_losses = []
ppl_values = []

# 正則表達式捕捉 Epoch 數字和三個 metrics
# Group 1: Epoch 數字
# Group 2: train loss
# Group 3: val loss
# Group 4: ppl
regex_pattern = r'Epoch (\d+)/\d+ - train loss: ([\d.]+)\s*\| val loss: ([\d.]+)\s*\| ppl: ([\d.]+)'

for line in log_string.split('\n'):
    match = re.search(regex_pattern, line.strip())
    if match:
        try:
            epochs.append(int(match.group(1)))
            train_losses.append(float(match.group(2)))
            val_losses.append(float(match.group(3)))
            ppl_values.append(float(match.group(4)))
        except ValueError:
            continue

# 將數據轉換為 DataFrame 以便結構化
df = pd.DataFrame({
    'Epoch': epochs,
    'Train Loss': train_losses,
    'Val Loss': val_losses,
    'PPL': ppl_values
})

# --- 繪圖設定與執行 (使用雙 Y 軸) ---
fig, ax1 = plt.subplots(figsize=(10, 6))
fig.suptitle('Training Metrics Trend', fontsize=16)

# 設定 X 軸
ax1.set_xlabel('Epoch')
ax1.set_xticks(df['Epoch']) # 確保 X 軸標籤只顯示實際的 Epoch 數字
ax1.grid(True, linestyle='--', alpha=0.6)


# --- Y1 軸 (Losses) ---
color_loss = 'tab:blue'
ax1.set_ylabel('Loss Value', color=color_loss)
ax1.tick_params(axis='y', labelcolor=color_loss)

# 繪製 Train Loss 和 Val Loss (無 marker)
line1 = ax1.plot(df['Epoch'], df['Train Loss'], linestyle='-', color='blue', label='Train Loss')
line2 = ax1.plot(df['Epoch'], df['Val Loss'], linestyle='-', color='darkblue', label='Val Loss')


# --- Y2 軸 (PPL) ---
# 創建第二個 Y 軸
ax2 = ax1.twinx()  
color_ppl = 'tab:red'
ax2.set_ylabel('PPL (Perplexity)', color=color_ppl) 
ax2.tick_params(axis='y', labelcolor=color_ppl)

# 繪製 PPL (無 marker)
line3 = ax2.plot(df['Epoch'], df['PPL'], linestyle='-', color='red', label='PPL')


# 統一圖例 (Legend)
lines = line1 + line2 + line3
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, loc='upper right')

fig.tight_layout(rect=[0, 0, 1, 0.96]) # 調整佈局以容納主標題
plt.show()

# 輸出結果 (依照您的要求移除中文)
# print(df)

## Predict Result

Predict the labesl based on testing set. Upload to [Kaggle](https://www.kaggle.com/t/69788476947b482b88e46c9565db190b).

**How to upload**

1. To kaggle. Click "Submit Predictions"
2. Upload the result.csv
3. System will automaticlaly calculate the accuracy of 50% dataset and publish this result to leaderboard.

In [None]:
load_checkpoint(model, PREDICT_CHECKPOINT, device)
model.eval()
test_set = build_dataset(
    [TIFU_TEST_PATH, SAMSUN_TEST_PATH],
    tokenizer=model.tokenizer,
    require_target=False,
)
test_loader = build_dataloader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)
predictions: List[Tuple[str, str]] = []
with torch.no_grad():
    for sample in tqdm(test_loader, desc="predict", leave=False):
        input_ids = sample["src"].to(device)
        src_lens = sample["src_len"].to(device=device, dtype=torch.int32)
        ids = sample["id"] #list of ids
        summaries = model.generate(
            input_ids=input_ids,
            src_seq_len=src_lens,
            generation_limit=MAX_GENERATION_LEN,
            sampling=True,
            top_k=50,
            top_p=0.9,
        )
        predictions.extend(zip(ids, summaries))
output_path = PREDICTION_OUTPUT
write_predictions_csv(output_path, predictions)
print(f"Wrote {len(predictions)} predictions to {output_path}")

In [None]:
Message = (""
    +f"TRAIN_EPOCHS = {TRAIN_EPOCHS}\n"
    +f"TRAIN_BATCH_SIZE = {TRAIN_BATCH_SIZE}\n"
    +f"LERANING_RATE = {lr}\n"
    # +"citerion1=nn.CrossEntropyLoss()\n"
)
!kaggle competitions submit -c lab-6-training-a-seq-2-seq-model-on-s-qu-ad-639401 -f result.csv -m "{Message}"    

100%|███████████████████████████████████████| 0.99M/0.99M [00:01<00:00, 597kB/s]
Successfully submitted to Lab6 Text Summarization with Seq2Seq Model(639401)