### 1. Fix đường dẫn 

In [1]:
import sys
from pathlib import Path

# notebooks/ -> project/
PROJECT_ROOT = Path.cwd().parent
SRC_DIR = PROJECT_ROOT / "src"

sys.path.append(str(SRC_DIR))

print("PROJECT_ROOT:", PROJECT_ROOT)
print("SRC_DIR added to path:", SRC_DIR)

PROJECT_ROOT: d:\studyhk1-25-26\NLP\project
SRC_DIR added to path: d:\studyhk1-25-26\NLP\project\src


### 2. Import module

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import sacrebleu

from config_loader import load_config
from data import (
    read_lines, tokenize_en, tokenize_fr,
    TranslationDataset, build_collate_fn,
    Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN
)
from model import (
    Encoder, Decoder, Seq2Seq,
    AttentionDecoder, Seq2SeqWithAttention
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

DEVICE: cpu


### 3. Helper functions (vocab, decode, BLEU)

In [3]:
def vocab_from_itos(itos_list):
    v = Vocab.__new__(Vocab)  # bypass __init__
    v.itos = list(itos_list)
    v.stoi = {tok: i for i, tok in enumerate(v.itos)}
    return v

def safe_tok_from_idx(vocab, idx: int):
    if 0 <= idx < len(vocab.itos):
        return vocab.itos[idx]
    return UNK_TOKEN

def generate_hyps_from_loader(model, dataloader, tgt_vocab, max_len=50, device="cpu"):
    model.eval()
    hyps, refs = [], []

    with torch.no_grad():
        for src, src_lens, tgt_in, tgt_out in tqdm(dataloader, desc="generate"):
            src, src_lens = src.to(device), src_lens.to(device)

            preds = model.greedy_decode(
                src, src_lens,
                max_len=max_len,
                sos_idx=tgt_vocab.stoi[SOS_TOKEN],
                eos_idx=tgt_vocab.stoi[EOS_TOKEN],
            )

            # hypotheses
            for seq in preds:
                tokens = []
                for idx in seq:
                    if idx == tgt_vocab.stoi[EOS_TOKEN]:
                        break
                    tok = safe_tok_from_idx(tgt_vocab, int(idx))
                    if tok not in (PAD_TOKEN, SOS_TOKEN, EOS_TOKEN):
                        tokens.append(tok)
                hyps.append(" ".join(tokens))

            # references
            tgt_np = tgt_out.cpu().numpy()
            for line in tgt_np:
                tokens = []
                for idx in line:
                    if int(idx) == tgt_vocab.stoi[EOS_TOKEN]:
                        break
                    tok = safe_tok_from_idx(tgt_vocab, int(idx))
                    if tok not in (PAD_TOKEN, SOS_TOKEN, EOS_TOKEN):
                        tokens.append(tok)
                refs.append(" ".join(tokens))

    return hyps, refs

def compute_corpus_bleu(hyps, refs):
    bleu = sacrebleu.corpus_bleu(hyps, [refs], force=True)
    return bleu.score

### 4. Helper load file & checkpoint 

In [4]:
def pick_existing_file(data_dir: Path, candidates):
    for name in candidates:
        p = data_dir / name
        if p.exists():
            return p
    return None

def load_checkpoint_safely(ckpt_path: Path):
    try:
        return torch.load(ckpt_path, map_location=DEVICE, weights_only=True)
    except TypeError:
        return torch.load(ckpt_path, map_location=DEVICE)

def is_attention_checkpoint(state_dict):
    """
    AttentionDecoder có input rnn = emb_dim + hid_dim
    Decoder thường: emb_dim
    """
    for k, v in state_dict.items():
        if "decoder.rnn.weight_ih_l0" in k:
            return v.shape[1] > 512
    return False

### 5. Load config & test data

In [5]:
config = load_config(PROJECT_ROOT / "config" / "config.yml")
data_dir = PROJECT_ROOT / config["data"]["data_dir"]

test_en = pick_existing_file(
    data_dir,
    ["test.en", "test_2018_flickr.en", "test_2017_flickr.en"]
)
test_fr = pick_existing_file(
    data_dir,
    ["test.fr", "test_2018_flickr.fr", "test_2017_flickr.fr"]
)

if test_en is None or test_fr is None:
    raise FileNotFoundError("Không tìm thấy file test")

test_src = read_lines(test_en)
test_tgt = read_lines(test_fr)

test_src_tok = [tokenize_en(s) for s in test_src]
test_tgt_tok = [tokenize_fr(s) for s in test_tgt]

print("Loaded test samples:", len(test_src_tok))

Loaded test samples: 1071


### 6. Load checkpoint & rebuild model

In [6]:
ckpt_path = PROJECT_ROOT / "checkpoints" / "best_model.pth"
ckpt = load_checkpoint_safely(ckpt_path)

src_vocab = vocab_from_itos(ckpt["src_itos"])
tgt_vocab = vocab_from_itos(ckpt["tgt_itos"])

cfg = ckpt.get("config", config)
m_cfg = cfg["model"]

emb_dim = m_cfg["emb_dim"]
hid_dim = m_cfg["hid_dim"]
n_layers = m_cfg["n_layers"]
dropout = m_cfg["dropout"]

use_attn = is_attention_checkpoint(ckpt["model_state_dict"])

enc = Encoder(len(src_vocab), emb_dim, hid_dim, n_layers, dropout)

if use_attn:
    dec = AttentionDecoder(len(tgt_vocab), emb_dim, hid_dim, n_layers, dropout)
    model = Seq2SeqWithAttention(enc, dec, DEVICE).to(DEVICE)
    print("→ Model: Seq2SeqWithAttention")
else:
    dec = Decoder(len(tgt_vocab), emb_dim, hid_dim, n_layers, dropout)
    model = Seq2Seq(enc, dec, DEVICE).to(DEVICE)
    print("→ Model: Seq2Seq (no attention)")

model.load_state_dict(ckpt["model_state_dict"])
model.eval()

→ Model: Seq2SeqWithAttention


Seq2SeqWithAttention(
  (encoder): Encoder(
    (embedding): Embedding(5893, 256, padding_idx=0)
    (rnn): LSTM(256, 512, num_layers=2, batch_first=True, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): AttentionDecoder(
    (embedding): Embedding(6470, 256, padding_idx=0)
    (attention): LuongAttention()
    (rnn): LSTM(768, 512, num_layers=2, batch_first=True, dropout=0.5)
    (fc_out): Linear(in_features=1024, out_features=6470, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

### 7. Build DataLoader

In [7]:
test_ds = TranslationDataset(
    test_src_tok,
    test_tgt_tok,
    src_vocab,
    tgt_vocab
)

collate_fn = build_collate_fn(src_vocab, tgt_vocab)

test_loader = DataLoader(
    test_ds,
    batch_size=cfg["training"]["batch_size"],
    shuffle=False,
    collate_fn=collate_fn,
)

### 8. Generate & BLEU

In [8]:
hyps_test, refs_test = generate_hyps_from_loader(
    model,
    test_loader,
    tgt_vocab,
    max_len=50,
    device=DEVICE
)

test_bleu = compute_corpus_bleu(hyps_test, refs_test)
print("Test BLEU:", test_bleu)

generate: 100%|██████████| 17/17 [00:09<00:00,  1.72it/s]

Test BLEU: 13.87869996913708





### 9. Xem nhanh vài câu dịch 

In [9]:
for i in range(5):
    print("SRC:", " ".join(test_src_tok[i]))
    print("REF:", refs_test[i])
    print("HYP:", hyps_test[i])
    print("-" * 60)

SRC: a young man participates in a career while the subject who records it smiles .
REF: un jeune homme participe à une course pendant que le sujet qui le filme sourit .
HYP: jeune homme fait dans un <unk> tandis que le le qui regarde .
------------------------------------------------------------
SRC: the man is scratching the back of his neck while looking for a book in a book store .
REF: l' homme se gratte l' arrière du cou tout en cherchant un livre dans une librairie .
HYP: homme frappe le arrière arrière son dos tandis en regardant un livre un livre livre un livre .
------------------------------------------------------------
SRC: a person wearing goggles and a hat is sled riding .
REF: une personne portant des lunettes de protection et un chapeau fait de la luge .
HYP: personne portant des lunettes et un chapeau fait du du .
------------------------------------------------------------
SRC: a girl in a pink coat and flowered goloshes sledding down a hill .
REF: une fille avec une

### 10. Lưu sample ra file 

In [10]:
save_path = PROJECT_ROOT / "results" / "samples.txt"
save_path.parent.mkdir(parents=True, exist_ok=True)

with open(save_path, "w", encoding="utf-8") as f:
    for src_tokens, ref, hyp in zip(
        test_src_tok[:200],
        refs_test[:200],
        hyps_test[:200],
    ):
        f.write("SRC: " + " ".join(src_tokens) + "\n")
        f.write("REF: " + ref + "\n")
        f.write("HYP: " + hyp + "\n\n")

print("Saved samples to:", save_path)

Saved samples to: d:\studyhk1-25-26\NLP\project\results\samples.txt
