0) IMPORT ทุกอย่างที่ต้องใช้

In [None]:
import os
import time
import copy

import numpy as np
import pandas as pd
import cv2
from PIL import Image

import kagglehub   # ใช้โหลด dataset จาก Kaggle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms


from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    f1_score,
    recall_score,
    roc_auc_score,
    classification_report,
    confusion_matrix
)
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt

# ใช้ GPU ถ้ามี ไม่มีก็ใช้ CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

2. Load Dataset (Kaggle BUSI)

In [None]:
import kagglehub

# โหลด dataset
dpath = kagglehub.dataset_download("aryashah2k/breast-ultrasound-images-dataset")
print("Dataset path:", dpath)

folder = os.path.join(dpath, "Dataset_BUSI_with_GT")
print("Classes:", os.listdir(folder))

3. Build DataFrame

In [None]:
class_names = ["benign", "malignant", "normal"]
data = []

for idx, cls in enumerate(class_names):
    cdir = os.path.join(folder, cls)
    for fname in os.listdir(cdir):
        if fname.lower().endswith((".png", ".jpg", ".jpeg")):
            data.append([os.path.join(cdir, fname), idx])

df = pd.DataFrame(data, columns=["path","label"])
print("Total images:", len(df))


4. Train/Val/Test split

In [None]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.15, stratify=df["label"], random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1765, stratify=train_df["label"], random_state=42)

print(len(train_df), len(val_df), len(test_df))

5. Dataset Class + CLAHE + Augmentation

In [None]:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))

train_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.Normalize([0.485]*3, [0.229]*3),
])

val_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485]*3, [0.229]*3),
])


class BUSIDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.loc[idx, "path"]
        label = self.df.loc[idx, "label"]

        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        img = clahe.apply(img)
        img = cv2.resize(img, (224,224))

        img = np.stack([img, img, img], axis=-1)
        img = img.astype("uint8")

        img = transforms.ToPILImage()(img)
        img = self.transform(img)

        return {"image": img, "label": torch.tensor(label).long()}

6. DataLoader

In [None]:
train_ds = BUSIDataset(train_df, train_tf)
val_ds   = BUSIDataset(val_df,   val_tf)
test_ds  = BUSIDataset(test_df,  val_tf)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=32, shuffle=False)

dataloaders = {"train": train_loader, "val": val_loader}
dataset_sizes = {"train": len(train_ds), "val": len(val_ds)}

7. Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        ce = nn.CrossEntropyLoss(reduction="none")(logits, targets)
        pt = torch.exp(-ce)
        focal = ((1 - pt)**self.gamma) * ce
        
        if self.alpha is not None:
            alpha_t = self.alpha.to(logits.device)[targets]
            focal = alpha_t * focal

        return focal.mean() if self.reduction=="mean" else focal.sum()

In [None]:
# compute class weight
labels_np = train_df["label"].values
class_weights_np = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(labels_np),
    y=labels_np
)

# boost malignant a bit
class_weights_np[1] *= 1.2

alpha_tensor = torch.tensor(class_weights_np, dtype=torch.float32).to(device)
criterion = FocalLoss(alpha=alpha_tensor, gamma=2.0)
print("Focal alpha:", alpha_tensor)

8. Hybrid CNN + Transformer Model

In [None]:
class CNNTransformerHybrid(nn.Module):
    def __init__(self, num_classes=3, backbone="resnet18",
                 num_layers=2, nhead=8, dim_feedforward=1024, dropout=0.1):
        super().__init__()

        if backbone=="resnet18":
            resnet = models.resnet18(pretrained=True)
            fdim = 512
        else:
            resnet = models.resnet50(pretrained=True)
            fdim = 2048

        self.conv1 = resnet.conv1
        self.bn1   = resnet.bn1
        self.relu  = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.fdim = fdim
        self.cls_token = nn.Parameter(torch.randn(1,1,fdim))
        self.pos_embed = nn.Parameter(torch.randn(1,50,fdim))

        enc_layer = nn.TransformerEncoderLayer(
            d_model=fdim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        self.head = nn.Sequential(
            nn.LayerNorm(fdim),
            nn.Linear(fdim, fdim//2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(fdim//2, num_classes)
        )

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x); x=self.layer2(x); x=self.layer3(x); x=self.layer4(x)

        B,C,H,W = x.shape
        x = x.view(B, C, H*W).transpose(1,2)

        cls = self.cls_token.expand(B,-1,-1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed[:,:x.size(1)]

        x = self.transformer(x)
        return self.head(x[:,0])

In [None]:
model_hybrid = CNNTransformerHybrid(
    num_classes=3,
    backbone="resnet18"
).to(device)

print(model_hybrid)

9. Two-Phase Fine-Tuning

In [None]:
# Phase 1: Freeze CNN backbone
for name, param in model_hybrid.named_parameters():
    if name.startswith(("conv1","bn1","layer1","layer2","layer3","layer4")):
        param.requires_grad = False
    else:
        param.requires_grad = True

optimizer1 = optim.Adam(
    filter(lambda p: p.requires_grad, model_hybrid.parameters()),
    lr=1e-3
)

# Training loop function
def train_model(model, criterion, optimizer, dataloaders, sizes,
                num_epochs=10, scheduler=None, phase_name="Phase"):

    best_w = None
    best_loss = 1e9

    for epoch in range(num_epochs):
        print(f"\n{phase_name} Epoch {epoch+1}/{num_epochs}")

        # Train + Val
        for phase in ["train","val"]:
            model.train() if phase=="train" else model.eval()

            running_loss=0
            running_corrects=0

            for batch in dataloaders[phase]:
                imgs = batch["image"].to(device)
                labels = batch["label"].to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase=="train"):
                    out = model(imgs)
                    loss = criterion(out, labels)
                    preds = out.argmax(1)

                    if phase=="train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item()*imgs.size(0)
                running_corrects += torch.sum(preds==labels)

            epoch_loss = running_loss/sizes[phase]
            epoch_acc  = running_corrects.double()/sizes[phase]

            print(f"{phase} loss={epoch_loss:.4f} acc={epoch_acc:.4f}")

            if phase=="val" and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_w = model.state_dict()

            if scheduler and phase=="val":
                scheduler.step(epoch_loss)

    model.load_state_dict(best_w)
    return model

# Train Phase 1
model_hybrid = train_model(
    model_hybrid, criterion, optimizer1,
    dataloaders, dataset_sizes,
    num_epochs=8,
    phase_name="Hybrid Phase 1"
)

Phase 2: Unfreeze layer4 + Transformer + Head

In [None]:
for name, param in model_hybrid.named_parameters():
    if name.startswith(("layer4","transformer","head","cls_token","pos_embed")):
        param.requires_grad=True
    else:
        param.requires_grad=False

optimizer2 = optim.Adam(
    filter(lambda p: p.requires_grad, model_hybrid.parameters()),
    lr=1e-4
)

scheduler2 = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer2, mode="min", factor=0.5, patience=2
)

model_hybrid = train_model(
    model_hybrid, criterion, optimizer2,
    dataloaders, dataset_sizes,
    num_epochs=12,
    scheduler=scheduler2,
    phase_name="Hybrid Phase 2"
)

best_model_hybrid = model_hybrid

10. Evaluate on Test Set

In [None]:
best_model_hybrid.eval()
probs=[]
labels=[]

with torch.no_grad():
    for batch in test_loader:
        x=batch["image"].to(device)
        y=batch["label"].to(device)
        o=best_model_hybrid(x)
        p=torch.softmax(o,1)

        probs.append(p.cpu().numpy())
        labels.append(y.cpu().numpy())

y_pred_proba=np.concatenate(probs)
y_test=np.concatenate(labels)
y_pred=np.argmax(y_pred_proba,1)

print("Accuracy:", accuracy_score(y_test, y_pred))
print("Macro F1:", f1_score(y_test, y_pred, average="macro"))
print("Macro Recall:", recall_score(y_test, y_pred, average="macro"))
print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

11. Threshold tuning

In [None]:
def eval_threshold(y_pred_proba, y_true, th):
    yp=[]
    for p in y_pred_proba:
        if p[1] >= th:
            yp.append(1)
        else:
            yp.append(0 if p[0]>=p[2] else 2)

    print("\n=== T =",th,"===")
    print("Acc:", accuracy_score(y_true, yp))
    print("Macro F1:", f1_score(y_true, yp, average="macro"))
    print("Macro Recall:", recall_score(y_true, yp, average="macro"))
    print(confusion_matrix(y_true, yp))

for th in [0.30,0.35,0.40,0.45,0.50]:
    eval_threshold(y_pred_proba, y_test, th)