<a href="https://colab.research.google.com/github/ttntbn/Deep-Learning/blob/main/real_Helmet_Detection_using_Faster_RCNN_clean.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Helmet Detection with Faster R-CNN — Clean Notebook (ไทย)

เวอร์ชันนี้จัดระเบียบใหม่ให้ **อ่านง่าย**, **แก้ไขสะดวก**, และ **เทรนพร้อมกราฟ loss สด**
- รวม **hyperparameters** ไว้ในที่เดียว
- คอมเมนต์ไทยอธิบายทุกส่วน: Config → Data → Transforms → Dataloaders → Model → Train → Save
- เทรนแล้วเห็นกราฟ Loss ต่อ step แบบ Live

## 1) Config — ปรับค่าที่นี่เพื่อคุมการเทรน

In [None]:
# ==============================
# Configurable Hyperparameters
# ==============================
from pathlib import Path

# --- Paths (แก้ให้ตรงกับเครื่อง) ---
DATA_ROOT   = Path("/content/drive/MyDrive/helmet_dataset")   # โฟลเดอร์หลักของรูป/annotations
TRAIN_IMG   = DATA_ROOT / "train/images"
TRAIN_ANN   = DATA_ROOT / "train/annotations"   # VOC XML
VAL_IMG     = DATA_ROOT / "val/images"
VAL_ANN     = DATA_ROOT / "val/annotations"
OUTPUT_DIR  = Path("./outputs_frcnn")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# --- Classes (แก้ให้ตรงกับชุดข้อมูล) ---
# ตัวอย่าง: 2 คลาส + background
CLASSES = ["__background__", "helmet", "no_helmet"]

# --- Training hyperparams ---
EPOCHS        = 15
BATCH_SIZE    = 8
LR            = 5e-4
WEIGHT_DECAY  = 1e-4
MOMENTUM      = 0.9
WARMUP_STEPS  = 500

# --- LR scheduler options ---
LR_SCHEDULER  = "multistep"     # ["none", "multistep", "cosine"]
MILESTONES    = [10, 15]
GAMMA         = 0.1
COSINE_TMAX   = EPOCHS
ETA_MIN       = 1e-6

# --- Data Augmentations ---
H_FLIP_PROB   = 0.5
RAND_RESIZE   = (640, 1024)
COLOR_JITTER  = True

SEED          = 1337

print("OUTPUT_DIR:", OUTPUT_DIR.resolve())

## 2) Setup — ไลบรารีและอุปกรณ์

In [None]:
# ===== run this FIRST (top of notebook, before importing torch) =====
!pip -q install comet-ml

import os
# ใส่ค่า "จริง" ของคุณ (หรือใช้ environment variables ก็ได้)
os.environ["COMET_API_KEY"]      = "KlDdmMhprhNGWTot1PPhnMo4u"             # คีย์จริงจาก Account Settings
os.environ["COMET_WORKSPACE"]    = "boonyapon-boontub-0272"     # <- จากลิงก์ของคุณ
os.environ["COMET_PROJECT_NAME"] = "helmet-fasterrcnn"

from comet_ml import Experiment

experiment = Experiment(
    api_key=os.getenv("COMET_API_KEY"),
    workspace=os.getenv("COMET_WORKSPACE"),
    project_name=os.getenv("COMET_PROJECT_NAME"),
    auto_metric_logging=False,
    auto_param_logging=False,
    auto_output_logging="simple",
)
experiment.set_name("fasterrcnn-helmet-run")


In [None]:
import os, math, time, random, xml.etree.ElementTree as ET
import numpy as np
import torch, torch.utils.data as data
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms.functional as F
from PIL import Image
import matplotlib.pyplot as plt

# reproducibility
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 3) Dataset (VOC XML) — อ่านกล่องจากไฟล์ .xml

In [None]:
def _read_voc_xml(xml_path, class_to_idx):
    # อ่าน annotation แบบ VOC XML -> boxes, labels (index เริ่ม 1 สำหรับคลาสของเรา)
    tree = ET.parse(str(xml_path))
    root = tree.getroot()
    boxes, labels = [], []
    for obj in root.findall("object"):
        name = obj.find("name").text.strip()
        if name not in class_to_idx:
            continue
        bnd = obj.find("bndbox")
        x1 = float(bnd.find("xmin").text); y1 = float(bnd.find("ymin").text)
        x2 = float(bnd.find("xmax").text); y2 = float(bnd.find("ymax").text)
        if x2 <= x1 or y2 <= y1:   # กรองกล่องผิดรูป
            continue
        boxes.append([x1, y1, x2, y2])
        labels.append(class_to_idx[name])
    return boxes, labels

class HelmetVOCDataset(data.Dataset):
    # จับคู่รูปใน IMG_DIR กับไฟล์ XML ใน ANN_DIR ชื่อเดียวกัน
    def __init__(self, img_dir: Path, ann_dir: Path, classes, transforms=None):
        self.img_dir = Path(img_dir); self.ann_dir = Path(ann_dir)
        self.transforms = transforms; self.classes = classes
        # mapping ชื่อคลาส -> index เริ่ม 1 (0 = background)
        self.class_to_idx = {}; idx = 1
        for c in classes:
            if c == "__background__": continue
            self.class_to_idx[c] = idx; idx += 1
        # list รูปที่มี XML คู่กัน
        exts = {".jpg",".jpeg",".png"}; items = []
        for p in sorted(self.img_dir.iterdir()):
            if p.suffix.lower() in exts:
                xml = self.ann_dir / (p.stem + ".xml")
                if xml.exists():
                    items.append((p, xml))
        self.items = items

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        img_path, xml_path = self.items[idx]
        img = Image.open(img_path).convert("RGB")
        boxes, labels = _read_voc_xml(xml_path, self.class_to_idx)
        boxes = torch.as_tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0,4), dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)  if labels else torch.zeros((0,), dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels, "image_id": torch.tensor([idx])}
        if self.transforms is not None: img, target = self.transforms(img, target)
        return img, target

### Transforms — Augmentations เพื่อลด overfit / เพิ่มความทนทาน

In [None]:
class ComposeTransforms:
    def __init__(self, transforms): self.transforms = transforms
    def __call__(self, img, target):
        for t in self.transforms: img, target = t(img, target)
        return img, target

class RandomHorizontalFlip:
    def __init__(self, p=0.5): self.p = p
    def __call__(self, img, target):
        if random.random() < self.p:
            w, h = img.size; img = F.hflip(img)
            boxes = target["boxes"]
            if boxes.numel() > 0:
                x1 = w - boxes[:,2]; x2 = w - boxes[:,0]
                boxes[:,0] = x1; boxes[:,2] = x2
                target["boxes"] = boxes
        return img, target

class RandomResizeLongestSide:
    # resize ให้ด้านยาวสุ่มในช่วง RAND_RESIZE (รักษาอัตราส่วนภาพ)
    def __init__(self, min_long, max_long):
        self.min_long = min_long; self.max_long = max_long
    def __call__(self, img, target):
        w, h = img.size; long_side = max(w, h)
        new_long = random.randint(self.min_long, self.max_long)
        scale = new_long / long_side
        new_w, new_h = int(w*scale), int(h*scale)
        img = F.resize(img, [new_h, new_w])
        boxes = target["boxes"]
        if boxes.numel()>0:
            boxes = boxes * scale; target["boxes"] = boxes
        return img, target

class OptionalColorJitter:
    def __init__(self, enable=True):
        self.enable = enable
        self.jitter = T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02)
    def __call__(self, img, target):
        if self.enable: img = self.jitter(img)
        return img, target

def make_transforms(train=True):
    tsfms = []
    if train:
        tsfms += [RandomResizeLongestSide(RAND_RESIZE[0], RAND_RESIZE[1]),
                  RandomHorizontalFlip(H_FLIP_PROB),
                  OptionalColorJitter(COLOR_JITTER)]
    else:
        tsfms += [RandomResizeLongestSide(RAND_RESIZE[0], RAND_RESIZE[1])]
    def to_tensor(img, target): return F.to_tensor(img), target
    tsfms += [to_tensor]
    return ComposeTransforms(tsfms)

## 4) Dataloaders

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pathlib import Path

DATA_ROOT = Path("/content/drive/MyDrive/helmet_dataset")
TRAIN_IMG_DIR = DATA_ROOT / "train/images"
TRAIN_ANN_DIR = DATA_ROOT / "train/annotations"
VAL_IMG_DIR   = DATA_ROOT / "val/images"
VAL_ANN_DIR   = DATA_ROOT / "val/annotations"

def _count(p, exts):
    return sum(1 for q in p.iterdir() if q.is_file() and q.suffix.lower() in exts)

# เช็คมีโฟลเดอร์จริง
for p in [TRAIN_IMG_DIR, TRAIN_ANN_DIR, VAL_IMG_DIR, VAL_ANN_DIR]:
    assert p.exists(), f"ไม่พบโฟลเดอร์: {p}"

print("train:", _count(TRAIN_IMG_DIR,{'.jpg','.jpeg','.png'}),"images /",
      _count(TRAIN_ANN_DIR,{'.xml'}),"xml")
print("val  :", _count(VAL_IMG_DIR,  {'.jpg','.jpeg','.png'}),"images /",
      _count(VAL_ANN_DIR,  {'.xml'}),"xml")

# collate_fn (ถ้ายังไม่ได้ประกาศ)
try:
    collate_fn  # noqa: F821
except NameError:
    def collate_fn(batch):
        imgs, tgts = list(zip(*batch))
        return list(imgs), list(tgts)

# Datasets & Dataloaders
train_ds = HelmetVOCDataset(TRAIN_IMG_DIR, TRAIN_ANN_DIR, CLASSES, transforms=make_transforms(train=True))
val_ds   = HelmetVOCDataset(VAL_IMG_DIR,   VAL_ANN_DIR,   CLASSES, transforms=make_transforms(train=False))

from torch.utils.data import DataLoader
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=2, pin_memory=True, prefetch_factor=2,
    persistent_workers=False,   # <- เปลี่ยนจาก True เป็น False
    collate_fn=collate_fn
)
val_loader = DataLoader(
    val_ds, batch_size=1, shuffle=False,
    num_workers=2, pin_memory=True, prefetch_factor=2,
    persistent_workers=False,
    collate_fn=collate_fn
)


print("OK -> loaders ready | train:", len(train_ds), "| val:", len(val_ds))


## 5) Model — Faster R-CNN ResNet50 FPN

In [None]:
def build_model(num_classes: int):
    model = fasterrcnn_resnet50_fpn(weights="DEFAULT")
    in_feats = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_feats, num_classes)
    return model

model = build_model(num_classes=len(CLASSES)).to(device)
model

## 6) Optimizer & LR Scheduler — พร้อม Linear Warmup

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

if LR_SCHEDULER == "multistep":
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=GAMMA)
elif LR_SCHEDULER == "cosine":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=COSINE_TMAX, eta_min=ETA_MIN)
else:
    scheduler = None

global_step = 0

def set_lr(opt, lr):
    for g in opt.param_groups: g['lr'] = lr

def get_warmup_lr(base_lr, step, warmup_steps):
    if warmup_steps <= 0: return base_lr
    return base_lr * min(1.0, step / float(warmup_steps))

In [None]:
# ===== core training setup =====
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# optimizer
optimizer = torch.optim.SGD(
    [p for p in model.parameters() if p.requires_grad],
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY,
)

# scheduler (เลือกตามค่าที่ตั้งไว้)
if LR_SCHEDULER == "multistep":
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=MILESTONES, gamma=GAMMA
    )
elif LR_SCHEDULER == "cosine":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=COSINE_TMAX, eta_min=ETA_MIN
    )
else:  # "none"
    scheduler = None

# ถ้ายังไม่ได้ประกาศ global_step
try:
    global_step
except NameError:
    global_step = 0


In [None]:
def update_epoch_plot():
    """อัปเดตกราฟราย epoch (train/val) แบบเดียวกับโค้ด MNIST"""
    x = np.arange(1, len(train_epoch_loss) + 1)

    # อัปเดตเส้น
    tr_line.set_data(x, train_epoch_loss)
    va_line.set_data(x, val_epoch_loss)

    # กำหนดขอบแกน X
    ax_ep.set_xlim(1, max(5, len(train_epoch_loss) + 0.5))

    # กำหนดขอบแกน Y จากค่าที่มีทั้งหมด
    if train_epoch_loss or val_epoch_loss:
        y_all = (train_epoch_loss or []) + (val_epoch_loss or [])
        y_min, y_max = float(min(y_all)), float(max(y_all))
        if y_max == y_min:
            y_max += 1.0
            y_min -= 1.0
        margin = 0.1 * (y_max - y_min)
        ax_ep.set_ylim(y_min - margin, y_max + margin)

    # วาดและอัปเดตเฟรมเดิมใน Colab
    fig_ep.canvas.draw()
    try:
        ep_handle.update(fig_ep)  # ต้องมี display(fig_ep, display_id=True) มาก่อน
    except Exception:
        display(fig_ep)


In [None]:
# ==== สร้างกราฟ epoch (train vs val) เพื่ออัปเดตใน cell เดียวกับที่เทรน ====
import numpy as np, matplotlib.pyplot as plt
from IPython.display import display

# list เก็บค่า ถ้ายังไม่มี
try: train_epoch_loss
except NameError: train_epoch_loss = []
try: val_epoch_loss
except NameError: val_epoch_loss = []

fig_ep, ax_ep = plt.subplots(figsize=(7,4))
(tr_line,) = ax_ep.plot([], [], '-o', label='train loss')
(va_line,) = ax_ep.plot([], [], '-s', label='val loss')
ax_ep.set_xlabel('Epoch'); ax_ep.set_ylabel('Loss')
ax_ep.set_title('Loss per Epoch'); ax_ep.grid(True); ax_ep.legend()

# สำคัญ: ใช้ display_id เพื่ออัปเดตกราฟเดิมใน cell เดียว
ep_handle = display(fig_ep, display_id=True)


In [None]:
# ==== Live step-loss plot (สร้างครั้งเดียว) ====
import numpy as np, matplotlib.pyplot as plt
from IPython.display import display

class LiveLossPlot:
    def __init__(self, title="Training Loss (live)", tail=500, ema_alpha=0.1):
        self.tail, self.ema_alpha = tail, ema_alpha
        self.fig, self.ax = plt.subplots(figsize=(7,4))
        (self.line,) = self.ax.plot([], [], label="loss/step")
        (self.ema_line,) = self.ax.plot([], [], ls="--", label=f"EMA α={ema_alpha}")
        self.ax.set_xlabel("Step"); self.ax.set_ylabel("Loss")
        self.ax.set_title(title); self.ax.grid(True); self.ax.legend()
        self.handle = display(self.fig, display_id=True)

    def update(self, loss_hist):
        if not loss_hist: return
        x = np.arange(1, len(loss_hist)+1)
        self.line.set_data(x, loss_hist)
        if len(loss_hist) > 1:
            a = self.ema_alpha; ema = [loss_hist[0]]
            for v in loss_hist[1:]: ema.append(a*v + (1-a)*ema[-1])
            self.ema_line.set_data(x, ema)
        tail = loss_hist[-min(self.tail, len(loss_hist)):]
        y0, y1 = float(min(tail)), float(max(tail))
        if y1 == y0: y1 += 1.0; y0 -= 1.0
        m = 0.1*(y1-y0)
        self.ax.set_xlim(1, max(50, len(loss_hist)+5))
        self.ax.set_ylim(y0-m, y1+m)
        self.fig.canvas.draw()
        try: self.handle.update(self.fig)
        except: display(self.fig)

UPDATE_EVERY = 5                  # ปรับความถี่อัปเดตกราฟ
live_plot = LiveLossPlot()        # ← สร้างตัวแปรที่คุณส่งเข้า train_one_epoch


In [None]:
# ==== MNIST-style Epoch Loss Plot (create once) ====
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

# lists เก็บค่า loss ราย epoch (ถ้ายังไม่มี)
try: train_epoch_loss
except NameError: train_epoch_loss = []
try: val_epoch_loss
except NameError: val_epoch_loss = []

# setup figure/lines
fig_ep, ax_ep = plt.subplots(figsize=(7,4))
(tr_line,) = ax_ep.plot([], [], '-o', label='train loss')
(va_line,) = ax_ep.plot([], [], '-s', label='val loss')
ax_ep.set_xlabel('Epoch'); ax_ep.set_ylabel('Loss')
ax_ep.set_title('Loss per Epoch'); ax_ep.grid(True); ax_ep.legend()

# ใช้ display_id เพื่ออัปเดตกราฟเดิมใน cell เดียว (เหมือนตัวอย่าง MNIST)
ep_handle = display(fig_ep, display_id=True)


In [None]:
def update_epoch_plot():
    """อัปเดตกราฟราย epoch (train/val) แบบเดียวกับโค้ด MNIST"""
    x = np.arange(1, len(train_epoch_loss) + 1)

    # อัปเดตเส้น
    tr_line.set_data(x, train_epoch_loss)
    va_line.set_data(x, val_epoch_loss)

    # ขอบแกน X
    ax_ep.set_xlim(1, max(5, len(train_epoch_loss) + 0.5))

    # ขอบแกน Y จากค่าที่มีทั้งหมด
    if train_epoch_loss or val_epoch_loss:
        y_all = (train_epoch_loss or []) + (val_epoch_loss or [])
        y_min, y_max = float(min(y_all)), float(max(y_all))
        if y_max == y_min:
            y_max += 1.0
            y_min -= 1.0
        margin = 0.1 * (y_max - y_min)
        ax_ep.set_ylim(y_min - margin, y_max + margin)

    # วาดและอัปเดตเฟรมเดิมใน Colab
    fig_ep.canvas.draw()
    try:
        ep_handle.update(fig_ep)  # ต้องมี display(fig_ep, display_id=True) มาก่อน
    except Exception:
        display(fig_ep)


## 7) Train — วาดกราฟ Loss สดระหว่างเทรน

In [None]:
# ===== Drop-in patch: define missing pieces BEFORE the epoch loop =====
import numpy as np
import torch, time

# 1) Averager (ไว้คำนวณค่าเฉลี่ยราย epoch)
class Averager:
    def __init__(self): self.reset()
    def reset(self): self.total, self.count = 0.0, 0
    def update(self, v, n: int = 1): self.total += float(v) * n; self.count += n
    @property
    def value(self): return self.total / max(1, self.count)

# ถ้ายังไม่มีตัวพวกนี้ ให้สร้าง
try: train_loss_hist
except NameError: train_loss_hist = Averager()
try: val_loss_hist
except NameError: val_loss_hist   = Averager()
try: train_step_loss
except NameError: train_step_loss = []
try: train_epoch_loss
except NameError: train_epoch_loss = []
try: val_epoch_loss
except NameError: val_epoch_loss   = []
try: global_step
except NameError: global_step = 0

# 2) ฟังก์ชันช่วยปรับ LR
def set_lr(opt, lr):
    for g in opt.param_groups: g['lr'] = lr

def get_warmup_lr(base_lr, step, warmup_steps):
    if warmup_steps <= 0: return base_lr
    return base_lr * min(1.0, step / float(warmup_steps))

# 3) validate แบบ pseudo-loss สำหรับ Faster R-CNN (train-mode + no_grad)
@torch.no_grad()
def validate(loader, model):
    model.train()
    val_loss_hist.reset()
    for images, targets in loader:
        images  = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        loss_sum  = sum(loss for loss in loss_dict.values())
        val_loss_hist.update(loss_sum.item())
    return val_loss_hist.value

# 4) train 1 epoch (log ไปที่ Comet แบบรายสเต็ป)
def train_one_epoch(loader, model, epoch: int):
    model.train()
    train_loss_hist.reset()
    global global_step

    num_batches = len(loader)
    print(f"[train] batches: {num_batches}, batch_size: {BATCH_SIZE}")

    t_epoch0 = time.time()

    for step, (images, targets) in enumerate(loader, start=1):
        t0 = time.time()

        images  = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # warmup LR
        warm_lr = get_warmup_lr(LR, global_step, WARMUP_STEPS)
        set_lr(optimizer, warm_lr)

        # forward/backward
        loss_dict = model(images, targets)
        losses    = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()

        # stat
        loss_val = float(losses.item())
        train_loss_hist.update(loss_val)
        train_step_loss.append(loss_val)
        global_step += 1

        # --- อัปเดตกราฟใน Colab ---
        if live_plot is not None and (step % UPDATE_EVERY) == 0:
            live_plot.update(train_step_loss)

        # ==== Comet: log metric รายสเต็ป ====
        experiment.log_metric("train/step_loss", loss_val, step=global_step)
        experiment.log_metric("lr", optimizer.param_groups[0]['lr'], step=global_step)

        # timing/log
        t1 = time.time()
        step_time = t1 - t0
        avg_step_time = (t1 - t_epoch0) / step
        eta = avg_step_time * (num_batches - step)
        if (step % 10) == 1 or step == 1:
            mem = (torch.cuda.memory_allocated()/1e9) if torch.cuda.is_available() else 0.0
            print(f"[epoch {epoch} step {step}/{num_batches}] "
                  f"loss={loss_val:.3f} "
                  f"step_time={step_time:.2f}s avg={avg_step_time:.2f}s ETA={eta/60:.1f}m "
                  f"lr={optimizer.param_groups[0]['lr']:.6f} "
                  f"gpu_mem~{mem:.2f}GB")

    return train_loss_hist.value

# 5) ตัวอย่าง SaveBestModel และ save_model (เหมือนเดิม)
class SaveBestModel:
    def __init__(self):
        self.best = float("inf")
    def __call__(self, current_val_loss, epoch, model, optimizer, scheduler):
        if current_val_loss < self.best:
            self.best = current_val_loss
            torch.save({
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict() if scheduler else None,
                "best_val_loss": self.best,
            }, OUTPUT_DIR / "best_model.pth")
            print(f">>> Saved best (val_loss={current_val_loss:.4f}) at epoch {epoch+1}")
            # แนบไฟล์ไปที่ Comet ด้วย (สะดวกเวลารีวิว)
            try:
                experiment.log_asset(str(OUTPUT_DIR / "best_model.pth"), step=epoch+1)
            except Exception:
                pass

def save_model(epoch, model, optimizer, scheduler):
    path = OUTPUT_DIR / f"epoch_{epoch+1}.pth"
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler else None,
    }, path)
    # แนบไฟล์ checkpoint ใน Comet (optional)
    try:
        experiment.log_asset(str(path), step=epoch+1)
    except Exception:
        pass

# ====== ลูปหลักแบบที่ต้องการ ======
save_best_model = SaveBestModel()

NUM_EPOCHS = EPOCHS  # หรือกำหนดเลขคงที่
start_epoch = 0

for ep in range(start_epoch, NUM_EPOCHS):
    print(f"\nEPOCH {ep+1} of {NUM_EPOCHS}")

    t0 = time.time()
    train_mean = train_one_epoch(train_loader, model, epoch=ep+1)   # ไม่ต้องส่ง live_plot
    val_mean   = validate(val_loader, model)
    elapsed = time.time() - t0

    train_epoch_loss.append(train_mean)
    val_epoch_loss.append(val_mean)
    update_epoch_plot()  # <<== จุดสำคัญ


    # ==== Comet: log metric ราย epoch ====
    experiment.log_metrics({
        "epoch/train_loss": train_mean,
        "epoch/val_loss": val_mean,
    }, step=ep+1, epoch=ep+1)

    # แสดง LR
    try:
        lr_now = scheduler.get_last_lr() if scheduler else [g['lr'] for g in optimizer.param_groups]
    except Exception:
        lr_now = [g['lr'] for g in optimizer.param_groups]
    print(f"Epoch #{ep+1} train loss: {train_mean:.3f} | val loss: {val_mean:.3f} | lr: {lr_now} | time: {elapsed:.1f}s")

    if scheduler is not None:
        scheduler.step()

    save_best_model(val_mean, ep, model, optimizer, scheduler)
    save_model(ep, model, optimizer, scheduler)

# ปิด experiment เมื่อจบ
experiment.end()
