In [1]:
import torch
torch.cuda.empty_cache()

In [None]:
import os
import time
import torch
import csv
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from datetime import datetime

from model.encoder import Encoder
from model.decoder import Decoder
from model.seq2seq import Seq2Seq
from dataset.preprocess import collate_fn, TranslationDataset

from config import *
from loss import get_loss, compute_loss
from metrics import calculate_bleu, calculate_meteor, calculate_corpus_bleu
from optimizer import get_optimizer, get_plateau_scheduler
from early_stopping import EarlyStopping
from eval import evaluate

In [3]:
# === Load data ===
data = torch.load(DATA_PATH)
src_lines = data["src_lines"]
tgt_lines = data["tgt_lines"]
vocab_en = data["vocab_en"]
vocab_vi = data["vocab_vi"]
pad_idx = data["pad_idx"]

# === Kh·ªüi t·∫°o l·∫°i Dataset
train_dataset = TranslationDataset(src_lines["train"], tgt_lines["train"], vocab_en, vocab_vi)
val_dataset   = TranslationDataset(src_lines["val"], tgt_lines["val"], vocab_en, vocab_vi)

# === Dataloader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=lambda x: collate_fn(x, pad_idx))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=lambda x: collate_fn(x, pad_idx))

In [4]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn= lambda x: collate_fn(x, pad_idx)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn= lambda x: collate_fn(x, pad_idx)
)

In [5]:
INPUT_DIM = len(vocab_en)
OUTPUT_DIM = len(vocab_vi)

In [6]:
# === Init model ===
encoder = Encoder(INPUT_DIM, EMBED_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT).to(DEVICE)
decoder = Decoder(EMBED_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, DROPOUT).to(DEVICE)
model = Seq2Seq(encoder, decoder).to(DEVICE)

In [7]:
loss_fn = get_loss(pad_idx, use_label_smoothing=True)

optimizer = get_optimizer(model, LEARNING_RATE)
scheduler = get_plateau_scheduler(optimizer, factor=0.5, patience=2)

early_stopping = EarlyStopping(patience=PATIENCE, path=CHECKPOINT_PATH)



In [8]:
# === Logging setup ===
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
os.makedirs("logs", exist_ok=True)

log_file = f"logs/train_log_{timestamp}.txt"
csv_file = f"logs/train_metrics_{timestamp}.csv"

def log(message):
    print(message)
    with open(log_file, "a", encoding="utf-8") as f:
        f.write(message + "\n")

In [9]:
CSV_LOG_PATH = "logs/train_metrics.csv"
if not os.path.exists(CSV_LOG_PATH):
    with open(CSV_LOG_PATH, "w", newline='', encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "train_loss", "val_loss", "bleu"])

In [10]:
# === Training function ===
def train(model, loader, optimizer, loss_fn, clip, epoch_num):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(loader, desc=f"üîÅ Training Epoch {epoch_num}", leave=False)

    for i, (src, trg) in enumerate(progress_bar):
        src = src.transpose(0, 1).to(DEVICE)
        trg = trg.transpose(0, 1).to(DEVICE)

        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio=TEACHER_FORCING_RATIO)

        output = output[1:].reshape(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = compute_loss(output, trg, loss_fn)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        batch_loss = loss.item()
        epoch_loss += batch_loss
        progress_bar.set_postfix(batch_loss=f"{batch_loss:.4f}")

    return epoch_loss / len(loader)

In [11]:
# === Decode function for BLEU/METEOR
def decode_sequence(seq, idx2word):
    return [idx2word.get(idx, "<UNK>") for idx in seq if idx2word.get(idx) not in ["<PAD>", "<EOS>", "<SOS>"]]

In [12]:
# === Epoch timing ===
def epoch_time(start, end):
    elapsed = end - start
    return int(elapsed // 60), int(elapsed % 60)

In [13]:
# === Inverse vocab
idx2vi = {v: k for k, v in vocab_vi.items()}

In [14]:
# === Main training loop ===
log(f"üöÄ Training started at {timestamp}")
log(f"üß† DEVICE: {DEVICE}")
log(f"üìä Total Epochs: {N_EPOCHS} | Batch Size: {BATCH_SIZE} | Teacher Forcing: {TEACHER_FORCING_RATIO}")

best_val_loss = float("inf")

for epoch in range(N_EPOCHS):
    log(f"\nüìÖ Epoch {epoch + 1}/{N_EPOCHS}")
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, loss_fn, CLIP, epoch + 1)
    val_loss = evaluate(model, val_loader, loss_fn, device=DEVICE)

    scheduler.step(val_loss)
    early_stopping(val_loss, model)

    mins, secs = epoch_time(start_time, time.time())

    # === Calculate BLEU only every N epochs
    bleu = "-"
    if (epoch + 1) % EVAL_BLEU_EVERY == 0 or (epoch + 1) == N_EPOCHS:
        model.eval()
        preds, refs = [], []
        with torch.no_grad():
            for src, trg in val_loader:
                src = src.transpose(0, 1).to(DEVICE)
                trg = trg.transpose(0, 1).to(DEVICE)
                output = model(src, trg, teacher_forcing_ratio=0.0)
                pred_ids = output.argmax(-1).transpose(0, 1).tolist()
                ref_ids = trg.transpose(0, 1).tolist()
                for pred_seq, ref_seq in zip(pred_ids, ref_ids):
                    pred = decode_sequence(pred_seq, idx2vi)
                    ref = decode_sequence(ref_seq, idx2vi)
                    preds.append(pred)
                    refs.append(ref)
        bleu = calculate_bleu(preds, refs)
        log(f"üìè BLEU this epoch: {bleu:.4f}")

    log(f"üïí Time: {mins}m {secs}s | üî• Train Loss: {train_loss:.4f} | ‚úÖ Val Loss: {val_loss:.4f}")

    with open(CSV_LOG_PATH, "a", newline='', encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow([epoch + 1, round(train_loss, 4), round(val_loss, 4), bleu if isinstance(bleu, float) else "NA"])


    if early_stopping.early_stop:
        log("‚õî Early stopping triggered!")
        break

log("‚úÖ Training complete.")

üöÄ Training started at 2025-04-27_15-27-21
üß† DEVICE: cuda
üìä Total Epochs: 10 | Batch Size: 32 | Teacher Forcing: 0.5

üìÖ Epoch 1/10


                                                                                           

üïí Time: 32m 48s | üî• Train Loss: 5.6119 | ‚úÖ Val Loss: 5.9422

üìÖ Epoch 2/10


                                                                                           

üïí Time: 32m 17s | üî• Train Loss: 5.1408 | ‚úÖ Val Loss: 5.7722

üìÖ Epoch 3/10


                                                                                           

üïí Time: 32m 24s | üî• Train Loss: 4.9740 | ‚úÖ Val Loss: 5.7905

üìÖ Epoch 4/10


                                                                                           

üïí Time: 32m 12s | üî• Train Loss: 4.8774 | ‚úÖ Val Loss: 5.7637

üìÖ Epoch 5/10


                                                                                           

üìè BLEU this epoch: 0.1070
üïí Time: 32m 7s | üî• Train Loss: 4.8128 | ‚úÖ Val Loss: 5.7570

üìÖ Epoch 6/10


                                                                                           

üïí Time: 31m 58s | üî• Train Loss: 4.7744 | ‚úÖ Val Loss: 5.8098

üìÖ Epoch 7/10


                                                                                           

üïí Time: 31m 58s | üî• Train Loss: 4.7508 | ‚úÖ Val Loss: 5.7852

üìÖ Epoch 8/10


                                                                                           

üïí Time: 31m 52s | üî• Train Loss: 4.7396 | ‚úÖ Val Loss: 5.7897

üìÖ Epoch 9/10


                                                                                           

üïí Time: 31m 45s | üî• Train Loss: 4.5644 | ‚úÖ Val Loss: 5.7214

üìÖ Epoch 10/10


                                                                                            

üìè BLEU this epoch: 0.1130
üïí Time: 31m 44s | üî• Train Loss: 4.4805 | ‚úÖ Val Loss: 5.7453
‚úÖ Training complete.
