In [None]:
from google.colab import drive
import sentencepiece as spm
import os
import sentencepiece as spm
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
drive.mount('/content/drive')

import os

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# ================================
# TRANSFORMER CONFIG (COLAB CELL)
# ================================

# -------- Paths --------
DATA_DIR = "/content/drive/MyDrive/TransformerMT/DATA"
CHECKPOINT_DIR = "/content/drive/MyDrive/TransformerMT/checkpoints"

TRAIN_EN = f"{DATA_DIR}/train.en"
TRAIN_VI = f"{DATA_DIR}/train.vi"
VALID_EN = f"{DATA_DIR}/valid.en"
VALID_VI = f"{DATA_DIR}/valid.vi"

VOCAB_EN_PATH = f"{DATA_DIR}/vocab_en.pth"
VOCAB_VI_PATH = f"{DATA_DIR}/vocab_vi.pth"

# -------- Special Tokens --------
PAD_TOKEN = "<pad>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"
UNK_TOKEN = "<unk>"

PAD_IDX = 0
BOS_IDX = 1
EOS_IDX = 2
UNK_IDX = 3

# -------- Model Hyperparameters --------
D_MODEL = 512
N_HEADS = 8
N_LAYERS = 6
D_FF = 2048
DROPOUT_RATE = 0.1

# -------- Training Hyperparameters --------
BATCH_SIZE = 64
EPOCHS = 150

LR = 1e-4
BETAS = (0.9, 0.98)
EPS = 1e-9

LABEL_SMOOTHING = 0.1
GRAD_CLIP = 1.0
WARMUP_STEPS = 4000

# -------- Sequence --------
MAX_LEN = 128

# -------- Device --------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -------- Logging / Saving --------
LOG_INTERVAL = 100
SAVE_EVERY_EPOCH = 1

VOCAB_SIZE_SPM=10000
MODEL_TYPE = 'bpe'

BATCH_SIZE = 64
MAX_SEQ_LEN = 128
VAL_SPLIT_RATIO = 0.1
SEED = 42
TEST_MODE_LIMIT=None
MAX_LEN_DECODE = 128
BEAM_SIZE = 5
TEST_EN_PATH = os.path.join(DATA_DIR, 'tst2013.en')
TEST_VI_PATH = os.path.join(DATA_DIR, 'tst2013.vi')
METRIC_LOG_PATH = os.path.join(DATA_DIR, 'training_log.json')





In [None]:
import os
import torch
import sentencepiece as spm
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from google.colab import drive


# ==============================================================================
# 2. C√ÅC H√ÄM TI·ªÜN √çCH (LOAD & STATS)
# ==============================================================================

def load_raw_data(file_path):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Kh√¥ng t√¨m th·∫•y file: {file_path}")
    print(f"-> ƒêang ƒë·ªçc file: {os.path.basename(file_path)}...")
    with open(file_path, encoding='utf-8') as f:
        return [line.strip() for line in f]

def plot_statistics(raw_data, sp_processor, lang_name):
    """
    Th·ªëng k√™ ƒë·ªô d√†i c√¢u d·ª±a tr√™n s·ªë l∆∞·ª£ng Subword (Token IDs).
    Gi√∫p quy·∫øt ƒë·ªãnh MAX_SEQ_LEN h·ª£p l√Ω.
    """
    # L·∫•y m·∫´u t·ªëi ƒëa 50k c√¢u ƒë·ªÉ th·ªëng k√™ cho nhanh
    sample_data = raw_data[:50000] if len(raw_data) > 50000 else raw_data

    # Tokenize ƒë·ªÉ ƒë·∫øm s·ªë l∆∞·ª£ng subword th·ª±c t·∫ø
    lengths = [len(sp_processor.encode_as_ids(s.lower())) for s in sample_data]

    if not lengths:
        print(f"D·ªØ li·ªáu {lang_name} tr·ªëng!")
        return

    print(f"\n[Th·ªëng k√™ {lang_name} (tr√™n {len(sample_data)} m·∫´u)]")
    print(f"- ƒê·ªô d√†i trung b√¨nh: {np.mean(lengths):.2f} subwords")
    print(f"- Max length: {np.max(lengths)} subwords")
    print(f"- 95% d·ªØ li·ªáu ng·∫Øn h∆°n: {np.percentile(lengths, 95):.0f} subwords")
    print(f"- 99% d·ªØ li·ªáu ng·∫Øn h∆°n: {np.percentile(lengths, 99):.0f} subwords")

    # V·∫Ω bi·ªÉu ƒë·ªì
    plt.figure(figsize=(8, 4))
    plt.hist(lengths, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
    plt.title(f"Ph√¢n ph·ªëi ƒë·ªô d√†i Subword ({lang_name})")
    plt.xlabel("S·ªë l∆∞·ª£ng Subword")
    plt.ylabel("S·ªë l∆∞·ª£ng c√¢u")
    plt.grid(axis='y', alpha=0.3)
    plt.show()

# ==============================================================================
# 3. DATASET & COLLATE FUNCTION
# ==============================================================================

class TranslationDatasetSPM(Dataset):
    """
    Dataset th·ª±c hi·ªán:
    1. Nh·∫≠n raw text.
    2. Tokenize b·∫±ng SentencePiece.
    3. L·ªçc b·ªè c√¢u qu√° d√†i (theo Subword Count).
    4. Th√™m <sos> v√† <eos>.
    """
    def __init__(self, raw_en, raw_vi, sp_en, sp_vi, max_len=100):
        self.sp_en = sp_en
        self.sp_vi = sp_vi
        self.SOS_IDX = sp_en.bos_id()
        self.EOS_IDX = sp_en.eos_id()

        self.data = []
        original_count = len(raw_en)

        print(f"-> ƒêang x·ª≠ l√Ω v√† l·ªçc {original_count} c·∫∑p c√¢u (Max Len={max_len})...")

        # Duy·ªát qua t·ª´ng c·∫∑p c√¢u
        for en, vi in zip(raw_en, raw_vi):
            # 1. Encode sang IDs (lowercase ƒë·ªÉ chu·∫©n h√≥a)
            en_ids = self.sp_en.encode_as_ids(en.lower())
            vi_ids = self.sp_vi.encode_as_ids(vi.lower())

            # 2. Ki·ªÉm tra ƒë·ªô d√†i (T√≠nh c·∫£ SOS v√† EOS s·∫Ω th√™m v√†o)
            if len(en_ids) + 2 <= max_len and len(vi_ids) + 2 <= max_len:
                self.data.append((en_ids, vi_ids))

        filtered_count = len(self.data)
        removed_count = original_count - filtered_count
        print(f"-> Ho√†n t·∫•t. Gi·ªØ l·∫°i: {filtered_count} | Lo·∫°i b·ªè: {removed_count} ({removed_count/original_count:.2%})")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        en_ids, vi_ids = self.data[index]

        # Th√™m SOS ƒë·∫ßu v√† EOS cu·ªëi
        en_out = [self.SOS_IDX] + en_ids + [self.EOS_IDX]
        vi_out = [self.SOS_IDX] + vi_ids + [self.EOS_IDX]

        # Tr·∫£ v·ªÅ LongTensor
        return torch.tensor(en_out, dtype=torch.long), torch.tensor(vi_out, dtype=torch.long)

class MyCollateSPM:
    """
    H√†m gom batch: Padding c√°c c√¢u trong batch cho b·∫±ng nhau.
    """
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        # T√°ch src v√† trg t·ª´ batch list
        src = [item[0] for item in batch]
        trg = [item[1] for item in batch]

        # Padding (batch_first=True -> Output: (Batch, Seq_Len))
        src = pad_sequence(src, batch_first=True, padding_value=self.pad_idx)
        trg = pad_sequence(trg, batch_first=True, padding_value=self.pad_idx)

        return src, trg

# ==============================================================================
# 4. LU·ªíNG CH·∫†Y CH√çNH (MAIN DATA PIPELINE)
# ==============================================================================

def process_data_pipeline():
    print("=== B·∫ÆT ƒê·∫¶U QUY TR√åNH X·ª¨ L√ù D·ªÆ LI·ªÜU ===\n")

    # -----------------------------------------
    # B∆Ø·ªöC 1: LOAD D·ªÆ LI·ªÜU TH√î
    # -----------------------------------------
    path_en = os.path.join(DATA_DIR, 'train.en')
    path_vi = os.path.join(DATA_DIR, 'train.vi')

    raw_en = load_raw_data(path_en)
    raw_vi = load_raw_data(path_vi)

    assert len(raw_en) == len(raw_vi), "L·ªói: S·ªë l∆∞·ª£ng c√¢u Anh-Vi·ªát kh√¥ng kh·ªõp!"
    print(f"-> T·ªïng s·ªë c√¢u raw: {len(raw_en)}")

    # -----------------------------------------
    # B∆Ø·ªöC 2: CHIA T·∫¨P TRAIN / VAL
    # -----------------------------------------
    print(f"\n--- Chia t·∫≠p d·ªØ li·ªáu (Val Ratio: {VAL_SPLIT_RATIO}) ---")
    en_train, en_val, vi_train, vi_val = train_test_split(
        raw_en, raw_vi, test_size=VAL_SPLIT_RATIO, random_state=SEED
    )
    print(f"-> Train set: {len(en_train)} c√¢u")
    print(f"-> Val set:   {len(en_val)} c√¢u")

    # -----------------------------------------
    # B∆Ø·ªöC 3: LOAD MODEL SENTENCEPIECE (ƒê√É TRAIN S·∫¥N)
    # -----------------------------------------
    print("\n--- Load SentencePiece Models ---")
    # Gi·∫£ ƒë·ªãnh file model t√™n l√† spm_en.model v√† spm_vi.model
    sp_en_path = os.path.join(DATA_DIR, 'spm_en.model')
    sp_vi_path = os.path.join(DATA_DIR, 'spm_vi.model')

    if not os.path.exists(sp_en_path) or not os.path.exists(sp_vi_path):
        raise FileNotFoundError("Ch∆∞a t√¨m th·∫•y file model SPM (.model). H√£y ƒë·∫£m b·∫£o b·∫°n ƒë√£ train SPM tr∆∞·ªõc ƒë√≥.")

    sp_en = spm.SentencePieceProcessor()
    sp_en.load(sp_en_path)

    sp_vi = spm.SentencePieceProcessor()
    sp_vi.load(sp_vi_path)

    PAD_IDX = sp_en.pad_id() # Th∆∞·ªùng l√† 0
    print(f"-> ƒê√£ load SPM. Vocab EN: {sp_en.get_piece_size()} | Vocab VI: {sp_vi.get_piece_size()}")
    print(f"-> PAD_IDX: {PAD_IDX}")

    # -----------------------------------------
    # B∆Ø·ªöC 4: TH·ªêNG K√ä (D·ª±a tr√™n Subword)
    # -----------------------------------------
    print("\n--- Th·ªëng k√™ D·ªØ li·ªáu Train (Subword Level) ---")
    plot_statistics(en_train, sp_en, "English (Source)")
    plot_statistics(vi_train, sp_vi, "Vietnamese (Target)")

    # -----------------------------------------
    # B∆Ø·ªöC 5: T·∫†O DATASET & DATALOADER
    # -----------------------------------------
    print("\n--- T·∫°o Dataset v√† DataLoader ---")

    # T·∫°o Dataset (Tokenize & Filter)
    train_dataset = TranslationDatasetSPM(en_train, vi_train, sp_en, sp_vi, max_len=MAX_SEQ_LEN)
    val_dataset = TranslationDatasetSPM(en_val, vi_val, sp_en, sp_vi, max_len=MAX_SEQ_LEN)

    # Collate Function
    collate_fn = MyCollateSPM(pad_idx=PAD_IDX)

    # T·∫°o DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2,    # TƒÉng t·ªëc ƒë·ªô load data (tu·ª≥ ch·ªçn)
        pin_memory=True   # T·ªëi ∆∞u khi chuy·ªÉn sang GPU
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )

    print(f"\n-> Train Loader: {len(train_loader)} batches")
    print(f"-> Val Loader:   {len(val_loader)} batches")

    # Ki·ªÉm tra th·ª≠ 1 batch
    src_batch, trg_batch = next(iter(train_loader))
    print(f"\n[Test Batch Shape]")
    print(f"Src: {src_batch.shape} (Batch, Seq_len)")
    print(f"Trg: {trg_batch.shape}")

    return train_loader, val_loader, sp_en, sp_vi

# --- CH·∫†Y CH∆Ø∆†NG TR√åNH ---
if __name__ == "__main__":
    #train_dl, val_dl, sp_en, sp_vi = process_data_pipeline()
    print("\n>>> X·ª¨ L√ù D·ªÆ LI·ªÜU HO√ÄN T·∫§T. S·∫¥N S√ÄNG TRAIN MODEL <<<")


>>> X·ª¨ L√ù D·ªÆ LI·ªÜU HO√ÄN T·∫§T. S·∫¥N S√ÄNG TRAIN MODEL <<<


In [None]:
sp_en = spm.SentencePieceProcessor()
sp_en.load(os.path.join(DATA_DIR, 'spm_en.model'))

sp_vi = spm.SentencePieceProcessor()
sp_vi.load(os.path.join(DATA_DIR, 'spm_vi.model'))

True

# Scaled Dot-Product Attention & Multi-Head Attention

In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F


def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)

    # scores shape: (Batch_size, Num_Heads, Seq_len_Q, Seq_len_K)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # 1. √Åp d·ª•ng Mask
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # 2. T√≠nh Softmax
    p_attn = torch.softmax(scores, dim=-1)

    # 3. √ÅP D·ª§NG DROPOUT
    if dropout is not None:
        p_attn = dropout(p_attn)

    # 4. T√≠nh ƒë·∫ßu ra
    output = torch.matmul(p_attn, value)

    return output, p_attn

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model ph·∫£i chia h·∫øt cho num_heads"

        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.d_model = d_model

        # 4 l·ªõp Linear: Wq, Wk, Wv, v√† Wo (output)
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)

        # Dropout Module
        self.dropout = nn.Dropout(p=dropout_rate)


    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        if mask is not None and mask.dim() == 3:
            # (B, L, L) ‚Üí (B, 1, L, L)
            mask = mask.unsqueeze(1)

        # 1. √Ånh x·∫° tuy·∫øn t√≠nh v√† t√°ch th√†nh Heads
        def transform(x, w):
            return w(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        q = transform(query, self.w_q)
        k = transform(key, self.w_k)
        v = transform(value, self.w_v)

        # 2. T√≠nh Scaled Dot-Product Attention (truy·ªÅn module dropout v√†o)
        # B·∫Øt bu·ªôc ph·∫£i truy·ªÅn self.dropout ƒë·ªÉ g·ªçi h√†m forward() c·ªßa nn.Dropout
        x, self_attn = scaled_dot_product_attention(q, k, v, mask=mask, dropout=self.dropout)

        # 3. N·ªëi c√°c ƒë·∫ßu (Concatenate Heads)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # 4. √Ånh x·∫° tuy·∫øn t√≠nh cu·ªëi c√πng (Wo)
        return self.w_o(x), self_attn

# Positional Encoding (Sinusoidal):

In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    """
    Positional Encoding (PE) Sinusoidal theo c√¥ng th·ª©c g·ªëc c·ªßa b√†i b√°o "Attention Is All You Need".
    PE l√† m·ªôt ma tr·∫≠n c·ªë ƒë·ªãnh, kh√¥ng h·ªçc ƒë∆∞·ª£c.
    """
    def __init__(self, d_model, dropout_rate, max_len=5000):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout_rate)

        # T·∫°o ma tr·∫≠n PE (Max_Len, D_model)
        pe = torch.zeros(max_len, d_model)

        # 1. T√≠nh Pos (V·ªã tr√≠): (Max_Len, 1)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # 2. T√≠nh Div_term (M·∫´u s·ªë): 10000^(2i/d_model)
        # S·ª≠ d·ª•ng log v√† exp ƒë·ªÉ tr√°nh l·ªói s·ªë h·ªçc: exp(2i * -log(10000) / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # 3. G√°n gi√° tr·ªã Sin/Cos
        # C·ªôt ch·∫µn (2i) l√† Sin:
        pe[:, 0::2] = torch.sin(position * div_term)

        # C·ªôt l·∫ª (2i+1) l√† Cos:
        # N·∫øu d_model l·∫ª, c·ªôt cu·ªëi c√πng s·∫Ω ƒë∆∞·ª£c g√°n b·∫±ng pe[:, d_model-1]
        pe[:, 1::2] = torch.cos(position * div_term)

        # Th√™m chi·ªÅu batch (1, Max_Len, D_model)
        pe = pe.unsqueeze(0)

        # ƒêƒÉng k√Ω l√† buffer (kh√¥ng tham gia v√†o qu√° tr√¨nh backward/gradient descent)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Input: x (Embedding ƒë·∫ßu v√†o) shape: (Batch_size, Seq_len, D_model)
        Output: x + PE
        """
        # Th√™m Positional Encoding v√†o Embedding.
        # Ta ch·ªâ l·∫•y PE ƒë·∫øn ƒë·ªô d√†i Seq_len c·ªßa batch hi·ªán t·∫°i.
        # K√≠ch th∆∞·ªõc c·ªßa x l√† x.size(1)
        x = x + self.pe[:, :x.size(1)]

        # √Åp d·ª•ng Dropout
        return self.dropout(x)



# Transformer Encoder Layer

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class PositionwiseFeedForward(nn.Module):
    """
    L·ªõp Feed-Forward Network (FFN) trong Transformer.
    √Åp d·ª•ng ri√™ng r·∫Ω (position-wise) cho t·ª´ng v·ªã tr√≠ trong chu·ªói.
    """
    def __init__(self, d_model, d_ff, dropout_rate=0.1):
        super(PositionwiseFeedForward, self).__init__()

        # L·ªõp 1: TƒÉng chi·ªÅu (d_model -> d_ff)
        self.w_1 = nn.Linear(d_model, d_ff)

        # L·ªõp 2: Gi·∫£m chi·ªÅu (d_ff -> d_model)
        self.w_2 = nn.Linear(d_ff, d_model)

        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x):
        # x shape: (Batch_size, Seq_len, D_model)

        # 1. Linear 1 -> ReLU
        # Output shape: (Batch_size, Seq_len, D_ff)
        output = F.relu(self.w_1(x))

        # 2. Dropout
        output = self.dropout(output)

        # 3. Linear 2
        # Output shape: (Batch_size, Seq_len, D_model)
        return self.w_2(output)

In [None]:
import torch

class SublayerConnection(nn.Module):
    def __init__(self, size, dropout_rate=0.1):
        super(SublayerConnection, self).__init__()

        # size ch√≠nh l√† d_model (v√≠ d·ª•: 512)
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x, sublayer):
        """
        Input: x (input c·ªßa sublayer), sublayer (h√†m/module c·ªßa sublayer: MHA ho·∫∑c FFN)
        """
        # 1. Layer Normalization
        norm_x = self.norm(x)

        # 2. T√≠nh Sublayer (MHA ho·∫∑c FFN)
        sublayer_output = sublayer(norm_x)

        # 3. Dropout
        sublayer_output = self.dropout(sublayer_output)

        # 4. Residual Connection (Add)
        return x + sublayer_output


In [None]:
class EncoderLayer(nn.Module):
    """
    M·ªôt l·ªõp ƒë∆°n l·∫ª trong Transformer Encoder.
    """
    def __init__(self, size, self_attn, feed_forward, dropout_rate):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn        # MultiHeadAttention
        self.feed_forward = feed_forward  # PositionwiseFeedForward
        self.sublayer = nn.ModuleList([
            SublayerConnection(size, dropout_rate), # Cho Self-Attention
            SublayerConnection(size, dropout_rate)  # Cho FFN
        ])
        self.size = size # d_model

    def forward(self, x, mask):
        # x shape: (Batch_size, Seq_len, D_model)

        # 1. Sublayer 1: Multi-Head Self-Attention (v·ªõi Residual & Norm)
        # T·ª± ch√∫ √Ω (Self-Attention): Q=K=V=x. Encoder kh√¥ng c·∫ßn look-ahead mask.
        # Mask ·ªü ƒë√¢y l√† Padding Mask (Mask: 1 l√† gi·ªØ, 0 l√† ·∫©n)
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask=mask)[0])

        # 2. Sublayer 2: Feed-Forward Network (v·ªõi Residual & Norm)
        x = self.sublayer[1](x, self.feed_forward)

        return x

# Transformer Decoder Layer

In [None]:
import torch

def subsequent_mask(size):
    return torch.tril(torch.ones(1, size, size, dtype=torch.bool))

def make_std_mask(tgt, pad_idx):
    """
    T·∫°o Target Mask (k·∫øt h·ª£p Padding Mask v√† Look-Ahead Mask).
    D√πng cho Decoder Self-Attention.
    """
    # 1. T·∫°o Padding Mask (True/1 cho v·ªã tr√≠ KH√îNG ph·∫£i PAD)
    # unsqueeze(-2) ƒë·ªÉ k√≠ch th∆∞·ªõc l√† (B, 1, L_tgt)
    tgt_pad_mask = (tgt != pad_idx).unsqueeze(-2)

    # 2. T·∫°o Look-Ahead Mask
    seq_len = tgt.size(-1)
    look_ahead_mask = subsequent_mask(seq_len).type_as(tgt_pad_mask.data)

    # 3. K·∫øt h·ª£p (AND logic): Ch·ªâ True n·∫øu KH√îNG ph·∫£i PAD V√Ä KH√îNG ph·∫£i t∆∞∆°ng lai
    tgt_mask = tgt_pad_mask & look_ahead_mask

    return tgt_mask.requires_grad_(False)

# =================================================================
# PH·∫¶N B: L·ªöP BATCH - T·ªî CH·ª®C INPUT V√Ä T·∫†O MASKS CHO M·ªñI L·∫¶N CH·∫†Y
# =================================================================

class Batch:
    """
    L·ªõp ƒë√≥ng g√≥i d·ªØ li·ªáu ƒë·∫ßu v√†o (src/tgt) v√† t·∫°o ra t·∫•t c·∫£ c√°c masks c·∫ßn thi·∫øt.
    """
    def __init__(self, src, tgt, pad_idx):
        self.src = src # (B, L_src)

        # 1. T√°ch Target cho Input v√† Output
        # TGT INPUT: D√πng ƒë·ªÉ ƒë∆∞a v√†o Decoder (g·ªìm <sos>...tok_n)
        self.tgt = tgt[:, :-1]
        # TGT OUTPUT (tgt_y): D√πng ƒë·ªÉ t√≠nh Loss (g·ªìm tok_1...<eos>)
        self.tgt_y = tgt[:, 1:]

        # 2. T·∫†O MASKS C·∫¶N THI·∫æT

        # a) Source Mask (Padding Mask): D√πng cho ENCODER & CROSS-ATTENTION
        # K√≠ch th∆∞·ªõc (B, 1, L_src)
        self.src_mask = (src != pad_idx).unsqueeze(-2).requires_grad_(False)

        # b) Target Mask (Padding + Look-Ahead): D√πng cho DECODER SELF-ATTENTION
        # K√≠ch th∆∞·ªõc (B, L_tgt, L_tgt)
        self.tgt_mask = make_std_mask(self.tgt, pad_idx)

        # S·ªë l∆∞·ª£ng token th·ª±c t·∫ø (kh√¥ng ph·∫£i pad) ƒë·ªÉ chu·∫©n h√≥a Loss
        self.ntokens = (self.tgt_y != pad_idx).data.sum()


In [None]:
class DecoderLayer(nn.Module):
    """
    M·ªôt l·ªõp ƒë∆°n l·∫ª trong Transformer Decoder.
    Ch·ª©a: Masked Self-Attention, Cross-Attention, v√† FFN.
    """
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout_rate):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn      # MultiHeadAttention (Cho Self-Attention)
        self.src_attn = src_attn        # MultiHeadAttention (Cho Cross-Attention)
        self.feed_forward = feed_forward

        # 3 Sublayer Connections cho 3 Sublayer
        self.sublayer = nn.ModuleList([
            SublayerConnection(size, dropout_rate), # Masked Self-Attention
            SublayerConnection(size, dropout_rate), # Cross-Attention
            SublayerConnection(size, dropout_rate)  # FFN
        ])

    def forward(self, x, memory, src_mask, tgt_mask):
        """
        Input:
          x: decoder input (output c·ªßa l·ªõp tr∆∞·ªõc) (Batch_size, L_tgt, D_model)
          memory: output c·ªßa Encoder (Batch_size, L_src, D_model)
          src_mask: Padding Mask cho Encoder Output (memory)
          tgt_mask: Look-Ahead Mask (Subsequent Mask) cho Decoder Input (x)
        """

        # 1. Sublayer 1: Masked Multi-Head Self-Attention
        # Q=K=V=x. D√πng tgt_mask (Look-Ahead Mask)
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask=tgt_mask)[0])

        # 2. Sublayer 2: Multi-Head Cross-Attention (Encoder-Decoder Attention)
        # Q=x (target), K=V=memory (source/encoder output)
        # D√πng src_mask (Padding Mask) ƒë·ªÉ kh√¥ng ch√∫ √Ω ƒë·∫øn c√°c token <pad> t·ª´ Encoder
        x = self.sublayer[1](x, lambda x: self.src_attn(x, memory, memory, mask=src_mask)[0])

        # 3. Sublayer 3: Feed-Forward Network
        x = self.sublayer[2](x, self.feed_forward)

        return x

# Transformer

In [None]:
class Encoder(nn.Module):
    """
    To√†n b·ªô Transformer Encoder, bao g·ªìm Stacking N Encoder Layers.
    """
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        # layer: l√† m·ªôt ƒë·ªëi t∆∞·ª£ng EncoderLayer ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o
        # N: s·ªë l∆∞·ª£ng l·ªõp (v√≠ d·ª• N=6)

        # Stacking N l·ªõp
        self.layers = clones(layer, N)

        # Layer Normalization cu·ªëi c√πng
        self.norm = nn.LayerNorm(layer.size)

        self.size = layer.size

    def forward(self, x, mask):
        # x shape: (Batch_size, Seq_len, D_model) - ƒê√£ l√† Embedding + PE
        # mask: Padding Mask (t·ª´ Source)

        # Ch·∫°y qua N l·ªõp Encoder Layer
        for layer in self.layers:
            x = layer(x, mask)

        # √Åp d·ª•ng Layer Norm cu·ªëi c√πng tr∆∞·ªõc khi tr·∫£ v·ªÅ
        return self.norm(x)

In [None]:
class Decoder(nn.Module):
    """
    To√†n b·ªô Transformer Decoder, bao g·ªìm Stacking N Decoder Layers.
    """
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        # layer: l√† m·ªôt ƒë·ªëi t∆∞·ª£ng DecoderLayer ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o

        # Stacking N l·ªõp
        self.layers = clones(layer, N)

        # Layer Normalization cu·ªëi c√πng
        self.norm = nn.LayerNorm(layer.size)

        self.size = layer.size

    def forward(self, x, memory, src_mask, tgt_mask):
        # x: Decoder Input (Embedding + PE)
        # memory: Output c·ªßa Encoder (Encoder Memory)
        # src_mask: Padding Mask c·ªßa Source
        # tgt_mask: Look-Ahead Mask c·ªßa Target

        # Ch·∫°y qua N l·ªõp Decoder Layer
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)

        # √Åp d·ª•ng Layer Norm cu·ªëi c√πng tr∆∞·ªõc khi tr·∫£ v·ªÅ
        return self.norm(x)

In [None]:
import copy

class Transformer(nn.Module):
    """
    M√¥ h√¨nh Transformer ho√†n ch·ªânh (Source-to-Target)
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(Transformer, self).__init__()
        self.encoder = encoder      # L·ªõp Encoder (Stacking N Encoder Layers)
        self.decoder = decoder      # L·ªõp Decoder (Stacking N Decoder Layers)
        self.src_embed = src_embed  # L·ªõp Embedding + PE cho Source
        self.tgt_embed = tgt_embed  # L·ªõp Embedding + PE cho Target
        self.generator = generator  # L·ªõp Linear cu·ªëi c√πng (Projection to Vocab)

    def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Input:
          src: Source sequence (indices) (B, L_src)
          tgt: Target sequence (indices) (B, L_tgt)
          src_mask: Padding Mask (B, 1, L_src)
          tgt_mask: Look-Ahead Mask (B, L_tgt, L_tgt)
        """

        # 1. Ch·∫°y Encoder v√† l·∫•y Memory
        memory = self.encode(src, src_mask) # (B, L_src, D_model)

        # 2. Ch·∫°y Decoder
        output = self.decode(memory, src_mask, tgt, tgt_mask) # (B, L_tgt, D_model)

        # 3. Projection Logits (Chuy·ªÉn D_model th√†nh Vocab Size)
        return self.generator(output)

    def encode(self, src, src_mask):
        # src_embed(src): Embedding + Positional Encoding
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        # tgt_embed(tgt): Embedding + Positional Encoding
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

# --- Generator/Projection Layer ---
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        # L·ªõp tuy·∫øn t√≠nh √°nh x·∫° D_model v·ªÅ Vocab Size
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # Tr·∫£ v·ªÅ Logits (tr∆∞·ªõc Softmax)
        return self.proj(x)

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model
    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

In [None]:
# --- H√†m utility clones (r·∫•t quan tr·ªçng) ---
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def make_model(src_vocab, tgt_vocab, N=N_LAYERS, d_model=D_MODEL, d_ff=D_FF, h=N_HEADS, dropout=DROPOUT_RATE):
    "H√†m x√¢y d·ª±ng m√¥ h√¨nh Transformer t·ª´ c√°c kh·ªëi ƒë√£ ƒë·ªãnh nghƒ©a."

    # 1. Kh·ªüi t·∫°o c√°c Sub-module
    c = copy.deepcopy

    # Gi·∫£ ƒë·ªãnh c√°c l·ªõp n√†y ƒë√£ ƒë∆∞·ª£c ƒë·ªãnh nghƒ©a chi ti·∫øt (MHA, FFN, PE, SublayerConnection)
    attn = MultiHeadAttention(d_model, h, dropout)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    pe = PositionalEncoding(d_model, dropout)
    sublayer_conn = SublayerConnection(d_model, dropout) # M·∫´u ƒë·ªÉ clones

    # 2. X√¢y d·ª±ng Layer m·∫´u (s·ª≠ d·ª•ng clones an to√†n cho Sublayers)
    class SafeEncoderLayer(EncoderLayer):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.sublayer = clones(sublayer_conn, 2)

    class SafeDecoderLayer(DecoderLayer):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.sublayer = clones(sublayer_conn, 3)

    enc_layer_proto = SafeEncoderLayer(d_model, c(attn), c(ff), dropout)
    dec_layer_proto = SafeDecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)

    # 3. L·∫Øp r√°p Encoder/Decoder l·ªõn (d√πng clones cho N layers)
    encoder = Encoder(enc_layer_proto, N)
    decoder = Decoder(dec_layer_proto, N)

    # 4. L·ªõp Embedding v√† Generator
    src_embed = nn.Sequential(Embeddings(d_model, src_vocab), c(pe))
    tgt_embed = nn.Sequential(Embeddings(d_model, tgt_vocab), c(pe))
    generator = Generator(d_model, tgt_vocab)

    # 5. M√¥ h√¨nh Ho√†n ch·ªânh
    model = Transformer(
        encoder, decoder,
        src_embed, tgt_embed,
        generator
    )

    # 6. Kh·ªüi t·∫°o tham s·ªë b·∫±ng Xavier (Gi√∫p ·ªïn ƒë·ªãnh training)
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model



# Training Loop

In [None]:
from torch.nn.utils import clip_grad_norm_

class LossCompute:
    def __init__(self, criterion, opt=None, clip_norm=None):
        self.criterion = criterion
        self.opt = opt
        self.clip_norm = clip_norm

    def __call__(self, x, y, norm):
        loss = self.criterion(
            x.contiguous().view(-1, x.size(-1)),
            y.contiguous().view(-1)
        )

        normalized_loss = loss / norm

        if self.opt is not None:
            normalized_loss.backward()

            if self.clip_norm is not None:
                clip_grad_norm_(
                    (p for g in self.opt.optimizer.param_groups for p in g['params']),
                    self.clip_norm
                )

            self.opt.step()
            self.opt.zero_grad()

        return loss


In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, size, padding_idx, smoothing=0.1):
        super().__init__()
        self.size = size
        self.padding_idx = padding_idx
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing
        self.denominator = size - 2
        self.criterion = nn.KLDivLoss(reduction='sum')

    def forward(self, logits, target):
        assert logits.size(1) == self.size

        log_probs = F.log_softmax(logits, dim=-1)

        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / self.denominator)
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            true_dist[:, self.padding_idx] = 0.0
            true_dist[target == self.padding_idx] = 0.0

        return self.criterion(log_probs, true_dist)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

class NoamOpt:
    "Optimizer t√πy ch·ªânh v·ªõi Noam/Warmup scheduler, h·ªó tr·ª£ Checkpointing."
    def __init__(self, model_size, warmup_steps, optimizer):
        # L∆∞u tr·ªØ base optimizer (c√≥ th·ªÉ l√† Adam, AdamW,...)
        self.optimizer = optimizer

        # C√°c tr·∫°ng th√°i c·∫ßn l∆∞u
        self._step = 0
        self.warmup_steps = warmup_steps
        self.model_size = model_size
        self._rate = 0

        # Thi·∫øt l·∫≠p LR ban ƒë·∫ßu trong param_groups
        self.set_lr_in_param_groups(0)

    def set_lr_in_param_groups(self, rate):
        "H√†m ti·ªán √≠ch ƒë·ªÉ c·∫≠p nh·∫≠t LR cho base optimizer."
        for p in self.optimizer.param_groups:
            p['lr'] = rate

    def step(self):
        "C·∫≠p nh·∫≠t c√°c tham s·ªë v√† t·ªëc ƒë·ªô h·ªçc."
        self._step += 1
        rate = self.rate()

        # C·∫≠p nh·∫≠t LR c·ªßa base optimizer
        self.set_lr_in_param_groups(rate)

        # G·ªçi step c·ªßa base optimizer
        self.optimizer.step()

    def rate(self, step=None):
        "T√≠nh to√°n LR theo c√¥ng th·ª©c Noam."
        if step is None:
            step = self._step
        if step == 0:
            return 0

        # C√¥ng th·ª©c Noam: d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
        return self.model_size**(-0.5) * \
               min(step**(-0.5), step * self.warmup_steps**(-1.5))

    def zero_grad(self):
        "G·ªçi zero_grad c·ªßa base optimizer."
        self.optimizer.zero_grad()

    # --- PH∆Ø∆†NG TH·ª®C H·ªñ TR·ª¢ CHECKPOINTING (V·∫•n ƒë·ªÅ 2) ---
    def state_dict(self):
        "Tr·∫£ v·ªÅ state_dict c·ªßa Scheduler (bao g·ªìm tr·∫°ng th√°i c·ªßa base optimizer v√† step)."
        return {
            'optimizer': self.optimizer.state_dict(),
            'step': self._step,
            'warmup_steps': self.warmup_steps,
            'model_size': self.model_size
        }

    def load_state_dict(self, state_dict):
        "Load tr·∫°ng th√°i cho Scheduler."
        self.optimizer.load_state_dict(state_dict['optimizer'])
        self._step = state_dict['step']
        self.warmup_steps = state_dict['warmup_steps']
        self.model_size = state_dict['model_size']
        # Sau khi load, c·∫ßn c·∫≠p nh·∫≠t l·∫°i LR hi·ªán t·∫°i
        self.set_lr_in_param_groups(self.rate())


In [None]:

def calculate_ppl(val_loader, model, pad_idx, criterion_val):
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for src_batch, tgt_batch in val_loader:
            src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)
            batch = Batch(src_batch, tgt_batch, pad_idx)

            output_logits = model(
                batch.src, batch.tgt,
                batch.src_mask, batch.tgt_mask
            )

            loss = criterion_val(
                output_logits.view(-1, output_logits.size(-1)),
                batch.tgt_y.view(-1)
            )

            total_loss += loss.item()
            total_tokens += batch.ntokens.item()

    if total_tokens == 0:
        return float('inf'), float('inf')

    avg_loss = total_loss / total_tokens
    ppl = math.exp(avg_loss)
    return avg_loss, ppl


In [None]:
import torch.nn.functional as F

# H√†m n√†y s·∫Ω ƒë∆∞·ª£c g·ªçi t·ª´ b√™n trong run_full_training_pipeline
def calculate_ppl(val_loader, model, pad_idx, criterion_val):
    "T√≠nh Perplexity tr√™n Validation Set b·∫±ng Cross-Entropy Loss chu·∫©n."
    model.eval() # B·∫Øt bu·ªôc ph·∫£i chuy·ªÉn sang ch·∫ø ƒë·ªô ƒë√°nh gi√°
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for src_batch, tgt_batch in val_loader:
            src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)
            batch = Batch(src_batch, tgt_batch, pad_idx)

            output_logits = model(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)

            # Log-Softmax Logits (V√¨ criterion_val l√† CrossEntropyLoss)
            log_probs = F.log_softmax(output_logits, dim=-1)

            # T√≠nh Loss b·∫±ng Cross-Entropy chu·∫©n (reduction='sum')
            loss = criterion_val(log_probs.contiguous().view(-1, output_logits.size(-1)),
                                 batch.tgt_y.contiguous().view(-1))

            total_loss += loss.item()
            total_tokens += batch.ntokens.item()

    if total_tokens == 0:
        return float('inf'), float('inf')

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity

In [None]:
import torch
import os
import math
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_

# --- C√ÅC THAM S·ªê CHUNG V√Ä BI·∫æN GLOBAL (Gi·∫£ ƒë·ªãnh) ---
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# PAD_IDX = 0
# N_EPOCHS = 10
# CHECKPOINT_DIR = '/content/drive/MyDrive/TransformerMT/checkpoints'
# Gi·∫£ ƒë·ªãnh: model, train_loader, val_loader, criterion, optimizer ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o

# --- C√ÅC H√ÄM TI·ªÜN √çCH C·∫¶N THI·∫æT ---
# Batch, calculate_ppl, LossCompute, NoamOpt ƒë√£ ƒë∆∞·ª£c ƒë·ªãnh nghƒ©a v√† s·ª≠a ƒë·ªïi

def save_checkpoint(model, optimizer, epoch, best_val_ppl,
                    is_best=False, filename='latest_checkpoint.pth'):

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_ppl': best_val_ppl,
        'config': {
            'pad_idx': PAD_IDX,
            'n_epochs': EPOCHS,
            'clip_norm': GRAD_CLIP
        }
    }

    path = os.path.join(CHECKPOINT_DIR, filename)
    torch.save(state, path)

    if is_best:
        torch.save(state, os.path.join(CHECKPOINT_DIR, 'best_model.pth'))
        print(f"üî• Best model saved at epoch {epoch}")

    print(f"üíæ Checkpoint saved: {path}")


def load_checkpoint(model, optimizer, filename='latest_checkpoint.pth'):
    """T·∫£i tr·∫°ng th√°i m√¥ h√¨nh v√† optimizer ƒë·ªÉ ti·∫øp t·ª•c training."""
    filepath = os.path.join(CHECKPOINT_DIR, filename)

    if not os.path.exists(filepath):
        print(f"Kh√¥ng t√¨m th·∫•y checkpoint t·∫°i {filepath}. B·∫Øt ƒë·∫ßu training t·ª´ ƒë·∫ßu.")
        return 1, float('inf') # Tr·∫£ v·ªÅ epoch = 1 v√† PPL v√¥ c·ª±c

    print(f"-> ƒêang t·∫£i checkpoint t·ª´ {filepath}...")

    checkpoint = torch.load(filepath, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_ppl = checkpoint['best_val_ppl']

    print(f"-> T·∫£i th√†nh c√¥ng. Ti·∫øp t·ª•c training t·ª´ Epoch {start_epoch}, PPL t·ªët nh·∫•t: {best_val_ppl:.2f}")

    return start_epoch, best_val_ppl


def run_full_training_pipeline(model, train_loader, val_loader,
                               criterion_train, criterion_val, # <-- ƒê√É S·ª¨A CH·ªÆ K√ù
                               optimizer, pad_idx, n_epochs, clip_norm=1.0):

    # 1. T·∫¢I CHECKPOINT
    start_epoch, best_val_ppl = load_checkpoint(model, optimizer)

    # Kh·ªüi t·∫°o Loss Compute (S·ª≠ d·ª•ng criterion_train cho Forward/Backward)
    loss_compute = LossCompute(criterion_train, optimizer, clip_norm=clip_norm)

    for epoch in range(start_epoch, n_epochs + 1):

        model.train()

        # T√≠ch l≈©y Loss v√† Tokens d∆∞·ªõi d·∫°ng Tensor tr√™n GPU (ƒê·ªÉ t·ªëi ∆∞u hi·ªáu nƒÉng)
        total_loss = torch.tensor(0.0, device=DEVICE)
        total_tokens = torch.tensor(0, device=DEVICE)

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epochs} (Train)", unit="batch")

        for i, (src_batch, tgt_batch) in enumerate(progress_bar):

            src_batch, tgt_batch = src_batch.to(DEVICE), tgt_batch.to(DEVICE)
            batch = Batch(src_batch, tgt_batch, pad_idx)

            # FORWARD PASS V√Ä BACKWARD
            # L∆∞u √Ω: loss_compute s·ª≠ d·ª•ng criterion_train (Label Smoothing)
            output_logits = model(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)
            loss_value = loss_compute(output_logits, batch.tgt_y, batch.ntokens)

            # T√≠ch l≈©y tr√™n GPU
            total_loss += loss_value
            total_tokens += batch.ntokens

            # LOGGING
            if total_tokens.item() > 0:
                avg_loss = total_loss.item() / total_tokens.item()

                progress_bar.set_postfix(
                    Loss=f"{avg_loss:.4f}",
                    PPL=f"{math.exp(avg_loss):.2f}",
                    LR=f"{optimizer.rate():.6f}"
                )

        # --- ƒê√ÅNH GI√Å TR√äN VALIDATION SET ---
        # S·ª¨ D·ª§NG CRITERION_VAL (Cross-Entropy chu·∫©n) ƒë·ªÉ ƒëo l∆∞·ªùng ch√≠nh x√°c
        val_loss, val_ppl = calculate_ppl(val_loader, model, pad_idx, criterion_val)

        # ... (Ph·∫ßn in log v√† l∆∞u checkpoint gi·ªØ nguy√™n)
        print(f"\n===========================================================")
        print(f"EPOCH {epoch} K·∫æT TH√öC | TRAIN LOSS: {avg_loss:.4f} | VAL LOSS: {val_loss:.4f}")
        print(f"VAL PERPLEXITY (PPL): {val_ppl:.2f}")
        print(f"===========================================================\n")

        # L∆ØU CHECKPOINT (Logic t∆∞∆°ng t·ª±)
        is_best = val_ppl < best_val_ppl
        if is_best:
            best_val_ppl = val_ppl

        save_checkpoint(model, optimizer, epoch, best_val_ppl, is_best=is_best)

        model.train()

In [None]:
def process_data_pipeline_for_test():
    print("=== B·∫ÆT ƒê·∫¶U X·ª¨ L√ù D·ªÆ LI·ªÜU (TEST MODE) ===\n")

    # -----------------------------------------
    # B∆Ø·ªöC 1: LOAD D·ªÆ LI·ªÜU TH√î
    # -----------------------------------------
    path_en = os.path.join(DATA_DIR, 'train.en')
    path_vi = os.path.join(DATA_DIR, 'train.vi')

    raw_en = load_raw_data(path_en)
    raw_vi = load_raw_data(path_vi)

    assert len(raw_en) == len(raw_vi), "L·ªói l·ªách d√≤ng!"

    # --- [QUAN TR·ªåNG] C·∫ÆT D·ªÆ LI·ªÜU ƒê·ªÇ TEST ---
    if TEST_MODE_LIMIT is not None:
        print(f"\n  CH·∫æ ƒê·ªò TEST: ƒêang c·∫Øt d·ªØ li·ªáu xu·ªëng c√≤n {TEST_MODE_LIMIT} c√¢u...")
        raw_en = raw_en[:TEST_MODE_LIMIT]
        raw_vi = raw_vi[:TEST_MODE_LIMIT]
    # ----------------------------------------

    print(f"-> T·ªïng s·ªë c√¢u s·ª≠ d·ª•ng: {len(raw_en)}")

    # -----------------------------------------
    # B∆Ø·ªöC 1: CHIA T·∫¨P TRAIN / VAL
    # -----------------------------------------
    print(f"\n--- Chia t·∫≠p Train/Val (Ratio: {VAL_SPLIT_RATIO}) ---")
    en_train, en_val, vi_train, vi_val = train_test_split(
        raw_en, raw_vi, test_size=VAL_SPLIT_RATIO, random_state=SEED
    )
    print(f"-> Train: {len(en_train)} | Val: {len(en_val)}")

    # -----------------------------------------
    # B∆Ø·ªöC 2: LOAD MODEL SENTENCEPIECE
    # -----------------------------------------
    print("\n--- Load SPM Models ---")
    sp_en = spm.SentencePieceProcessor(); sp_en.load(os.path.join(DATA_DIR, 'spm_en.model'))
    sp_vi = spm.SentencePieceProcessor(); sp_vi.load(os.path.join(DATA_DIR, 'spm_vi.model'))

    print(f"-> Vocab Size: {sp_en.get_piece_size()} / {sp_vi.get_piece_size()}")


    # -----------------------------------------
    # B∆Ø·ªöC 3: T·∫†O DATASET & DATALOADER
    # -----------------------------------------
    print("\n--- T·∫°o DataLoader ---")
    train_dataset = TranslationDatasetSPM(en_train, vi_train, sp_en, sp_vi, max_len=MAX_SEQ_LEN)
    val_dataset = TranslationDatasetSPM(en_val, vi_val, sp_en, sp_vi, max_len=MAX_SEQ_LEN)

    collate_fn = MyCollateSPM(pad_idx=PAD_IDX)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    print(f"-> Train Batches: {len(train_loader)}")
    print(f"-> Val Batches:   {len(val_loader)}")

    return train_loader, val_loader

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import math
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import copy


def main():

    train_loader, val_loader = process_data_pipeline_for_test()
    print(f"-> Train/Val Loader ƒë√£ s·∫µn s√†ng (Batch Size: {BATCH_SIZE})")


    # ----------------------------------------------------
    ## B∆Ø·ªöC 1.3: KH·ªûI T·∫†O M√î H√åNH V√Ä CHUY·ªÇN DEVICE
    # ----------------------------------------------------
    print("3. ƒêang kh·ªüi t·∫°o m√¥ h√¨nh Transformer...")

    # G·ªçi h√†m make_model ƒë√£ ƒë∆∞·ª£c ƒë·ªãnh nghƒ©a
    model = make_model(VOCAB_SIZE_SPM, VOCAB_SIZE_SPM, N=N_LAYERS, d_model=D_MODEL, h=N_HEADS, d_ff=D_FF, dropout=DROPOUT_RATE)
    model.to(DEVICE)
    print(f"-> M√¥ h√¨nh ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o v√† chuy·ªÉn sang {DEVICE}. T·ªïng tham s·ªë: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


    # ----------------------------------------------------
    ## B∆Ø·ªöC 1.4: KH·ªûI T·∫†O C√ÅC TH√ÄNH PH·∫¶N TRAINING
    # ----------------------------------------------------
    print("4. ƒêang thi·∫øt l·∫≠p Loss Function v√† Optimizer...")

    # 1. Base Optimizer (AdamW ƒë∆∞·ª£c khuy·∫øn ngh·ªã)
    base_optimizer = optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4)

    # 2. Noam Scheduler (Wrapped Optimizer)
    optimizer = NoamOpt(model_size=D_MODEL, warmup_steps=WARMUP_STEPS, optimizer=base_optimizer)

    # 3. Criterion (Label Smoothing)
    # criterion = LabelSmoothing(size=VOCAB_SIZE_VI, padding_idx=PAD_IDX, smoothing=LABEL_SMOOTHING).to(DEVICE)
    # Do KLDivLoss ho·∫°t ƒë·ªông tr√™n Log-Probs, ta d√πng nn.CrossEntropyLoss cho m·ª•c ƒë√≠ch
    # t√≠nh loss tr√™n Validation set (VALID_CRITERION)

    # Kh·ªüi t·∫°o Criterion m·ªõi: D√πng KLDivLoss/LabelSmoothing cho Train
    criterion_train = LabelSmoothing(size=VOCAB_SIZE_SPM, padding_idx=PAD_IDX, smoothing=LABEL_SMOOTHING).to(DEVICE)

    # Kh·ªüi t·∫°o Criterion cho Validation (CrossEntropyLoss ƒë∆°n thu·∫ßn)
    criterion_val = nn.CrossEntropyLoss(ignore_index=PAD_IDX, reduction='sum').to(DEVICE)

    print(f"-> Optimizer (Noam) v√† Loss (Label Smoothing {LABEL_SMOOTHING}) ƒë√£ s·∫µn s√†ng.")


    # ----------------------------------------------------
    ## B∆Ø·ªöC 1.5: CH·∫†Y V√íNG L·∫∂P HU·∫§N LUY·ªÜN CH√çNH
    # ----------------------------------------------------
    print("5. B·∫Øt ƒë·∫ßu V√≤ng L·∫∑p Hu·∫•n Luy·ªán...")

    # Ch·∫°y h√†m training ch√≠nh ƒë√£ ƒë∆∞·ª£c s·ª≠a l·ªói v√† t·ªëi ∆∞u
    run_full_training_pipeline(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion_train=criterion_train, # C·∫ßn s·ª≠a l·∫°i signature c·ªßa h√†m training
        criterion_val=criterion_val,     # Truy·ªÅn v√†o 2 criterion kh√°c nhau
        optimizer=optimizer,
        pad_idx=PAD_IDX,
        n_epochs=EPOCHS,
        clip_norm=GRAD_CLIP
    )


# --- B∆Ø·ªöC 2: KH·ªêI CH·∫†Y CH∆Ø∆†NG TR√åNH ---
if __name__ == "__main__":
    # ƒê·∫£m b·∫£o CHECKPOINT_DIR t·ªìn t·∫°i
    if not os.path.exists(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)

    # G·ªçi h√†m ch√≠nh
    #main()
    print("\n--- ƒê√É HO√ÄN T·∫§T SETUP H√ÄM MAIN ---")
    print("B√¢y gi·ªù b·∫°n ch·ªâ c·∫ßn g·ªçi h√†m main() ƒë·ªÉ b·∫Øt ƒë·∫ßu training.")


--- ƒê√É HO√ÄN T·∫§T SETUP H√ÄM MAIN ---
B√¢y gi·ªù b·∫°n ch·ªâ c·∫ßn g·ªçi h√†m main() ƒë·ªÉ b·∫Øt ƒë·∫ßu training.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
import os
import math
import copy # V·∫´n c·∫ßn copy v√† math cho vi·ªác kh·ªüi t·∫°o model

# ==============================================================================
# 1. C·∫§U H√åNH & LOAD SPM (Kh·ªüi t·∫°o c√°c h·∫±ng s·ªë c·∫ßn thi·∫øt)
# ==============================================================================


print("--- 1. LOADING SPM MODELS ---")
try:
    sp_en = spm.SentencePieceProcessor()
    sp_en.load(os.path.join(DATA_DIR, 'spm_en.model'))
    sp_vi = spm.SentencePieceProcessor()
    sp_vi.load(os.path.join(DATA_DIR, 'spm_vi.model'))

except Exception as e:
    print(f"L·ªñI LOAD SPM: {e}. Vui l√≤ng ki·ªÉm tra l·∫°i ƒë∆∞·ªùng d·∫´n.")
    exit()


try:
    # 1. Kh·ªüi t·∫°o model r·ªóng d·ª±a tr√™n ki·∫øn tr√∫c ƒë√£ ƒë·ªãnh nghƒ©a
    model = make_model(VOCAB_SIZE_SPM, VOCAB_SIZE_SPM, N=N_LAYERS, d_model=D_MODEL, h=N_HEADS).to(DEVICE)
    print("-> ƒê√£ kh·ªüi t·∫°o c·∫•u tr√∫c m√¥ h√¨nh (c·∫ßn load tr·ªçng s·ªë).")
except NameError:
    print("\n--- L·ªñI QUAN TR·ªåNG ---")
    print("  L·ªõp ho·∫∑c h√†m `make_model` CH∆ØA ƒê∆Ø·ª¢C ƒê·ªäNH NGHƒ®A. Vui l√≤ng ch·∫°y l·∫°i ƒëo·∫°n code ƒë·ªãnh nghƒ©a ki·∫øn tr√∫c Transformer tr∆∞·ªõc ƒë√≥.")
    exit()


# ------------------------------------------------------------------------------
# 3. T·∫¢I TR·ªåNG S·ªê T·ª™ CHECKPOINT V√Ä H√ÄM DECODE
# ------------------------------------------------------------------------------

print("\n--- 3. T·∫¢I TR·ªåNG S·ªê V√Ä SETUP INFERENCE ---")

# 1. T·∫£i tr·ªçng s·ªë
checkpoint_path = os.path.join(CHECKPOINT_DIR, 'latest_checkpoint.pth')

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(DEVICE)
else:
    print(f"  C·∫¢NH B√ÅO: Kh√¥ng t√¨m th·∫•y checkpoint t·∫°i {checkpoint_path}.")
    print("Model s·∫Ω ch·∫°y v·ªõi tr·ªçng s·ªë ng·∫´u nhi√™n ho·∫∑c ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o tr∆∞·ªõc ƒë√≥.")

# H√†m h·ªó tr·ª£ cho Masking
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return subsequent_mask == 0

def make_std_mask_decode(tgt, pad_idx):
    tgt_mask = (tgt != pad_idx).unsqueeze(-2)
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data).to(DEVICE)
    return tgt_mask

# H√†m d·ªãch (Greedy Decoding)
def simple_greedy_decode(model, src, src_mask, max_len, sp_vi):
    model.eval()
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(sp_vi.bos_id()).type_as(src.data).to(DEVICE)

    for _ in range(max_len - 1):
        tgt_mask = make_std_mask_decode(ys, sp_vi.pad_id())
        out = model.decode(memory, src_mask, ys, tgt_mask)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word).to(DEVICE)], dim=1)

        if next_word == sp_vi.eos_id():
            break

    # Gi·∫£i m√£ v√† tr·∫£ v·ªÅ k·∫øt qu·∫£
    translated_ids = ys.squeeze(0).tolist()[1:]
    if sp_vi.eos_id() in translated_ids:
        translated_ids = translated_ids[:translated_ids.index(sp_vi.eos_id())]

    return sp_vi.decode_ids(translated_ids)


def beam_search_decode(model, src, src_mask, max_len, sp_vi, beam_size=12):
    model.eval()
    device = src.device

    memory = model.encode(src, src_mask)

    # (sequence, log_prob)
    beams = [(torch.tensor([[sp_vi.bos_id()]], device=device), 0.0)]
    completed = []

    for _ in range(max_len - 1):
        new_beams = []

        for seq, score in beams:
            if seq[0, -1].item() == sp_vi.eos_id():
                completed.append((seq, score))
                continue

            tgt_mask = make_std_mask_decode(seq, sp_vi.pad_id())
            out = model.decode(memory, src_mask, seq, tgt_mask)
            logits = model.generator(out[:, -1])
            log_probs = F.log_softmax(logits, dim=-1)

            topk_log_probs, topk_ids = torch.topk(log_probs, beam_size, dim=-1)

            for k in range(beam_size):
                next_id = topk_ids[0, k].item()
                next_score = score + topk_log_probs[0, k].item()

                next_seq = torch.cat(
                    [seq, torch.tensor([[next_id]], device=device)], dim=1
                )

                new_beams.append((next_seq, next_score))

        # Gi·ªØ beam t·ªët nh·∫•t
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]

        if len(completed) >= beam_size:
            break

    # N·∫øu ch∆∞a c√≥ eos ‚Üí l·∫•y beam t·ªët nh·∫•t
    if len(completed) == 0:
        completed = beams

    best_seq = max(completed, key=lambda x: x[1])[0]

    # B·ªè BOS v√† EOS
    translated_ids = best_seq.squeeze(0).tolist()[1:]
    if sp_vi.eos_id() in translated_ids:
        translated_ids = translated_ids[:translated_ids.index(sp_vi.eos_id())]

    return sp_vi.decode_ids(translated_ids)



# ------------------------------------------------------------------------------
# 4. TEST C√ÇU ƒê∆†N GI·∫¢N
# ------------------------------------------------------------------------------
def run_simple_test(input_sentence):
    model.eval()

    # 1. Tokenize v√† chu·∫©n b·ªã Input
    src_ids = sp_en.encode_as_ids(input_sentence.lower())
    src_ids = [BOS_IDX] + src_ids + [EOS_IDX]
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
    src_mask = (src_tensor != PAD_IDX).unsqueeze(-2).to(DEVICE)

    # 2. D·ªãch
    with torch.no_grad():
        translation = beam_search_decode(model, src_tensor, src_mask, 100, sp_vi)

    print("-" * 50)
    print(f"EN Input: {input_sentence}")
    print(f"VI Output: {translation}")
    print("-" * 50)


print("\n--- 4. TH·ª¨ NGHI·ªÜM D·ªäCH (Sau khi load model) ---")
#run_simple_test("I am happy to test the translation model.")
#run_simple_test("What is your name?")

--- 1. LOADING SPM MODELS ---
-> ƒê√£ kh·ªüi t·∫°o c·∫•u tr√∫c m√¥ h√¨nh (c·∫ßn load tr·ªçng s·ªë).

--- 3. T·∫¢I TR·ªåNG S·ªê V√Ä SETUP INFERENCE ---

--- 4. TH·ª¨ NGHI·ªÜM D·ªäCH (Sau khi load model) ---


# ƒê√°nh gi√°

In [None]:
!pip install sacrebleu

Collecting sacrebleu
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/51.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m51.8/51.8 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Collecting colorama (from sacrebleu)
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Downloading sacrebleu-2.5.1-py3-none-any.whl (104 kB)
[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/104.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import os, math, json
import sentencepiece as spm
import sacrebleu
import matplotlib.pyplot as plt

sp_en = spm.SentencePieceProcessor()
sp_vi = spm.SentencePieceProcessor()
sp_en.load(os.path.join(DATA_DIR, 'spm_en.model'))
sp_vi.load(os.path.join(DATA_DIR, 'spm_vi.model'))

def subsequent_mask(size):
    mask = torch.triu(torch.ones(size, size, device=DEVICE), diagonal=1).bool()
    return (~mask).unsqueeze(0)

def make_std_mask_decode(tgt, pad_idx):
    tgt_mask = (tgt != pad_idx).unsqueeze(-2)
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1))
    return tgt_mask


model = make_model(
    VOCAB_SIZE_SPM,
    VOCAB_SIZE_SPM,
    N=6,
    d_model=512,
    h=8
).to(DEVICE)

checkpoint = torch.load(
    os.path.join(CHECKPOINT_DIR, 'best_model.pth'),
    map_location=DEVICE
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Model & SentencePiece loaded successfully")




# ==============================================================================
# 4. DATASET CHO PPL
# ==============================================================================
class TestDatasetPPL(Dataset):
    def __init__(self, raw_en, raw_vi):
        self.data = list(zip(raw_en, raw_vi))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        en, vi = self.data[idx]
        src = [BOS_IDX] + sp_en.encode_as_ids(en.lower()) + [EOS_IDX]
        tgt = [BOS_IDX] + sp_vi.encode_as_ids(vi.lower()) + [EOS_IDX]
        return torch.tensor(src), torch.tensor(tgt)

class CollatePPL:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        srcs, tgts = zip(*batch)
        srcs = pad_sequence(srcs, batch_first=True, padding_value=self.pad_idx)
        tgts = pad_sequence(tgts, batch_first=True, padding_value=self.pad_idx)
        return srcs.to(DEVICE), tgts.to(DEVICE)

# ==============================================================================
# 5. PERPLEXITY
# ==============================================================================
def calculate_ppl(loader):
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX, reduction='sum')
    total_loss, total_tokens = 0, 0

    with torch.no_grad():
        for src, tgt in tqdm(loader, desc="Calculating PPL"):
            tgt_in = tgt[:, :-1]
            tgt_y = tgt[:, 1:]

            src_mask = (src != PAD_IDX).unsqueeze(-2)
            tgt_mask = make_std_mask_decode(tgt_in, PAD_IDX)

            logits = model(src, tgt_in, src_mask, tgt_mask)

            loss = criterion(
                logits.view(-1, logits.size(-1)),
                tgt_y.contiguous().view(-1)
            )

            total_loss += loss.item()
            total_tokens += (tgt_y != PAD_IDX).sum().item()

    avg_loss = total_loss / total_tokens
    return avg_loss, math.exp(avg_loss)

# ==============================================================================
# 6. BEAM SEARCH (1 c√¢u)
# ==============================================================================
def beam_translate(src_sentence):
    src_ids = [BOS_IDX] + sp_en.encode_as_ids(src_sentence.lower()) + [EOS_IDX]
    src = torch.tensor(src_ids).unsqueeze(0).to(DEVICE)
    src_mask = (src != PAD_IDX).unsqueeze(-2)

    memory = model.encode(src, src_mask)

    beams = [(0.0, torch.tensor([[BOS_IDX]], device=DEVICE))]

    for _ in range(MAX_LEN_DECODE):
        candidates = []

        for score, seq in beams:
            if seq[0, -1].item() == EOS_IDX:
                candidates.append((score, seq))
                continue

            tgt_mask = make_std_mask_decode(seq, PAD_IDX)
            out = model.decode(memory, src_mask, seq, tgt_mask)
            log_prob = F.log_softmax(model.generator(out[:, -1]), dim=-1)

            topk = torch.topk(log_prob, BEAM_SIZE)

            for i in range(BEAM_SIZE):
                new_score = score + topk.values[0, i].item()
                new_seq = torch.cat(
                    [seq, topk.indices[0, i].view(1, 1)], dim=1
                )
                candidates.append((new_score, new_seq))

        beams = sorted(
            candidates,
            key=lambda x: x[0] / (x[1].size(1) ** 0.7),
            reverse=True
        )[:BEAM_SIZE]

        if beams[0][1][0, -1].item() == EOS_IDX:
            break

    best_seq = beams[0][1][0, 1:].tolist()
    if EOS_IDX in best_seq:
        best_seq = best_seq[:best_seq.index(EOS_IDX)]

    return sp_vi.decode_ids(best_seq)

# ==============================================================================
# 7. BLEU + REPORT
# ==============================================================================
def evaluate():
    raw_en = load_raw_data(TEST_EN_PATH)
    raw_vi = load_raw_data(TEST_VI_PATH)

    ppl_ds = TestDatasetPPL(raw_en, raw_vi)
    ppl_loader = DataLoader(
        ppl_ds,
        batch_size=BATCH_SIZE,
        collate_fn=CollatePPL(PAD_IDX)
    )

    loss, ppl = calculate_ppl(ppl_loader)

    print("\n Translating with Beam Search...")
    hypotheses = [beam_translate(s) for s in tqdm(raw_en)]

    bleu = sacrebleu.corpus_bleu(hypotheses, [raw_vi])

    print("\n" + "="*60)
    print(" FINAL EVALUATION RESULTS")
    print("="*60)
    print(f"Test Loss: {loss:.4f}")
    print(f"Test Perplexity (PPL): {ppl:.2f}")
    print(f"BLEU Score: {bleu.score:.2f}")
    print("="*60)
    print(bleu.format())
    print("="*60)

    return ppl, bleu.score

# ==============================================================================
# 8. RUN
# ==============================================================================
evaluate()


‚úÖ Model & SentencePiece loaded successfully
-> ƒêang ƒë·ªçc file: tst2013.en...
-> ƒêang ƒë·ªçc file: tst2013.vi...


Calculating PPL: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.06it/s]



üîÅ Translating with Beam Search...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1268/1268 [28:43<00:00,  1.36s/it]



üìä FINAL EVALUATION RESULTS
Test Loss: 2.0334
Test Perplexity (PPL): 7.64
BLEU Score: 26.37
BLEU = 26.37 59.3/33.9/20.7/12.9 (BP = 0.975 ratio = 0.975 hyp_len = 32905 ref_len = 33738)


(7.639712025620963, 26.36685262435577)