In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import random
import glob
import copy
import time

# ==========================================
# 1. CẤU HÌNH & THAM SỐ
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Đang sử dụng thiết bị: {device}")

# Đường dẫn gốc chứa folder 'train' và 'validation'
DATA_DIR = '/kaggle/input/nail-disease-detection-dataset/data' 

# Đường dẫn folder chứa ảnh nhiễu/unknown (Folder "a")
UNKNOWN_DIR = '/kaggle/input/vehicle-classification/test'  # Thay bằng đường dẫn thật của bạn

BATCH_SIZE = 32
NUM_EPOCHS = 15
LEARNING_RATE = 0.0005
MAX_IMAGES_PER_CLASS = 800  # Chỉ áp dụng cho tập Train

# ==========================================
# 2. ĐỊNH NGHĨA AUGMENTATION
# ==========================================

# Transform cơ bản (cho Val/Test)
base_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Augmentation nhẹ cho các lớp bệnh móng tay (giữ cấu trúc móng)
normal_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10), # Xoay ít hơn vì hướng móng tay quan trọng
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Augmentation CỰC MẠNH cho lớp Unknown
strong_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomAffine(degrees=180, translate=(0.2, 0.2), shear=20, scale=(0.5, 1.5)),
    transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ==========================================
# 3. CUSTOM DATASET (XỬ LÝ LOGIC)
# ==========================================
class BalancedNailDataset(Dataset):
    def __init__(self, root_dir, unknown_dir=None, max_per_class=800, is_train=True):
        self.data = [] 
        # Lấy danh sách lớp từ thư mục hiện tại (train hoặc val)
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        self.is_train = is_train
        
        # Nếu là tập Train, ta thêm lớp Unknown vào danh sách quản lý
        if self.is_train and unknown_dir:
            self.classes.append("Unknown")
            self.unknown_label = len(self.classes) - 1
        else:
            self.unknown_label = -1 # Không có unknown trong validation (hoặc xử lý sau)

        # --- LOAD ẢNH TỪ CÁC LỚP BỆNH ---
        for cls_name in os.listdir(root_dir):
            cls_folder = os.path.join(root_dir, cls_name)
            if not os.path.isdir(cls_folder): continue
            
            images = glob.glob(os.path.join(cls_folder, "*.*"))
            images = [f for f in images if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
            
            # LOGIC: Chỉ cân bằng số lượng ở tập Train
            if self.is_train and len(images) > max_per_class:
                selected_images = random.sample(images, max_per_class)
            else:
                selected_images = images # Val lấy hết, hoặc Train ít hơn 300 thì lấy hết
            
            label = self.class_to_idx[cls_name]
            for img_path in selected_images:
                self.data.append((img_path, label))

        # --- LOAD ẢNH UNKNOWN (CHỈ CHO TRAIN) ---
        if self.is_train and unknown_dir and os.path.exists(unknown_dir):
            unknown_images = glob.glob(os.path.join(unknown_dir, "*.*"))
            unknown_images = [f for f in unknown_images if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            
            # Cân bằng lớp Unknown
            if len(unknown_images) > max_per_class:
                selected_unknown = random.sample(unknown_images, max_per_class)
            else:
                selected_unknown = unknown_images

            for img_path in selected_unknown:
                self.data.append((img_path, self.unknown_label))
            
            print(f"[TRAIN] Đã thêm Unknown. Tổng ảnh Train: {len(self.data)}")
        elif not self.is_train:
            print(f"[VAL] Load validation set. Tổng ảnh Val: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224))

        # Logic Augmentation
        if not self.is_train:
            # Validation -> Base Transform
            image = base_transform(image)
        else:
            # Train -> Check Unknown
            if label == self.unknown_label:
                image = strong_aug(image)
            else:
                image = normal_aug(image)

        return image, label
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Lưu ý: ResNet18 channel nhỏ, nếu in_planes < 16 thì ratio=16 sẽ lỗi
        # Ta điều chỉnh ratio an toàn
        safe_ratio = ratio if in_planes >= ratio else 1
        
        self.fc1 = nn.Conv2d(in_planes, in_planes // safe_ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // safe_ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv1(x))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

# ==========================================
# 2. CBAM TÍCH HỢP VÀO RESNET18
# ==========================================
class CBAMResNet18(nn.Module):
    def __init__(self, num_classes):
        super(CBAMResNet18, self).__init__()
        # Load ResNet18 pretrained
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # --- CẤU HÌNH KÊNH CỦA RESNET18 ---
        # Layer 1: 64 channels
        # Layer 2: 128 channels
        # Layer 3: 256 channels
        # Layer 4: 512 channels (Nhỏ hơn nhiều so với 2048 của ResNet50)
        
        self.cbam1 = CBAM(64)
        self.cbam2 = CBAM(128)
        self.cbam3 = CBAM(256)
        self.cbam4 = CBAM(512)
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # Input của FC layer cuối cùng là 512
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        # Phần đầu (Stem)
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        # Layer 1 + CBAM
        x = self.backbone.layer1(x)
        x = self.cbam1(x)

        # Layer 2 + CBAM
        x = self.backbone.layer2(x)
        x = self.cbam2(x)

        # Layer 3 + CBAM
        x = self.backbone.layer3(x)
        x = self.cbam3(x)

        # Layer 4 + CBAM
        x = self.backbone.layer4(x)
        x = self.cbam4(x)

        # Head
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# ==========================================
# 5. CHUẨN BỊ DATA LOADERS
# ==========================================
train_dir = os.path.join(DATA_DIR, 'train')
val_dir = os.path.join(DATA_DIR, 'validation')

# Dataset Train: Có cân bằng, Có Unknown, Có Augmentation mạnh
train_dataset = BalancedNailDataset(train_dir, UNKNOWN_DIR, max_per_class=MAX_IMAGES_PER_CLASS, is_train=True)

# Dataset Val: Giữ nguyên, không thêm Unknown (trừ khi folder Val có sẵn), không giới hạn
val_dataset = BalancedNailDataset(val_dir, unknown_dir=None, max_per_class=9999, is_train=False)

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
    'validation': DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
}
dataset_sizes = {'train': len(train_dataset), 'validation': len(val_dataset)}

# Cập nhật số lượng lớp thực tế (bao gồm Unknown nếu có)
class_names = train_dataset.classes
NUM_CLASSES = len(class_names)
print(f"Final Classes: {class_names} (Total: {NUM_CLASSES})")

# ==========================================
# 6. TRAINING LOOP
# ==========================================
model = CBAMResNet18(num_classes=NUM_CLASSES)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

def train_model(model, criterion, optimizer, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'validation' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), 'resnet18_cbam_nail_best.pth')
                print(f"--> New Best Model! Acc: {best_acc:.4f}")

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model

if __name__ == '__main__':
    # Kiểm tra folder Unknown trước khi chạy
    if not os.path.exists(UNKNOWN_DIR):
        print(f"WARNING: Không tìm thấy {UNKNOWN_DIR}. Model sẽ train mà không có lớp Unknown!")
    
    model_ft = train_model(model, criterion, optimizer, num_epochs=NUM_EPOCHS)

Đang sử dụng thiết bị: cuda
[TRAIN] Đã thêm Unknown. Tổng ảnh Train: 3944
[VAL] Load validation set. Tổng ảnh Val: 91
Final Classes: ['Acral_Lentiginous_Melanoma', 'Healthy_Nail', 'Onychogryphosis', 'blue_finger', 'clubbing', 'pitting', 'Unknown'] (Total: 7)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 188MB/s]


Epoch 1/15
----------
train Loss: 0.6426 Acc: 0.7766
validation Loss: 0.8699 Acc: 0.7143
--> New Best Model! Acc: 0.7143

Epoch 2/15
----------
train Loss: 0.3849 Acc: 0.8689
validation Loss: 0.2722 Acc: 0.8901
--> New Best Model! Acc: 0.8901

Epoch 3/15
----------
train Loss: 0.3131 Acc: 0.8895
validation Loss: 0.1937 Acc: 0.9121
--> New Best Model! Acc: 0.9121

Epoch 4/15
----------
train Loss: 0.2546 Acc: 0.9105
validation Loss: 0.4066 Acc: 0.8242

Epoch 5/15
----------
train Loss: 0.2281 Acc: 0.9217
validation Loss: 0.3015 Acc: 0.8901

Epoch 6/15
----------
train Loss: 0.1802 Acc: 0.9338
validation Loss: 0.3480 Acc: 0.8901

Epoch 7/15
----------
train Loss: 0.1871 Acc: 0.9366
validation Loss: 0.4169 Acc: 0.8791

Epoch 8/15
----------
train Loss: 0.1141 Acc: 0.9602
validation Loss: 0.5315 Acc: 0.8681

Epoch 9/15
----------
train Loss: 0.1620 Acc: 0.9462
validation Loss: 0.1713 Acc: 0.9341
--> New Best Model! Acc: 0.9341

Epoch 10/15
----------
train Loss: 0.1401 Acc: 0.9473
validati