In [1]:
import os
from typing import List, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from PIL import Image
import timm  # pip install timm n·∫øu ch∆∞a c√≥


# ============= CONFIG =============
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_ROOT = "/kaggle/input/dermnet"   # s·ª≠a path cho ƒë√∫ng
BATCH_SIZE = 32
IMAGE_SIZE = 224
EPOCHS = 5
BEST_MODEL_PATH = "best_medagen_resnet18_vits_cbam.pth"

SELECTED_CLASSES: List[str] = [
    "Acne and Rosacea Photos",
    "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
    "Atopic Dermatitis Photos",
    "Cellulitis Impetigo and other Bacterial Infections",
    "Eczema Photos",
    "Hair Loss Photos Alopecia and other Hair Diseases",
    "Melanoma Skin Cancer Nevi and Moles",
    "Nail Fungus and other Nail Disease",
    "Poison Ivy Photos and other Contact Dermatitis",
    "Psoriasis pictures Lichen Planus and related diseases",
    "Scabies Lyme Disease and other Infestations and Bites",
    "Seborrheic Keratoses and other Benign Tumors",
    "Tinea Ringworm Candidiasis and other Fungal Infections",
    "Warts Molluscum and other Viral Infections",
]


# ============= TRANSFORMS =============
train_tfms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

test_tfms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])


def load_filtered_dataset(split: str, transform):
    root = os.path.join(DATA_ROOT, split)
    ds = datasets.ImageFolder(root=root, transform=transform)

    orig_class_to_idx = ds.class_to_idx

    for cls in SELECTED_CLASSES:
        if cls not in orig_class_to_idx:
            raise ValueError(f"Kh√¥ng t√¨m th·∫•y class: {cls} trong {root}")

    allowed = {orig_class_to_idx[c] for c in SELECTED_CLASSES}
    orig_to_new = {orig_class_to_idx[c]: i for i, c in enumerate(SELECTED_CLASSES)}

    filtered_samples = []
    for path, target in ds.samples:
        if target in allowed:
            filtered_samples.append((path, orig_to_new[target]))

    ds.samples = filtered_samples
    ds.targets = [t for _, t in filtered_samples]
    ds.classes = SELECTED_CLASSES
    ds.class_to_idx = {cls: i for i, cls in enumerate(SELECTED_CLASSES)}
    return ds


# ============= CBAM MODULE =============
class CBAM(nn.Module):
    def __init__(self, channels: int, reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        # Channel attention
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
        )
        # Spatial attention
        self.spatial = nn.Conv2d(2, 1, kernel_size=spatial_kernel,
                                 padding=spatial_kernel // 2, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.size()

        # ----- Channel attention -----
        avg_pool = F.adaptive_avg_pool2d(x, 1).view(b, c)
        max_pool = F.adaptive_max_pool2d(x, 1).view(b, c)
        ch_att = torch.sigmoid(self.mlp(avg_pool) + self.mlp(max_pool)).view(b, c, 1, 1)
        x = x * ch_att

        # ----- Spatial attention -----
        avg = torch.mean(x, dim=1, keepdim=True)
        mx, _ = torch.max(x, dim=1, keepdim=True)
        s = torch.cat([avg, mx], dim=1)   # [B, 2, H, W]
        sp_att = torch.sigmoid(self.spatial(s))
        x = x * sp_att
        return x


# ============= FUSION MODEL =============
class ResNet18_ViTS_CBAM(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        # ResNet18 backbone (feature map 512-d)
        rn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.resnet_backbone = nn.Sequential(*list(rn.children())[:-1])  # [B,512,1,1]
        res_dim = 512

        # ViT small patch16 224 (timm)
        self.vit = timm.create_model("vit_small_patch16_224", pretrained=True)
        vit_dim = self.vit.embed_dim
        # b·ªè head, ch·ªâ l·∫•y embedding
        if hasattr(self.vit, "head"):
            self.vit.reset_classifier(0)

        fused_dim = res_dim + vit_dim
        self.cbam = CBAM(fused_dim, reduction=16, spatial_kernel=3)
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # ResNet branch
        r = self.resnet_backbone(x)              # [B,512,1,1]
        r = r.view(r.size(0), -1)                # [B,512]

        # ViT branch
        v = self.vit(x)                          # [B,vit_dim]

        feat = torch.cat([r, v], dim=1)          # [B, C]
        feat_4d = feat.unsqueeze(-1).unsqueeze(-1)  # [B,C,1,1]
        feat_4d = self.cbam(feat_4d)             # CBAM attention
        feat = feat_4d.view(feat_4d.size(0), -1)
        out = self.classifier(feat)
        return out


# ============= LOAD DATA =============
train_dataset = load_filtered_dataset("train", train_tfms)
test_dataset  = load_filtered_dataset("test",  test_tfms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print("S·ªë l·ªõp:", len(SELECTED_CLASSES))
print("S·ªë ·∫£nh train:", len(train_dataset), "| S·ªë ·∫£nh test:", len(test_dataset))


# ============= INIT MODEL =============
num_classes = len(SELECTED_CLASSES)
model = ResNet18_ViTS_CBAM(num_classes=num_classes).to(DEVICE)
print('params:', sum(p.numel() for p in model.parameters()))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
best_test_acc = 0.0


# ============= TRAIN LOOP =============
for epoch in range(1, EPOCHS + 1):
    # ---- Train ----
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc  = correct / total

    # ---- Eval ----
    model.eval()
    correct_test, total_test = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            correct_test += preds.eq(labels).sum().item()
            total_test += labels.size(0)

    test_acc = correct_test / total_test

    print(f"Epoch {epoch}/{EPOCHS} | "
          f"TrainLoss {train_loss:.4f} | TrainAcc {train_acc:.4f} | TestAcc {test_acc:.4f}")

    # ---- Save best ----
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save(
            {"model_state": model.state_dict(), "classes": SELECTED_CLASSES},
            BEST_MODEL_PATH,
        )
        print(f"üî• L∆∞u best model (TestAcc={test_acc:.4f}) ‚Üí {BEST_MODEL_PATH}")


# ============= INFERENCE 1 ·∫¢NH =============
IMG_PATH = "example.jpg"  # s·ª≠a path ·∫£nh ri√™ng ƒë·ªÉ test

if os.path.exists(IMG_PATH) and os.path.isfile(IMG_PATH):
    ckpt = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    img = Image.open(IMG_PATH).convert("RGB")
    img_t = test_tfms(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(img_t)
        probs = torch.softmax(logits, dim=1)[0]
        topk = torch.topk(probs, k=3)

    print(f"\n=== K·∫øt qu·∫£ inference cho: {IMG_PATH} ===")
    for i in range(topk.indices.size(0)):
        idx = topk.indices[i].item()
        cls_name = SELECTED_CLASSES[idx]
        p = float(topk.values[i]) * 100
        print(f"{i+1}. {cls_name} ‚Äî {p:.2f}%")
else:
    print("‚ö†Ô∏è Kh√¥ng t√¨m th·∫•y IMG_PATH, b·ªè qua inference.")




S·ªë l·ªõp: 14
S·ªë ·∫£nh train: 11596 | S·ªë ·∫£nh test: 3007


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 44.7M/44.7M [00:00<00:00, 204MB/s]


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

params: 33408992
Epoch 1/5 | TrainLoss 1.6815 | TrainAcc 0.4588 | TestAcc 0.5647
üî• L∆∞u best model (TestAcc=0.5647) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 2/5 | TrainLoss 1.1282 | TrainAcc 0.6319 | TestAcc 0.6498
üî• L∆∞u best model (TestAcc=0.6498) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 3/5 | TrainLoss 0.8151 | TrainAcc 0.7352 | TestAcc 0.6698
üî• L∆∞u best model (TestAcc=0.6698) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 4/5 | TrainLoss 0.5581 | TrainAcc 0.8126 | TestAcc 0.6874
üî• L∆∞u best model (TestAcc=0.6874) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 5/5 | TrainLoss 0.3884 | TrainAcc 0.8696 | TestAcc 0.6997
üî• L∆∞u best model (TestAcc=0.6997) ‚Üí best_medagen_resnet18_vits_cbam.pth
‚ö†Ô∏è Kh√¥ng t√¨m th·∫•y IMG_PATH, b·ªè qua inference.
