In [None]:
import os
import torch
import time
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from tqdm import tqdm

from model.dlinknet3 import DLinkNet34
from data.dataset import RoadDatasetLegacyAugment
from loss.loss import bce_dice_loss
from utils.early_stopping import EarlyStopping
from utils.save_load import save_model
from config.optimizer import get_optimizer, get_scheduler
from utils.metrics import compute_metrics

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# === Cấu hình ===
batch_size = 4
num_epochs = 50
lr = 1e-4
log_dir = 'logs'
model_path = 'checkpoints/dlinknet_model_best.pth'
csv_path = 'data/1024/split.csv'
image_dir = 'data/1024/images'
mask_dir = 'data/1024/masks'

In [4]:
# === Dataloader ===
train_dataset = RoadDatasetLegacyAugment(csv_path, image_dir, mask_dir, split='train', augment=True)
val_dataset = RoadDatasetLegacyAugment(csv_path, image_dir, mask_dir, split='val', augment=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [5]:
# === Model, Loss, Optimizer ===
model = DLinkNet34(num_classes=1).to(device)
optimizer = get_optimizer(model, lr)
scheduler = get_scheduler(optimizer, mode='min')
early_stopping = EarlyStopping(patience=10, mode='min')  # chuyển sang mode='max' nếu early stopping chọn theo dice



In [6]:
# === Logger ===
writer = SummaryWriter(log_dir)
log_txt = open(os.path.join(log_dir, 'train_log.txt'), 'w')

In [7]:
# === Training Loop ===
best_dice = 0.0
for epoch in range(num_epochs):
    print(f"\n🟢 Epoch [{epoch+1}/{num_epochs}]")
    model.train()
    train_loss = 0

    for imgs, masks in tqdm(train_loader, desc='Training'):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = bce_dice_loss(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    # === Validation ===
    model.eval()
    val_loss = 0
    all_preds = []
    all_masks = []

    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc='Validation'):
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = bce_dice_loss(outputs, masks)
            val_loss += loss.item()

            all_preds.append(outputs.cpu())
            all_masks.append(masks.cpu())

    val_loss /= len(val_loader)
    all_preds = torch.cat(all_preds)
    all_masks = torch.cat(all_masks)

    val_dice, val_iou, val_miou = compute_metrics(all_preds, all_masks, threshold=0.5)

    print(f"🔹 Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Dice: {val_dice:.4f} | IoU: {val_iou:.4f} | mIoU: {val_miou:.4f}")
    log_txt.write(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, dice={val_dice:.4f}, iou={val_iou:.4f}, mIoU: {val_miou:.4f}\n")

    # === TensorBoard log ===
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Dice/val', val_dice, epoch)
    writer.add_scalar('IoU/val', val_iou, epoch)
    writer.add_scalar('mIoU/val', val_miou, epoch)

    # === Scheduler + EarlyStopping ===
    scheduler.step(val_loss)
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("⛔ Early stopping triggered!")
        early_stopping.load_best_weights(model)
        break

    # === Save model if improved ===
    if val_dice > best_dice:
        best_dice = val_dice
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        save_model(model, model_path, epoch, best_dice)
        print("✅ Model improved. Saved.")

log_txt.close()
writer.close()
print("🏁 Training completed.")



🟢 Epoch [1/50]


Training: 100%|██████████| 1295/1295 [05:02<00:00,  4.28it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.44it/s]


🔹 Train Loss: 0.6345 | Val Loss: 0.4631 | Dice: 0.6597 | IoU: 0.5112 | mIoU: 0.7404
✅ Model improved. Saved.

🟢 Epoch [2/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.39it/s]


🔹 Train Loss: 0.4757 | Val Loss: 0.4308 | Dice: 0.6807 | IoU: 0.5342 | mIoU: 0.7534
✅ Model improved. Saved.

🟢 Epoch [3/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.4527 | Val Loss: 0.4276 | Dice: 0.6824 | IoU: 0.5384 | mIoU: 0.7557
✅ Model improved. Saved.

🟢 Epoch [4/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.50it/s]


🔹 Train Loss: 0.4374 | Val Loss: 0.4226 | Dice: 0.6862 | IoU: 0.5413 | mIoU: 0.7571
✅ Model improved. Saved.

🟢 Epoch [5/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.48it/s]


🔹 Train Loss: 0.4238 | Val Loss: 0.4141 | Dice: 0.6916 | IoU: 0.5506 | mIoU: 0.7623
✅ Model improved. Saved.

🟢 Epoch [6/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.55it/s]


🔹 Train Loss: 0.4149 | Val Loss: 0.4130 | Dice: 0.6907 | IoU: 0.5504 | mIoU: 0.7626

🟢 Epoch [7/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.52it/s]


🔹 Train Loss: 0.4096 | Val Loss: 0.3991 | Dice: 0.7026 | IoU: 0.5616 | mIoU: 0.7688
✅ Model improved. Saved.

🟢 Epoch [8/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.4017 | Val Loss: 0.3982 | Dice: 0.7013 | IoU: 0.5592 | mIoU: 0.7676

🟢 Epoch [9/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.50it/s]


🔹 Train Loss: 0.3986 | Val Loss: 0.3878 | Dice: 0.7096 | IoU: 0.5703 | mIoU: 0.7733
✅ Model improved. Saved.

🟢 Epoch [10/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.45it/s]


🔹 Train Loss: 0.3947 | Val Loss: 0.3950 | Dice: 0.7079 | IoU: 0.5670 | mIoU: 0.7704

🟢 Epoch [11/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.3913 | Val Loss: 0.3786 | Dice: 0.7171 | IoU: 0.5787 | mIoU: 0.7773
✅ Model improved. Saved.

🟢 Epoch [12/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.51it/s]


🔹 Train Loss: 0.3880 | Val Loss: 0.3970 | Dice: 0.7021 | IoU: 0.5626 | mIoU: 0.7696

🟢 Epoch [13/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.51it/s]


🔹 Train Loss: 0.3842 | Val Loss: 0.4094 | Dice: 0.6899 | IoU: 0.5500 | mIoU: 0.7631

🟢 Epoch [14/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.51it/s]


🔹 Train Loss: 0.3841 | Val Loss: 0.3809 | Dice: 0.7143 | IoU: 0.5760 | mIoU: 0.7763

🟢 Epoch [15/50]


Training: 100%|██████████| 1295/1295 [05:02<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.48it/s]


🔹 Train Loss: 0.3776 | Val Loss: 0.3957 | Dice: 0.7032 | IoU: 0.5645 | mIoU: 0.7707

🟢 Epoch [16/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.51it/s]


🔹 Train Loss: 0.3777 | Val Loss: 0.3694 | Dice: 0.7222 | IoU: 0.5858 | mIoU: 0.7816
✅ Model improved. Saved.

🟢 Epoch [17/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.47it/s]


🔹 Train Loss: 0.3733 | Val Loss: 0.3814 | Dice: 0.7152 | IoU: 0.5769 | mIoU: 0.7763

🟢 Epoch [18/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.48it/s]


🔹 Train Loss: 0.3686 | Val Loss: 0.3715 | Dice: 0.7216 | IoU: 0.5844 | mIoU: 0.7807

🟢 Epoch [19/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.3705 | Val Loss: 0.3826 | Dice: 0.7114 | IoU: 0.5738 | mIoU: 0.7752

🟢 Epoch [20/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.43it/s]


🔹 Train Loss: 0.3679 | Val Loss: 0.3802 | Dice: 0.7143 | IoU: 0.5763 | mIoU: 0.7767

🟢 Epoch [21/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.48it/s]


🔹 Train Loss: 0.3665 | Val Loss: 0.3665 | Dice: 0.7254 | IoU: 0.5895 | mIoU: 0.7835
✅ Model improved. Saved.

🟢 Epoch [22/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.47it/s]


🔹 Train Loss: 0.3637 | Val Loss: 0.3689 | Dice: 0.7231 | IoU: 0.5875 | mIoU: 0.7826

🟢 Epoch [23/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.44it/s]


🔹 Train Loss: 0.3604 | Val Loss: 0.3677 | Dice: 0.7242 | IoU: 0.5879 | mIoU: 0.7825

🟢 Epoch [24/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.46it/s]


🔹 Train Loss: 0.3616 | Val Loss: 0.3671 | Dice: 0.7255 | IoU: 0.5903 | mIoU: 0.7840
✅ Model improved. Saved.

🟢 Epoch [25/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.44it/s]


🔹 Train Loss: 0.3606 | Val Loss: 0.3596 | Dice: 0.7296 | IoU: 0.5943 | mIoU: 0.7862
✅ Model improved. Saved.

🟢 Epoch [26/50]


Training: 100%|██████████| 1295/1295 [05:02<00:00,  4.28it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.47it/s]


🔹 Train Loss: 0.3571 | Val Loss: 0.3647 | Dice: 0.7254 | IoU: 0.5897 | mIoU: 0.7838

🟢 Epoch [27/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.39it/s]


🔹 Train Loss: 0.3571 | Val Loss: 0.3584 | Dice: 0.7311 | IoU: 0.5967 | mIoU: 0.7875
✅ Model improved. Saved.

🟢 Epoch [28/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.50it/s]


🔹 Train Loss: 0.3535 | Val Loss: 0.3654 | Dice: 0.7233 | IoU: 0.5876 | mIoU: 0.7829

🟢 Epoch [29/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.48it/s]


🔹 Train Loss: 0.3514 | Val Loss: 0.3704 | Dice: 0.7215 | IoU: 0.5853 | mIoU: 0.7811

🟢 Epoch [30/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.52it/s]


🔹 Train Loss: 0.3496 | Val Loss: 0.3557 | Dice: 0.7324 | IoU: 0.5983 | mIoU: 0.7882
✅ Model improved. Saved.

🟢 Epoch [31/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.52it/s]


🔹 Train Loss: 0.3486 | Val Loss: 0.3575 | Dice: 0.7310 | IoU: 0.5975 | mIoU: 0.7878

🟢 Epoch [32/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.50it/s]


🔹 Train Loss: 0.3518 | Val Loss: 0.3576 | Dice: 0.7311 | IoU: 0.5965 | mIoU: 0.7873

🟢 Epoch [33/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.3470 | Val Loss: 0.3664 | Dice: 0.7240 | IoU: 0.5886 | mIoU: 0.7834

🟢 Epoch [34/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.3483 | Val Loss: 0.3538 | Dice: 0.7340 | IoU: 0.6005 | mIoU: 0.7891
✅ Model improved. Saved.

🟢 Epoch [35/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.51it/s]


🔹 Train Loss: 0.3463 | Val Loss: 0.3541 | Dice: 0.7343 | IoU: 0.5999 | mIoU: 0.7888
✅ Model improved. Saved.

🟢 Epoch [36/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.3455 | Val Loss: 0.3522 | Dice: 0.7352 | IoU: 0.6017 | mIoU: 0.7900
✅ Model improved. Saved.

🟢 Epoch [37/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.52it/s]


🔹 Train Loss: 0.3405 | Val Loss: 0.3562 | Dice: 0.7299 | IoU: 0.5969 | mIoU: 0.7877

🟢 Epoch [38/50]


Training: 100%|██████████| 1295/1295 [05:02<00:00,  4.28it/s]
Validation: 100%|██████████| 278/278 [00:28<00:00,  9.66it/s]


🔹 Train Loss: 0.3450 | Val Loss: 0.3590 | Dice: 0.7285 | IoU: 0.5957 | mIoU: 0.7870

🟢 Epoch [39/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.50it/s]


🔹 Train Loss: 0.3402 | Val Loss: 0.3651 | Dice: 0.7265 | IoU: 0.5920 | mIoU: 0.7843

🟢 Epoch [40/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.48it/s]


🔹 Train Loss: 0.3379 | Val Loss: 0.3535 | Dice: 0.7335 | IoU: 0.5995 | mIoU: 0.7889

🟢 Epoch [41/50]


Training: 100%|██████████| 1295/1295 [05:02<00:00,  4.28it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.53it/s]


🔹 Train Loss: 0.3410 | Val Loss: 0.3563 | Dice: 0.7318 | IoU: 0.5988 | mIoU: 0.7886

🟢 Epoch [42/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.26it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.51it/s]


🔹 Train Loss: 0.3386 | Val Loss: 0.3556 | Dice: 0.7327 | IoU: 0.5979 | mIoU: 0.7878

🟢 Epoch [43/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.42it/s]


🔹 Train Loss: 0.3289 | Val Loss: 0.3505 | Dice: 0.7346 | IoU: 0.6028 | mIoU: 0.7908

🟢 Epoch [44/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.43it/s]


🔹 Train Loss: 0.3222 | Val Loss: 0.3436 | Dice: 0.7413 | IoU: 0.6090 | mIoU: 0.7938
✅ Model improved. Saved.

🟢 Epoch [45/50]


Training: 100%|██████████| 1295/1295 [05:03<00:00,  4.27it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.47it/s]


🔹 Train Loss: 0.3208 | Val Loss: 0.3408 | Dice: 0.7430 | IoU: 0.6121 | mIoU: 0.7955
✅ Model improved. Saved.

🟢 Epoch [46/50]


Training: 100%|██████████| 1295/1295 [05:02<00:00,  4.28it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.51it/s]


🔹 Train Loss: 0.3197 | Val Loss: 0.3450 | Dice: 0.7395 | IoU: 0.6081 | mIoU: 0.7938

🟢 Epoch [47/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.49it/s]


🔹 Train Loss: 0.3178 | Val Loss: 0.3455 | Dice: 0.7392 | IoU: 0.6086 | mIoU: 0.7938

🟢 Epoch [48/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.46it/s]


🔹 Train Loss: 0.3168 | Val Loss: 0.3417 | Dice: 0.7422 | IoU: 0.6120 | mIoU: 0.7956

🟢 Epoch [49/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.44it/s]


🔹 Train Loss: 0.3170 | Val Loss: 0.3381 | Dice: 0.7450 | IoU: 0.6147 | mIoU: 0.7970
✅ Model improved. Saved.

🟢 Epoch [50/50]


Training: 100%|██████████| 1295/1295 [05:04<00:00,  4.25it/s]
Validation: 100%|██████████| 278/278 [00:29<00:00,  9.44it/s]


🔹 Train Loss: 0.3146 | Val Loss: 0.3423 | Dice: 0.7424 | IoU: 0.6123 | mIoU: 0.7959
🏁 Training completed.
