In [1]:
import os
from typing import List, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from PIL import Image
import timm  # pip install timm n·∫øu ch∆∞a c√≥


# ============= CONFIG =============
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_ROOT = "/kaggle/input/dermnet"   # s·ª≠a path cho ƒë√∫ng
BATCH_SIZE = 32
IMAGE_SIZE = 224
EPOCHS = 50
BEST_MODEL_PATH = "best_medagen_resnet18_vits_cbam.pth"

SELECTED_CLASSES: List[str] =['Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos', 'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', 'Warts Molluscum and other Viral Infections']


train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.85, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),

    transforms.RandomAffine(degrees=15, translate=(0.05, 0.05)),

    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0),

    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

test_tfms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])


def load_filtered_dataset(split: str, transform):
    root = os.path.join(DATA_ROOT, split)
    ds = datasets.ImageFolder(root=root, transform=transform)

    orig_class_to_idx = ds.class_to_idx

    for cls in SELECTED_CLASSES:
        if cls not in orig_class_to_idx:
            raise ValueError(f"Kh√¥ng t√¨m th·∫•y class: {cls} trong {root}")

    allowed = {orig_class_to_idx[c] for c in SELECTED_CLASSES}
    orig_to_new = {orig_class_to_idx[c]: i for i, c in enumerate(SELECTED_CLASSES)}

    filtered_samples = []
    for path, target in ds.samples:
        if target in allowed:
            filtered_samples.append((path, orig_to_new[target]))

    ds.samples = filtered_samples
    ds.targets = [t for _, t in filtered_samples]
    ds.classes = SELECTED_CLASSES
    ds.class_to_idx = {cls: i for i, cls in enumerate(SELECTED_CLASSES)}
    return ds

class CBAMBlock(nn.Module):
    def __init__(self, channels, reduction=16, spatial_kernel=7):
        super(CBAMBlock, self).__init__()
        # Channel attention
        self.channel_att = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, bias=False),
            nn.Sigmoid()
        )
        # Spatial attention
        self.spatial_att = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=spatial_kernel, padding=spatial_kernel // 2, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Channel attention
        ca = self.channel_att(x)
        x = x * ca
        # Spatial attention
        sa = torch.cat([
            torch.mean(x, dim=1, keepdim=True), 
            torch.max(x, dim=1, keepdim=True)[0]
        ], dim=1)
        sa = self.spatial_att(sa)
        x = x * sa
        return x

# ============== Hybrid Model (Gated Sum Fusion) ==============
class ViTCNNHybrid(nn.Module):
    def __init__(self, num_classes, use_cbam=True):
        super(ViTCNNHybrid, self).__init__()
        
        self.vit = timm.create_model(
            'swin_tiny_patch4_window7_224', pretrained=True, num_classes=0, drop_rate=0.3
        )
        self.vit_out_features = 768
        
        # ConvNeXt-Tiny
        self.cnn = timm.create_model(
            'convnext_tiny', pretrained=True, num_classes=0, drop_rate=0.3, global_pool=''
        )
        self.cnn_out_features = 768  # ConvNeXt-Tiny output features
        self.cnn_pool = nn.AdaptiveAvgPool2d((7, 7))
        
        # Gates
        self.vit_gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(self.vit_out_features, self.vit_out_features // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.vit_out_features // 16, self.vit_out_features, 1),
            nn.Sigmoid()
        )
        self.cnn_gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(self.cnn_out_features, self.cnn_out_features // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.cnn_out_features // 16, self.cnn_out_features, 1),
            nn.Sigmoid()
        )
        
        self.match_dim = nn.Conv2d(self.vit_out_features, self.cnn_out_features, 1)

        # Learnable Œ± for dynamic fusion
        self.alpha_param = nn.Parameter(torch.tensor(0.5))

        # Fusion
        fusion_layers = [
            nn.Conv2d(self.cnn_out_features, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        ]
        if use_cbam:
            fusion_layers.append(CBAMBlock(256))
        fusion_layers.append(nn.AdaptiveAvgPool2d((1, 1)))
        self.fusion = nn.Sequential(*fusion_layers)
        
        # FC
        self.fc = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # ViT branch
        vit_out = self.vit(x)  # (B, 768, 7, 7)
        vit_out = vit_out.view(-1, self.vit_out_features, 1, 1).expand(-1, -1, 7, 7)
        vit_out = vit_out * self.vit_gate(vit_out)

        
        # CNN branch
        cnn_out = self.cnn(x)  # ConvNeXt-Tiny outputs (B, 768, 7, 7) directly
        cnn_out = self.cnn_pool(cnn_out)  # Ensure (B, 768, 7, 7)
        cnn_out = cnn_out * self.cnn_gate(cnn_out)

        # Dynamic Fusion
        alpha = torch.sigmoid(self.alpha_param)
        combined = alpha * vit_out + (1 - alpha) * cnn_out
        
        combined = self.fusion(combined)  # (B, 256, 1, 1)
        combined = combined.view(combined.size(0), -1)  # (B, 256)
        out = self.fc(combined)
        return out


# ============= LOAD DATA =============
train_dataset = load_filtered_dataset("train", train_tfms)
test_dataset  = load_filtered_dataset("test",  test_tfms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print("S·ªë l·ªõp:", len(SELECTED_CLASSES))
print("S·ªë ·∫£nh train:", len(train_dataset), "| S·ªë ·∫£nh test:", len(test_dataset))


# ============= INIT MODEL =============
num_classes = len(SELECTED_CLASSES)
model = ViTCNNHybrid(num_classes=num_classes).to(DEVICE)
print('params:', sum(p.numel() for p in model.parameters()))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
best_test_acc = 0.0


# ============= TRAIN LOOP =============
for epoch in range(1, EPOCHS + 1):
    # ---- Train ----
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc  = correct / total

    # ---- Eval ----
    model.eval()
    correct_test, total_test = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            correct_test += preds.eq(labels).sum().item()
            total_test += labels.size(0)

    test_acc = correct_test / total_test

    print(f"Epoch {epoch}/{EPOCHS} | "
          f"TrainLoss {train_loss:.4f} | TrainAcc {train_acc:.4f} | TestAcc {test_acc:.4f}")

    # ---- Save best ----
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save(
            {"model_state": model.state_dict(), "classes": SELECTED_CLASSES},
            BEST_MODEL_PATH,
        )
        print(f"üî• L∆∞u best model (TestAcc={test_acc:.4f}) ‚Üí {BEST_MODEL_PATH}")


# ============= INFERENCE 1 ·∫¢NH =============
IMG_PATH = "example.jpg"  # s·ª≠a path ·∫£nh ri√™ng ƒë·ªÉ test

if os.path.exists(IMG_PATH) and os.path.isfile(IMG_PATH):
    ckpt = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    img = Image.open(IMG_PATH).convert("RGB")
    img_t = test_tfms(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(img_t)
        probs = torch.softmax(logits, dim=1)[0]
        topk = torch.topk(probs, k=3)

    print(f"\n=== K·∫øt qu·∫£ inference cho: {IMG_PATH} ===")
    for i in range(topk.indices.size(0)):
        idx = topk.indices[i].item()
        cls_name = SELECTED_CLASSES[idx]
        p = float(topk.values[i]) * 100
        print(f"{i+1}. {cls_name} ‚Äî {p:.2f}%")
else:
    print("‚ö†Ô∏è Kh√¥ng t√¨m th·∫•y IMG_PATH, b·ªè qua inference.")




S·ªë l·ªõp: 23
S·ªë ·∫£nh train: 15557 | S·ªë ·∫£nh test: 4002


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

params: 56428212
Epoch 1/50 | TrainLoss 2.6410 | TrainAcc 0.2392 | TestAcc 0.3236
üî• L∆∞u best model (TestAcc=0.3236) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 2/50 | TrainLoss 2.1647 | TrainAcc 0.3600 | TestAcc 0.4185
üî• L∆∞u best model (TestAcc=0.4185) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 3/50 | TrainLoss 1.8516 | TrainAcc 0.4530 | TestAcc 0.5067
üî• L∆∞u best model (TestAcc=0.5067) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 4/50 | TrainLoss 1.5916 | TrainAcc 0.5316 | TestAcc 0.5482
üî• L∆∞u best model (TestAcc=0.5482) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 5/50 | TrainLoss 1.3681 | TrainAcc 0.5926 | TestAcc 0.5670
üî• L∆∞u best model (TestAcc=0.5670) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 6/50 | TrainLoss 1.1868 | TrainAcc 0.6449 | TestAcc 0.5920
üî• L∆∞u best model (TestAcc=0.5920) ‚Üí best_medagen_resnet18_vits_cbam.pth
Epoch 7/50 | TrainLoss 1.0141 | TrainAcc 0.6943 | TestAcc 0.6052
üî• L∆∞u best model (TestAcc=0.6052) ‚Üí best_medagen_resne