In [1]:
import torch
torch.cuda.empty_cache()

In [None]:
import os
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from datetime import datetime

from model.encoder import Encoder
from model.decoder import Decoder
from model.seq2seq import Seq2Seq
from dataset.preprocess import collate_fn, TranslationDataset

from config import *
from loss import get_loss, compute_loss
from optimizer import get_optimizer, get_plateau_scheduler
from early_stopping import EarlyStopping
from eval import evaluate

In [3]:
# === Load data ===
data = torch.load(DATA_PATH)
src_lines = data["src_lines"]
tgt_lines = data["tgt_lines"]
vocab_en = data["vocab_en"]
vocab_vi = data["vocab_vi"]
pad_idx = data["pad_idx"]

# === Khởi tạo lại Dataset
train_dataset = TranslationDataset(src_lines["train"], tgt_lines["train"], vocab_en, vocab_vi)
val_dataset   = TranslationDataset(src_lines["val"], tgt_lines["val"], vocab_en, vocab_vi)

# === Dataloader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=lambda x: collate_fn(x, pad_idx))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=lambda x: collate_fn(x, pad_idx))

In [4]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn= lambda x: collate_fn(x, pad_idx)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn= lambda x: collate_fn(x, pad_idx)
)

In [5]:
INPUT_DIM = len(vocab_en)
OUTPUT_DIM = len(vocab_vi)

In [6]:
# === Init model ===
encoder = Encoder(INPUT_DIM, EMBED_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT).to(DEVICE)
decoder = Decoder(EMBED_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, DROPOUT).to(DEVICE)
model = Seq2Seq(encoder, decoder).to(DEVICE)

In [7]:
loss_fn = get_loss(pad_idx, use_label_smoothing=True)

optimizer = get_optimizer(model, LEARNING_RATE)
scheduler = get_plateau_scheduler(optimizer, factor=0.5, patience=2)

early_stopping = EarlyStopping(patience=PATIENCE, path=CHECKPOINT_PATH)



In [8]:
# === Logging setup ===
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_file = f"train_log_{timestamp}.txt"
os.makedirs("logs", exist_ok=True)
log_path = os.path.join("logs", log_file)

def log(message):
    print(message)
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(message + "\n")

In [9]:
# === Training function ===
def train(model, loader, optimizer, loss_fn, clip, epoch_num):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(loader, desc=f"🔁 Training Epoch {epoch_num}", leave=False)

    for i, (src, trg) in enumerate(progress_bar):
        src = src.transpose(0, 1).to(DEVICE)
        trg = trg.transpose(0, 1).to(DEVICE)

        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio=TEACHER_FORCING_RATIO)

        output = output[1:].reshape(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = compute_loss(output, trg, loss_fn)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        batch_loss = loss.item()
        epoch_loss += batch_loss
        progress_bar.set_postfix(batch_loss=f"{batch_loss:.4f}")

    return epoch_loss / len(loader)

In [10]:
# === Decode function for BLEU/METEOR
def decode_sequence(seq, idx2word):
    return [idx2word.get(idx, "<UNK>") for idx in seq if idx2word.get(idx) not in ["<PAD>", "<EOS>", "<SOS>"]]

In [11]:
# === Epoch timing ===
def epoch_time(start, end):
    elapsed = end - start
    return int(elapsed // 60), int(elapsed % 60)

In [12]:
# === Inverse vocab
idx2vi = {v: k for k, v in vocab_vi.items()}

In [13]:
# === Main training loop ===
log(f"🚀 Training started at {timestamp}")
log(f"🧠 DEVICE: {DEVICE}")
log(f"📊 Total Epochs: {N_EPOCHS} | Batch Size: {BATCH_SIZE} | Teacher Forcing: {TEACHER_FORCING_RATIO}")

best_val_loss = float("inf")

for epoch in range(N_EPOCHS):
    log(f"\n📅 Epoch {epoch + 1}/{N_EPOCHS}")
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, loss_fn, CLIP, epoch + 1)
    val_loss = evaluate(model, val_loader, loss_fn, device=DEVICE)

    scheduler.step(val_loss)
    early_stopping(val_loss, model)

    mins, secs = epoch_time(start_time, time.time())

    log(f"🕒 Time: {mins}m {secs}s | 🔥 Train Loss: {train_loss:.4f} | ✅ Val Loss: {val_loss:.4f}")

    if early_stopping.early_stop:
        log("⛔ Early stopping triggered!")
        break

log("✅ Training complete.")

🚀 Training started at 2025-05-19_03-07-37
🧠 DEVICE: cuda
📊 Total Epochs: 10 | Batch Size: 32 | Teacher Forcing: 0.5

📅 Epoch 1/10


                                                                                              

🕒 Time: 50m 28s | 🔥 Train Loss: 5.5492 | ✅ Val Loss: 5.8539

📅 Epoch 2/10


                                                                                              

🕒 Time: 50m 11s | 🔥 Train Loss: 5.0975 | ✅ Val Loss: 5.7253

📅 Epoch 3/10


                                                                                             

🕒 Time: 49m 44s | 🔥 Train Loss: 4.9575 | ✅ Val Loss: 5.6614

📅 Epoch 4/10


                                                                                             

🕒 Time: 49m 47s | 🔥 Train Loss: 4.8865 | ✅ Val Loss: 5.6938

📅 Epoch 5/10


                                                                                             

🕒 Time: 49m 45s | 🔥 Train Loss: 4.8476 | ✅ Val Loss: 5.6422

📅 Epoch 6/10


                                                                                             

🕒 Time: 49m 44s | 🔥 Train Loss: 4.8205 | ✅ Val Loss: 5.6642

📅 Epoch 7/10


                                                                                              

🕒 Time: 49m 46s | 🔥 Train Loss: 4.8106 | ✅ Val Loss: 5.6881

📅 Epoch 8/10


                                                                                              

🕒 Time: 49m 49s | 🔥 Train Loss: 4.8237 | ✅ Val Loss: 5.6933

📅 Epoch 9/10


                                                                                             

🕒 Time: 49m 53s | 🔥 Train Loss: 4.7347 | ✅ Val Loss: 5.6285

📅 Epoch 10/10


                                                                                               

🕒 Time: 50m 45s | 🔥 Train Loss: 4.6947 | ✅ Val Loss: 5.6565
✅ Training complete.
