In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os
import random
import glob
import copy
import time
import torch.nn.functional as F

# ==========================================
# 1. CẤU HÌNH & DATASET (GIỮ NGUYÊN LOGIC CŨ)
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Đang sử dụng thiết bị: {device}")

DATA_DIR = '/kaggle/input/teeth-dataset/Tooth dataset' 
UNKNOWN_DIR = '/kaggle/input/vehicle-classification/test'  
BATCH_SIZE = 32 
NUM_EPOCHS = 20 # Tăng epoch vì model lớn + CBAM cần thời gian hội tụ
LEARNING_RATE = 0.0005 # Giảm LR một chút vì ResNet finetune nhạy cảm hơn
MAX_IMAGES_PER_CLASS = 400 

# --- Định nghĩa Augmentation (Như cũ) ---
base_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

normal_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

strong_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomAffine(degrees=180, translate=(0.2, 0.2), shear=20, scale=(0.6, 1.4)),
    transforms.RandomPerspective(distortion_scale=0.4, p=0.5),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# --- Dataset Class (Giữ nguyên logic thông minh của bạn) ---
class BalancedTeethDataset(Dataset):
    def __init__(self, root_dir, unknown_dir, max_per_class=300, mode='train'):
        self.data = [] 
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        self.classes.append("Unknown")
        self.unknown_label = len(self.classes) - 1
        
        # Load bệnh lý
        for cls_name in os.listdir(root_dir):
            cls_folder = os.path.join(root_dir, cls_name)
            if not os.path.isdir(cls_folder): continue
            images = glob.glob(os.path.join(cls_folder, "*.*"))
            images = [f for f in images if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            
            if len(images) > max_per_class:
                selected_images = random.sample(images, max_per_class)
            else:
                selected_images = images
            
            label = self.class_to_idx[cls_name]
            for img_path in selected_images:
                self.data.append((img_path, label))
                
        # Load Unknown
        if os.path.exists(unknown_dir):
            unknown_images = glob.glob(os.path.join(unknown_dir, "*.*"))
            unknown_images = [f for f in unknown_images if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            if len(unknown_images) > max_per_class:
                selected_unknown = random.sample(unknown_images, max_per_class)
            else:
                selected_unknown = unknown_images
            for img_path in selected_unknown:
                self.data.append((img_path, self.unknown_label))

        self.mode = mode 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224))

        if self.mode == 'val':
            image = base_transform(image)
        else:
            if label == self.unknown_label:
                image = strong_aug(image)
            else:
                image = normal_aug(image)
        return image, label

# ==========================================
# 2. ĐỊNH NGHĨA MODULE CBAM (Attention)
# ==========================================
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Lưu ý: ResNet18 channel nhỏ, nếu in_planes < 16 thì ratio=16 sẽ lỗi
        # Ta điều chỉnh ratio an toàn
        safe_ratio = ratio if in_planes >= ratio else 1
        
        self.fc1 = nn.Conv2d(in_planes, in_planes // safe_ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // safe_ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv1(x))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

# ==========================================
# 2. CBAM TÍCH HỢP VÀO RESNET18
# ==========================================
class CBAMResNet18(nn.Module):
    def __init__(self, num_classes):
        super(CBAMResNet18, self).__init__()
        # Load ResNet18 pretrained
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # --- CẤU HÌNH KÊNH CỦA RESNET18 ---
        # Layer 1: 64 channels
        # Layer 2: 128 channels
        # Layer 3: 256 channels
        # Layer 4: 512 channels (Nhỏ hơn nhiều so với 2048 của ResNet50)
        
        self.cbam1 = CBAM(64)
        self.cbam2 = CBAM(128)
        self.cbam3 = CBAM(256)
        self.cbam4 = CBAM(512)
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # Input của FC layer cuối cùng là 512
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        # Phần đầu (Stem)
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        # Layer 1 + CBAM
        x = self.backbone.layer1(x)
        x = self.cbam1(x)

        # Layer 2 + CBAM
        x = self.backbone.layer2(x)
        x = self.cbam2(x)

        # Layer 3 + CBAM
        x = self.backbone.layer3(x)
        x = self.cbam3(x)

        # Layer 4 + CBAM
        x = self.backbone.layer4(x)
        x = self.cbam4(x)

        # Head
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# ==========================================
# 4. CHUẨN BỊ DỮ LIỆU & TRAINING
# ==========================================

# Dataset
full_dataset = BalancedTeethDataset(DATA_DIR, UNKNOWN_DIR, max_per_class=MAX_IMAGES_PER_CLASS, mode='train')
class_names = full_dataset.classes
NUM_CLASSES = len(class_names)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_subset, val_subset = random_split(full_dataset, [train_size, val_size])

dataloaders = {
    'train': DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
    'validation': DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
}
dataset_sizes = {'train': len(train_subset), 'validation': len(val_subset)}

print(f"Model ResNet50 + CBAM initialized for {NUM_CLASSES} classes.")
model = CBAMResNet18(num_classes=NUM_CLASSES)
model = model.to(device)

# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
# Dùng LR nhỏ cho phần backbone (feature extractor) và LR lớn hơn cho phần CBAM + Classifier nếu muốn (ở đây dùng chung cho đơn giản)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Hàm Train (Giữ nguyên logic lưu model của bạn)
def train_model(model, criterion, optimizer, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'validation' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), 'resnet18_cbam_teeth_best.pth')
                print(f"--> New Best Model! Acc: {best_acc:.4f}")

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model

if __name__ == '__main__':
    if not os.path.exists(UNKNOWN_DIR):
        print(f"LƯU Ý: Hãy tạo folder {UNKNOWN_DIR} hoặc sửa đường dẫn trước khi chạy!")
    else:
        model_ft = train_model(model, criterion, optimizer, num_epochs=NUM_EPOCHS)

Đang sử dụng thiết bị: cuda
Model ResNet50 + CBAM initialized for 6 classes.


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 219MB/s]


Epoch 1/20
----------
train Loss: 0.6563 Acc: 0.7770
validation Loss: 0.6243 Acc: 0.7851
--> New Best Model! Acc: 0.7851

Epoch 2/20
----------
train Loss: 0.3575 Acc: 0.8719
validation Loss: 0.5073 Acc: 0.8382
--> New Best Model! Acc: 0.8382

Epoch 3/20
----------
train Loss: 0.3154 Acc: 0.8938
validation Loss: 0.3304 Acc: 0.8859
--> New Best Model! Acc: 0.8859

Epoch 4/20
----------
train Loss: 0.2753 Acc: 0.9064
validation Loss: 0.2894 Acc: 0.8992
--> New Best Model! Acc: 0.8992

Epoch 5/20
----------
train Loss: 0.2175 Acc: 0.9336
validation Loss: 0.2179 Acc: 0.9257
--> New Best Model! Acc: 0.9257

Epoch 6/20
----------
train Loss: 0.1722 Acc: 0.9436
validation Loss: 0.3157 Acc: 0.9019

Epoch 7/20
----------
train Loss: 0.1915 Acc: 0.9303
validation Loss: 0.3614 Acc: 0.8780

Epoch 8/20
----------
train Loss: 0.1802 Acc: 0.9396
validation Loss: 0.2532 Acc: 0.9045

Epoch 9/20
----------
train Loss: 0.1285 Acc: 0.9628
validation Loss: 0.2072 Acc: 0.9125

Epoch 10/20
----------
train L