In [None]:
import os
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image

DATA_DIR = "./data"
TARGET_SIZE = 224
BATCH_SIZE = 32
SHUFFLE = True

class SkinDiseaseDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.class_names = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)}
        self.images = []
        self.labels = []

        for class_name in self.class_names:
            class_dir = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_dir):
                if filename.lower().endswith((".jpg", ".jpeg", ".png")):
                    path = os.path.join(class_dir, filename)
                    self.images.append(path)
                    self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.Resize((TARGET_SIZE, TARGET_SIZE)),
    transforms.ToTensor(),
])

train_dataset = SkinDiseaseDataset(os.path.join(DATA_DIR, "train"), transform=transform)
test_dataset = SkinDiseaseDataset(os.path.join(DATA_DIR, "test"), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"총 학습 이미지 수: {len(train_dataset)}")
print(f"총 테스트 이미지 수: {len(test_dataset)}")
print(f"클래스 매핑: {train_dataset.class_to_idx}")

for imgs, labels in train_loader:
    print("배치 이미지 크기:", imgs.shape)
    print("배치 라벨:", labels)
    break

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()

        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding='same')
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=2, padding='same')
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=2, padding='same')
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=2, padding='same')
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=2, padding='same')
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=2, padding='same')
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=2, padding='same')
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        
        self.fc1 = nn.Linear(512 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc_out = nn.Linear(64, num_classes)
        
        
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.pool2(x)
        
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = self.pool3(x)
        
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = self.pool4(x)
        
        x = torch.flatten(x, 1)

        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout(x)

        x = F.relu(self.fc3(x))
        x = self.dropout(x)

        x = self.fc_out(x)
        
        return x

In [None]:
def test(model, loader, criterion):
    model.eval()
    correct = 0
    total_loss = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device, dtype=torch.long)
            
            output = model(data)
            total_loss += criterion(output, target).item() * data.size(0)
            
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    acc = 100. * correct / len(loader.dataset)
    return avg_loss, acc

def train_model(model, train_loader, val_loader, criterion, optimizer, epochs):
    train_loss_list = []
    train_acc_list = []
    val_loss_list = []
    val_acc_list = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total_samples = 0
        
        for data, target in train_loader:
            data = data.to(device)
            target = target.to(device, dtype=torch.long)

            optimizer.zero_grad()
            output = model(data)
            
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * len(data)
            
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += len(data)

        epoch_loss = total_loss / total_samples
        train_acc = 100. * correct / total_samples
        
        train_loss_list.append(epoch_loss)
        train_acc_list.append(train_acc)
        
        val_loss, val_acc = test(model, val_loader, criterion)
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)

        print(f'Epoch: {epoch+1}/{epochs} | '
              f'Train Loss: {epoch_loss:.4f} | Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
    
    return train_loss_list, train_acc_list, val_loss_list, val_acc_list

In [None]:
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adamax(model.parameters(), lr=0.0001)

In [None]:
import torchinfo

torchinfo.summary(model, input_size=(1, 3, 128, 128), device=device)

In [None]:
import matplotlib.pyplot as plt
import random


model.eval()


incorrect_preds = []


with torch.no_grad():

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        incorrect_indices = (preds != labels).nonzero(as_tuple=True)[0]

        for idx in incorrect_indices:
            image = images[idx].cpu().permute(1, 2, 0)
            pred = preds[idx].cpu().item()
            label = labels[idx].cpu().item()
            incorrect_preds.append((image, pred, label))

num_to_plot = min(len(incorrect_preds), 25)
random_incorrect_preds = random.sample(incorrect_preds, num_to_plot)

fig, axes = plt.subplots(5, 5, figsize=(15, 15))
fig.suptitle('Incorrect Predictions (Predicted vs. True Label)', fontsize=20)

for i, ax in enumerate(axes.flat):
    if i < len(random_incorrect_preds):
        image, pred, label = random_incorrect_preds[i]
        ax.imshow(image)
        ax.set_title(f'pred: {pred}, True: {label}', fontsize=10)
        ax.axis('off')
    else:

        ax.axis('off')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
model.eval()

class_mapping = {'Acne': 0, 'Eczema': 1, 'Psoriasis': 2}
incorrect_by_label = {label: 0 for label in class_mapping.values()}
total_by_label = {label: 0 for label in class_mapping.values()}


with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        for label in labels:
            total_by_label[label.item()] += 1
        incorrect_indices = (preds != labels).nonzero(as_tuple=True)[0]

        for idx in incorrect_indices:

            true_label = labels[idx].item()
            incorrect_by_label[true_label] += 1

print("--- 라벨별 오답 개수 (총 265개 테스트 이미지) ---")
print(f"클래스 매핑: {class_mapping}")
print("-" * 50)

for name, label_id in class_mapping.items():
    total = total_by_label[label_id]
    incorrect = incorrect_by_label[label_id]
    accuracy = (total - incorrect) / total * 100 if total > 0 else 0
    
    print(f"'{name}' (라벨 {label_id})")
    print(f"  총 이미지 수: {total}")
    print(f"  오답 예측 수: {incorrect}")
    print(f"  정확도: {accuracy:.2f}%")
    print("-" * 50)