# Transformer



In [None]:
# (Colab) Cài thư viện
!pip install -q datasets tokenizers sacrebleu tqdm


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Sampler

from datasets import load_dataset, Dataset, DatasetDict
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from sacrebleu.metrics import BLEU
from tqdm import tqdm

print("torch:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)


torch: 2.9.0+cu126
device: cuda


In [None]:
# Seed để tái lập
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)


## 1) Load dữ liệu + tạo dataset 2 chiều

Dùng dataset `thainq107/iwslt2015-en-vi` (HuggingFace).  
Ta tạo thêm trường:
- `src`: câu nguồn
- `tgt`: câu đích
- `tgt_lang`: `"vi"` hoặc `"en"` (để biết cần gắn tag nào lên input)
- `direction`: `"en->vi"` hoặc `"vi->en"`


In [None]:
raw = load_dataset("thainq107/iwslt2015-en-vi")
raw


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/522 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/181k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/133317 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1268 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1268 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['en', 'vi'],
        num_rows: 133317
    })
    validation: Dataset({
        features: ['en', 'vi'],
        num_rows: 1268
    })
    test: Dataset({
        features: ['en', 'vi'],
        num_rows: 1268
    })
})

In [None]:
def build_bidirectional_split(ds_split) -> Dataset:
    src_list, tgt_list, tgt_lang_list, direction_list = [], [], [], []
    for ex in ds_split:
        en = ex["en"]
        vi = ex["vi"]

        # en -> vi
        src_list.append(en)
        tgt_list.append(vi)
        tgt_lang_list.append("vi")
        direction_list.append("en->vi")

        # vi -> en
        src_list.append(vi)
        tgt_list.append(en)
        tgt_lang_list.append("en")
        direction_list.append("vi->en")

    return Dataset.from_dict({
        "src": src_list,
        "tgt": tgt_list,
        "tgt_lang": tgt_lang_list,
        "direction": direction_list,
    })

datasets = DatasetDict({
    "train": build_bidirectional_split(raw["train"]),
    "validation": build_bidirectional_split(raw["validation"]),
    "test": build_bidirectional_split(raw["test"]),
})

datasets


DatasetDict({
    train: Dataset({
        features: ['src', 'tgt', 'tgt_lang', 'direction'],
        num_rows: 266634
    })
    validation: Dataset({
        features: ['src', 'tgt', 'tgt_lang', 'direction'],
        num_rows: 2536
    })
    test: Dataset({
        features: ['src', 'tgt', 'tgt_lang', 'direction'],
        num_rows: 2536
    })
})

In [None]:
# Xem vài ví dụ
for i in range(3):
    print(datasets["train"][i])


{'src': 'Rachel Pike : The science behind a climate headline', 'tgt': 'Khoa học đằng sau một tiêu đề về khí hậu', 'tgt_lang': 'vi', 'direction': 'en->vi'}
{'src': 'Khoa học đằng sau một tiêu đề về khí hậu', 'tgt': 'Rachel Pike : The science behind a climate headline', 'tgt_lang': 'en', 'direction': 'vi->en'}
{'src': 'In 4 minutes , atmospheric chemist Rachel Pike provides a glimpse of the massive scientific effort behind the bold headlines on climate change , with her team -- one of thousands who contributed -- taking a risky flight over the rainforest in pursuit of data on a key molecule .', 'tgt': 'Trong 4 phút , chuyên gia hoá học khí quyển Rachel Pike giới thiệu sơ lược về những nỗ lực khoa học miệt mài đằng sau những tiêu đề táo bạo về biến đổi khí hậu , cùng với đoàn nghiên cứu của mình -- hàng ngàn người đã cống hiến cho dự án này -- một chuyến bay mạo hiểm qua rừng già để tìm kiếm thông tin về một phân tử then chốt .', 'tgt_lang': 'vi', 'direction': 'en->vi'}


## 2) Tokenizer joint BPE + special tokens (PAD/UNK/BOS/EOS + tag ngôn ngữ)

Ta train 1 tokenizer chung cho cả 2 ngôn ngữ bằng cách lấy corpus gồm:
- `src` và `tgt` của mọi split (train/val/test)

Special tokens:
- `[PAD] [UNK] [BOS] [EOS]`
- `[2EN]` (mục tiêu dịch sang tiếng Anh)
- `[2VI]` (mục tiêu dịch sang tiếng Việt)


In [None]:
PAD_TOKEN = "[PAD]"
UNK_TOKEN = "[UNK]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
TAG_EN = "[2EN]"
TAG_VI = "[2VI]"

special_tokens = [PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, TAG_EN, TAG_VI]

# Hyperparams tokenizer
vocab_size = 16000

tokenizer = Tokenizer(BPE(unk_token=UNK_TOKEN))
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=special_tokens)

def joint_corpus():
    for split in ["train", "validation", "test"]:
        for ex in datasets[split]:
            yield ex["src"]
            yield ex["tgt"]

print("Training joint tokenizer...")
tokenizer.train_from_iterator(joint_corpus(), trainer)

tokenizer.save("tokenizer_joint_envi.json")
print("Saved tokenizer_joint_envi.json")
print("Vocab size:", tokenizer.get_vocab_size())


Training joint tokenizer...
Saved tokenizer_joint_envi.json
Vocab size: 16000


In [None]:
# Lấy id các special tokens
pad_id = tokenizer.token_to_id(PAD_TOKEN)
unk_id = tokenizer.token_to_id(UNK_TOKEN)
bos_id = tokenizer.token_to_id(BOS_TOKEN)
eos_id = tokenizer.token_to_id(EOS_TOKEN)
tag_en_id = tokenizer.token_to_id(TAG_EN)
tag_vi_id = tokenizer.token_to_id(TAG_VI)

assert None not in [pad_id, unk_id, bos_id, eos_id, tag_en_id, tag_vi_id]

lang2tag_id = {"en": tag_en_id, "vi": tag_vi_id}

print({
    "pad_id": pad_id, "unk_id": unk_id, "bos_id": bos_id, "eos_id": eos_id,
    "tag_en_id": tag_en_id, "tag_vi_id": tag_vi_id
})


{'pad_id': 0, 'unk_id': 1, 'bos_id': 2, 'eos_id': 3, 'tag_en_id': 4, 'tag_vi_id': 5}


## 3) Encode/Pad + DataLoader (Bucket batching)

Quy ước encode:
- **Source**: `[BOS] [2XX] <tokens> [EOS]` (tag cho biết *ngôn ngữ đích*)
- **Target**: `[BOS] <tokens> [EOS]`


In [None]:
# Hyperparams dữ liệu
batch_size = 64
max_src_len = 80  # tính cả BOS + TAG + EOS
max_tgt_len = 80  # tính cả BOS + EOS

def encode_src(text: str, tgt_lang: str, max_len: int) -> List[int]:
    tag_id = lang2tag_id[tgt_lang]
    ids = tokenizer.encode(text).ids
    # trừ 3 cho BOS + TAG + EOS
    ids = ids[: max_len - 3]
    return [bos_id, tag_id] + ids + [eos_id]

def encode_tgt(text: str, max_len: int) -> List[int]:
    ids = tokenizer.encode(text).ids
    ids = ids[: max_len - 2]  # BOS + EOS
    return [bos_id] + ids + [eos_id]

def pad_sequences(seqs: List[List[int]], pad_value: int) -> torch.Tensor:
    max_len = max(len(s) for s in seqs)
    out = torch.full((len(seqs), max_len), pad_value, dtype=torch.long)
    for i, s in enumerate(seqs):
        out[i, :len(s)] = torch.tensor(s, dtype=torch.long)
    return out

def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    src_seqs = [encode_src(ex["src"], ex["tgt_lang"], max_src_len) for ex in batch]
    tgt_seqs = [encode_tgt(ex["tgt"], max_tgt_len) for ex in batch]

    src_batch = pad_sequences(src_seqs, pad_id)
    tgt_batch = pad_sequences(tgt_seqs, pad_id)

    # shift-right
    tgt_input = tgt_batch[:, :-1].contiguous()
    tgt_output = tgt_batch[:, 1:].contiguous()

    return {
        "src": src_batch,
        "tgt_input": tgt_input,
        "tgt_output": tgt_output,
    }

class BucketBatchSampler(Sampler[List[int]]):
    """Gom các câu có độ dài gần nhau để giảm padding."""
    def __init__(self, lengths: List[int], batch_size: int, shuffle: bool = True):
        self.lengths = lengths
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        indices = list(range(len(self.lengths)))
        if self.shuffle:
            random.shuffle(indices)

        # sort theo length trong từng 'chunk' lớn để giảm bias
        chunk_size = self.batch_size * 100
        for start in range(0, len(indices), chunk_size):
            chunk = indices[start:start+chunk_size]
            chunk.sort(key=lambda i: self.lengths[i])

            for b in range(0, len(chunk), self.batch_size):
                yield chunk[b:b+self.batch_size]

    def __len__(self):
        return (len(self.lengths) + self.batch_size - 1) // self.batch_size

def compute_lengths(ds_split) -> List[int]:
    # dùng split theo whitespace để nhanh (đủ tốt cho bucketing)
    lens = []
    for ex in ds_split:
        src_len = len(ex["src"].split())
        tgt_len = len(ex["tgt"].split())
        lens.append(max(src_len, tgt_len))
    return lens

train_lengths = compute_lengths(datasets["train"])
train_sampler = BucketBatchSampler(train_lengths, batch_size=batch_size, shuffle=True)

train_loader = DataLoader(datasets["train"], batch_sampler=train_sampler, collate_fn=collate_fn)
valid_loader = DataLoader(datasets["validation"], batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print("train batches:", len(train_loader), "valid batches:", len(valid_loader))


train batches: 4167 valid batches: 40


## 4) Mask (padding & causal)

In [None]:
def create_src_mask(src: torch.Tensor, pad_idx: int) -> torch.Tensor:
    # src: [B, S]
    # mask: [B,1,1,S] với True ở token hợp lệ
    return (src != pad_idx).unsqueeze(1).unsqueeze(2)

def create_tgt_mask(tgt: torch.Tensor, pad_idx: int) -> torch.Tensor:
    # tgt: [B, T]
    B, T = tgt.size()
    pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)  # [B,1,1,T]

    # causal mask: [1,1,T,T]
    causal = torch.tril(torch.ones((T, T), device=tgt.device)).bool()
    causal = causal.unsqueeze(0).unsqueeze(0)

    return pad_mask & causal


## 5) Transformer (from scratch)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # [1,max_len,d_model]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B,T,d_model]
        return x + self.pe[:, :x.size(1)]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask: Optional[torch.Tensor] = None):
        # query: [B,Tq,d_model], key/value: [B,Tk,d_model]
        B = query.size(0)

        Q = self.q_linear(query).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)  # [B,H,Tq,d_k]
        K = self.k_linear(key).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)    # [B,H,Tk,d_k]
        V = self.v_linear(value).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)  # [B,H,Tk,d_k]

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)  # [B,H,Tq,Tk]

        if mask is not None:
            # mask True = keep, False = mask out
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        context = attn @ V  # [B,H,Tq,d_k]
        context = context.transpose(1, 2).contiguous().view(B, -1, self.d_model)  # [B,Tq,d_model]
        return self.out_proj(context)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x, src_mask):
        # self-attn
        attn = self.self_attn(x, x, x, src_mask)
        x = self.norm1(x + self.drop1(attn))
        # ffn
        ff = self.ffn(x)
        x = self.norm2(x + self.drop2(ff))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)
        self.drop3 = nn.Dropout(dropout)

    def forward(self, x, memory, src_mask, tgt_mask):
        attn1 = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.drop1(attn1))

        attn2 = self.cross_attn(x, memory, memory, src_mask)
        x = self.norm2(x + self.drop2(attn2))

        ff = self.ffn(x)
        x = self.norm3(x + self.drop3(ff))
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model)
        self.drop = nn.Dropout(dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, src, src_mask):
        x = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos(x)
        x = self.drop(x)
        for layer in self.layers:
            x = layer(x, src_mask)
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model)
        self.drop = nn.Dropout(dropout)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, tgt, memory, src_mask, tgt_mask):
        x = self.embedding(tgt) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos(x)
        x = self.drop(x)
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return x

class TransformerNMT(nn.Module):
    def __init__(self, vocab_size, d_model, num_encoder_layers, num_decoder_layers, num_heads, d_ff, dropout):
        super().__init__()
        self.encoder = TransformerEncoder(vocab_size, d_model, num_encoder_layers, num_heads, d_ff, dropout)
        self.decoder = TransformerDecoder(vocab_size, d_model, num_decoder_layers, num_heads, d_ff, dropout)
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt_input, src_mask, tgt_mask):
        memory = self.encoder(src, src_mask)
        dec_out = self.decoder(tgt_input, memory, src_mask, tgt_mask)
        logits = self.output_layer(dec_out)  # [B,T,V]
        return logits


## 6) Loss (Label smoothing) + Noam LR scheduler

In [None]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, label_smoothing: float, vocab_size: int, ignore_index: int = 0):
        super().__init__()
        assert 0.0 <= label_smoothing <= 1.0
        self.smoothing = label_smoothing
        self.confidence = 1.0 - label_smoothing
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """pred: [N,V] logits, target: [N]"""
        pred = pred.log_softmax(dim=-1)

        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.vocab_size - 2))

            ignore = target == self.ignore_index
            tgt = target.clone()
            tgt[ignore] = 0
            true_dist.scatter_(1, tgt.unsqueeze(1), self.confidence)
            true_dist[ignore] = 0

        loss = torch.sum(-true_dist * pred, dim=-1)
        non_pad = ~ignore
        return loss[non_pad].mean()

class NoamOpt:
    def __init__(self, model_size: int, factor: float, warmup: int, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0.0

    def step(self):
        self._step += 1
        lr = self.rate()
        for p in self.optimizer.param_groups:
            p["lr"] = lr
        self._rate = lr
        self.optimizer.step()

    def zero_grad(self):
        self.optimizer.zero_grad(set_to_none=True)

    def rate(self, step: Optional[int] = None):
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * (self.warmup ** (-1.5))))


## 7) Train/Eval loop (CE theo token)

In [None]:
# Hyperparams model/train
d_model = 512
num_encoder_layers = 4
num_decoder_layers = 4
num_heads = 8
d_ff = 2048
dropout = 0.1

label_smoothing = 0.1
warmup_steps = 4000
lr_factor = 1.0
max_grad_norm = 1.0
num_epochs = 10

vocab_sz = tokenizer.get_vocab_size()

model = TransformerNMT(
    vocab_size=vocab_sz,
    d_model=d_model,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    num_heads=num_heads,
    d_ff=d_ff,
    dropout=dropout,
).to(device)

print("Params (M):", sum(p.numel() for p in model.parameters()) / 1e6)

base_opt = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9)
optimizer = NoamOpt(d_model, lr_factor, warmup_steps, base_opt)

criterion = LabelSmoothingLoss(label_smoothing, vocab_sz, ignore_index=pad_id)


Params (M): 54.017664


In [None]:
def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    total_loss, total_tokens = 0.0, 0

    for batch in tqdm(data_loader, desc="train", leave=False):
        src = batch["src"].to(device)
        tgt_input = batch["tgt_input"].to(device)
        tgt_output = batch["tgt_output"].to(device)

        src_mask = create_src_mask(src, pad_id)
        tgt_mask = create_tgt_mask(tgt_input, pad_id)

        optimizer.zero_grad()
        logits = model(src, tgt_input, src_mask, tgt_mask)

        loss = criterion(
            logits.view(-1, logits.size(-1)),
            tgt_output.view(-1),
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        num_tokens = (tgt_output != pad_id).sum().item()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens

    return total_loss / max(total_tokens, 1)

@torch.no_grad()
def eval_epoch(model, data_loader, criterion):
    model.eval()
    total_loss, total_tokens = 0.0, 0

    for batch in tqdm(data_loader, desc="valid", leave=False):
        src = batch["src"].to(device)
        tgt_input = batch["tgt_input"].to(device)
        tgt_output = batch["tgt_output"].to(device)

        src_mask = create_src_mask(src, pad_id)
        tgt_mask = create_tgt_mask(tgt_input, pad_id)

        logits = model(src, tgt_input, src_mask, tgt_mask)
        loss = criterion(
            logits.view(-1, logits.size(-1)),
            tgt_output.view(-1),
        )

        num_tokens = (tgt_output != pad_id).sum().item()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens

    return total_loss / max(total_tokens, 1)


In [None]:
best_valid = float("inf")

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    valid_loss = eval_epoch(model, valid_loader, criterion)

    if valid_loss < best_valid:
        best_valid = valid_loss
        torch.save(model.state_dict(), "model_best_bidirectional.pth")

    print(
        f"Epoch {epoch:02d} | "
        f"Train Loss: {train_loss:.4f} | Train PPL: {math.exp(train_loss):.2f} | "
        f"Valid Loss: {valid_loss:.4f} | Valid PPL: {math.exp(valid_loss):.2f} | "
        f"LR: {optimizer._rate:.6f}"
    )

torch.save(model.state_dict(), "model_last_bidirectional.pth")
print("Saved: model_best_bidirectional.pth & model_last_bidirectional.pth")




Epoch 01 | Train Loss: 5.4641 | Train PPL: 236.06 | Valid Loss: 4.9797 | Valid PPL: 145.43 | LR: 0.000685




Epoch 02 | Train Loss: 4.5766 | Train PPL: 97.19 | Valid Loss: 4.5720 | Valid PPL: 96.74 | LR: 0.000484




Epoch 03 | Train Loss: 4.2758 | Train PPL: 71.94 | Valid Loss: 4.3968 | Valid PPL: 81.19 | LR: 0.000395




Epoch 04 | Train Loss: 4.1081 | Train PPL: 60.83 | Valid Loss: 4.2823 | Valid PPL: 72.41 | LR: 0.000342




Epoch 05 | Train Loss: 3.9930 | Train PPL: 54.22 | Valid Loss: 4.1806 | Valid PPL: 65.41 | LR: 0.000306




Epoch 06 | Train Loss: 3.9065 | Train PPL: 49.72 | Valid Loss: 4.1633 | Valid PPL: 64.29 | LR: 0.000279




Epoch 07 | Train Loss: 3.8370 | Train PPL: 46.38 | Valid Loss: 4.1093 | Valid PPL: 60.90 | LR: 0.000259




Epoch 08 | Train Loss: 3.7795 | Train PPL: 43.79 | Valid Loss: 4.0944 | Valid PPL: 60.00 | LR: 0.000242




Epoch 09 | Train Loss: 3.7302 | Train PPL: 41.69 | Valid Loss: 4.0495 | Valid PPL: 57.37 | LR: 0.000228




Epoch 10 | Train Loss: 3.6888 | Train PPL: 40.00 | Valid Loss: 4.0520 | Valid PPL: 57.51 | LR: 0.000216
Saved: model_best_bidirectional.pth & model_last_bidirectional.pth


## 8) Suy diễn (Beam search) cho cả 2 chiều

Chỉ cần chỉ định `target_lang="vi"` hoặc `"en"`.

Ví dụ:
- dịch **EN→VI**: `target_lang="vi"`
- dịch **VI→EN**: `target_lang="en"`


In [None]:
@torch.no_grad()
def translate_sentence(
    model: nn.Module,
    src_sentence: str,
    target_lang: str,
    max_len: int = 80,
    beam_size: int = 5,
) -> str:
    model.eval()

    # Encode source with language tag
    src_ids = encode_src(src_sentence, target_lang, max_src_len)
    src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1,S]
    src_mask = create_src_mask(src, pad_id)

    memory = model.encoder(src, src_mask)

    beams: List[Tuple[List[int], float]] = [([bos_id], 0.0)]
    completed: List[Tuple[List[int], float]] = []

    for _ in range(max_len):
        new_beams: List[Tuple[List[int], float]] = []

        for tokens, score in beams:
            if tokens[-1] == eos_id:
                completed.append((tokens, score))
                continue

            tgt = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)  # [1,t]
            tgt_mask = create_tgt_mask(tgt, pad_id)

            dec_out = model.decoder(tgt, memory, src_mask, tgt_mask)
            logits = model.output_layer(dec_out)[:, -1, :]  # [1,V]
            log_probs = F.log_softmax(logits, dim=-1).squeeze(0)  # [V]

            topk_logp, topk_ids = torch.topk(log_probs, beam_size)
            for lp, idx in zip(topk_logp.tolist(), topk_ids.tolist()):
                new_beams.append((tokens + [idx], score + lp))

        if not new_beams:
            break

        new_beams.sort(key=lambda x: x[1], reverse=True)
        beams = new_beams[:beam_size]

        # stop sớm nếu đã có đủ câu kết thúc
        if len(completed) >= beam_size and all(b[0][-1] == eos_id for b in beams):
            break

    if not completed:
        completed = beams

    best_tokens = max(completed, key=lambda x: x[1])[0]

    # Remove special tokens (BOS/EOS/PAD + tag)
    special_ids = {pad_id, bos_id, eos_id, tag_en_id, tag_vi_id}
    best_tokens = [t for t in best_tokens if t not in special_ids]

    return tokenizer.decode(best_tokens)

# Demo nhanh
print("EN->VI:", translate_sentence(model, "How are you today?", target_lang="vi"))
print("VI->EN:", translate_sentence(model, "Tôi đang học Transformer để dịch máy.", target_lang="en"))


EN->VI: Bạn như thế nào hôm nay ?
VI->EN: I & apos ; m learning how to translate the machine .


## 9) Đánh giá BLEU-1..4 (tách theo 2 chiều)

Ta tính:
- BLEU1 = BLEU với max_ngram_order=1
- BLEU2 = max_ngram_order=2
- BLEU3 = max_ngram_order=3
- BLEU4 = max_ngram_order=4 (BLEU chuẩn)

Mỗi chiều (`en->vi`, `vi->en`) được report riêng.


In [None]:
@torch.no_grad()
def compute_bleu_1to4(
    model: nn.Module,
    dataset_split,
    max_sentences: Optional[int] = 500,
    beam_size: int = 5,
) -> Dict[str, Dict[str, float]]:
    results = {}

    for direction in ["en->vi", "vi->en"]:
        ds_dir = dataset_split.filter(lambda x: x["direction"] == direction)
        n = len(ds_dir) if max_sentences is None else min(max_sentences, len(ds_dir))

        hyps, refs = [], []
        for i in tqdm(range(n), desc=f"decode {direction}", leave=False):
            ex = ds_dir[i]
            hyp = translate_sentence(
                model,
                src_sentence=ex["src"],
                target_lang=ex["tgt_lang"],
                max_len=max_tgt_len,
                beam_size=beam_size,
            )
            hyps.append(hyp)
            refs.append(ex["tgt"])

        # BLEU1..4
        scores = {}
        for k in [1, 2, 3, 4]:
            bleu_metric = BLEU(max_ngram_order=k)
            scores[f"BLEU{k}"] = bleu_metric.corpus_score(hyps, [refs]).score

        results[direction] = scores

    return results
bleu_scores = compute_bleu_1to4(model, datasets["test"], max_sentences=500, beam_size=5)
bleu_scores


Filter:   0%|          | 0/2536 [00:00<?, ? examples/s]



Filter:   0%|          | 0/2536 [00:00<?, ? examples/s]



{'en->vi': {'BLEU1': 40.980254702652175,
  'BLEU2': 27.402589055183864,
  'BLEU3': 18.472398131972067,
  'BLEU4': 12.6376224067564},
 'vi->en': {'BLEU1': 28.185187118325583,
  'BLEU2': 17.708425692875462,
  'BLEU3': 12.04129079382029,
  'BLEU4': 8.559670874154843}}

In [None]:

for direction, scores in bleu_scores.items():
    print(direction, " | ", ", ".join([f"{k}: {v:.2f}" for k, v in scores.items()]))


en->vi  |  BLEU1: 40.98, BLEU2: 27.40, BLEU3: 18.47, BLEU4: 12.64
vi->en  |  BLEU1: 28.19, BLEU2: 17.71, BLEU3: 12.04, BLEU4: 8.56


## 10) Load lại model tốt nhất (nếu cần)

Chạy cell này khi bạn đã train xong và muốn reload.


In [None]:
# Load best checkpoint
ckpt_path = "model_best_bidirectional.pth"
try:
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)
    print("Loaded:", ckpt_path)
except FileNotFoundError:
    print("Không tìm thấy checkpoint:", ckpt_path)


Loaded: model_best_bidirectional.pth


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
from pathlib import Path

#Mount Google Drive
IN_COLAB = False
try:
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")
    IN_COLAB = True
except Exception as e:
    print("Không phải Colab hoặc không mount được Drive:", e)

# Thư mục lưu checkpoint
SAVE_DIR = "/content/drive/MyDrive/transformer_envi_bidirectional" if IN_COLAB else "./checkpoints_transformer_envi_bidirectional"
os.makedirs(SAVE_DIR, exist_ok=True)
print("SAVE_DIR =", SAVE_DIR)

# Lưu tokenizer lên Drive
TOKENIZER_PATH = os.path.join(SAVE_DIR, "tokenizer_joint_envi.json")
tokenizer.save(TOKENIZER_PATH)
print("Saved tokenizer:", TOKENIZER_PATH)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
SAVE_DIR = /content/drive/MyDrive/transformer_envi_bidirectional
Saved tokenizer: /content/drive/MyDrive/transformer_envi_bidirectional/tokenizer_joint_envi.json


In [None]:
def pack_checkpoint(model, optimizer, epoch: int, best_valid: float, config: dict) -> dict:
    ckpt = {
        "epoch": int(epoch),
        "best_valid": float(best_valid),
        "config": config,
        "model_state_dict": model.state_dict(),
    }

    # optimizer là NoamOpt wrapper (optimizer.optimizer là Adam)
    if optimizer is not None and hasattr(optimizer, "optimizer"):
        ckpt["base_optimizer_state_dict"] = optimizer.optimizer.state_dict()
        ckpt["noam_step"] = int(getattr(optimizer, "_step", 0))
        ckpt["noam_rate"] = float(getattr(optimizer, "_rate", 0.0))

    return ckpt

def save_checkpoint(path: str, ckpt: dict) -> None:
    tmp_path = path + ".tmp"
    torch.save(ckpt, tmp_path)
    os.replace(tmp_path, path)
    print("Saved checkpoint:", path)

def load_checkpoint(path: str, device=device) -> dict:
    return torch.load(path, map_location=device)


In [None]:
best_valid = float("inf")

BEST_CKPT_PATH = os.path.join(SAVE_DIR, "checkpoint_best.pt")
LAST_CKPT_PATH = os.path.join(SAVE_DIR, "checkpoint_last.pt")

TRAIN_CONFIG = {
    "d_model": d_model,
    "num_encoder_layers": num_encoder_layers,
    "num_decoder_layers": num_decoder_layers,
    "num_heads": num_heads,
    "d_ff": d_ff,
    "dropout": dropout,
    "label_smoothing": label_smoothing,
    "warmup_steps": warmup_steps,
    "lr_factor": lr_factor,
    "max_grad_norm": max_grad_norm,
    "max_src_len": max_src_len,
    "max_tgt_len": max_tgt_len,
    "batch_size": batch_size,
    "tokenizer_path": TOKENIZER_PATH,
    "special_tokens": {
        "PAD_TOKEN": PAD_TOKEN,
        "UNK_TOKEN": UNK_TOKEN,
        "BOS_TOKEN": BOS_TOKEN,
        "EOS_TOKEN": EOS_TOKEN,
        "TAG_EN": TAG_EN,
        "TAG_VI": TAG_VI,
    },
}

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    valid_loss = eval_epoch(model, valid_loader, criterion)

    # Save BEST theo valid_loss
    if valid_loss < best_valid:
        best_valid = valid_loss
        ckpt = pack_checkpoint(model, optimizer, epoch, best_valid, TRAIN_CONFIG)
        save_checkpoint(BEST_CKPT_PATH, ckpt)

    # Save LAST mỗi epoch
    ckpt_last = pack_checkpoint(model, optimizer, epoch, best_valid, TRAIN_CONFIG)
    save_checkpoint(LAST_CKPT_PATH, ckpt_last)

    print(f"Epoch {epoch} | train_loss={train_loss:.4f} | valid_loss={valid_loss:.4f} | lr={optimizer._rate:.6f}")

print("Done.")
print("BEST:", BEST_CKPT_PATH)
print("LAST:", LAST_CKPT_PATH)




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 1 | train_loss=3.6801 | valid_loss=4.0184 | lr=0.000206




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 2 | train_loss=3.6435 | valid_loss=4.0120 | lr=0.000198




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 3 | train_loss=3.6121 | valid_loss=3.9871 | lr=0.000190




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 4 | train_loss=3.5847 | valid_loss=3.9791 | lr=0.000183




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 5 | train_loss=3.5593 | valid_loss=3.9844 | lr=0.000177




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 6 | train_loss=3.5368 | valid_loss=3.9395 | lr=0.000171




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 7 | train_loss=3.5152 | valid_loss=3.9705 | lr=0.000166




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 8 | train_loss=3.4953 | valid_loss=3.9337 | lr=0.000161




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 9 | train_loss=3.4766 | valid_loss=3.9323 | lr=0.000157




Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
Saved checkpoint: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt
Epoch 10 | train_loss=3.4581 | valid_loss=3.9235 | lr=0.000153
Done.
BEST: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_best.pt
LAST: /content/drive/MyDrive/transformer_envi_bidirectional/checkpoint_last.pt


In [None]:
import os
from tokenizers import Tokenizer

CKPT_PATH = os.path.join(SAVE_DIR, "checkpoint_best.pt")
ckpt = torch.load(CKPT_PATH, map_location="cpu")
cfg = ckpt["config"]

# Load tokenizer đúng vocab
tokenizer = Tokenizer.from_file(cfg["tokenizer_path"])
vocab_sz = tokenizer.get_vocab_size()

# Rebuild special token ids để encode đúng
PAD_TOKEN = cfg["special_tokens"]["PAD_TOKEN"]
UNK_TOKEN = cfg["special_tokens"]["UNK_TOKEN"]
BOS_TOKEN = cfg["special_tokens"]["BOS_TOKEN"]
EOS_TOKEN = cfg["special_tokens"]["EOS_TOKEN"]
TAG_EN = cfg["special_tokens"]["TAG_EN"]
TAG_VI = cfg["special_tokens"]["TAG_VI"]

pad_id = tokenizer.token_to_id(PAD_TOKEN)
unk_id = tokenizer.token_to_id(UNK_TOKEN)
bos_id = tokenizer.token_to_id(BOS_TOKEN)
eos_id = tokenizer.token_to_id(EOS_TOKEN)

lang2tag_id = {
    "en": tokenizer.token_to_id(TAG_EN),
    "vi": tokenizer.token_to_id(TAG_VI),
}

# Khởi tạo lại model theo config
model = TransformerNMT(
    vocab_size=vocab_sz,
    d_model=cfg["d_model"],
    num_encoder_layers=cfg["num_encoder_layers"],
    num_decoder_layers=cfg["num_decoder_layers"],
    num_heads=cfg["num_heads"],
    d_ff=cfg["d_ff"],
    dropout=cfg["dropout"],
).to(device)

model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# Dịch thử
print("EN->VI:", translate_sentence(model, "I love machine learning.", target_lang="vi", beam_size=5))
print("VI->EN:", translate_sentence(model, "Tôi thích học máy.", target_lang="en", beam_size=5))


EN->VI: Tôi yêu máy học .
VI->EN: I love the machine .
