프로젝트: Mixup 또는 CutMix 비교실험 하기


In [1]:
import torch
import numpy as np

print(torch.__version__)
print(np.__version__)

2.7.1+cu118
2.2.6


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [3]:
# ----------------------------------------------------
# 1. 전처리 & 기본 Augmentation + DataLoader 함수
# ----------------------------------------------------

def base_transform():
    """Resize + ToTensor + Normalize만 포함된 기본 전처리"""
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])
    ])

def augment_transform():
    """기본 Augmentation (좌우반전 + 밝기 조절)"""
    return transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2)
    ])

def apply_normalize_on_dataset(dataset,
                               is_test=False,
                               batch_size=32,
                               with_aug=False):
    """
    - normalize, resize, (옵션) augmentation, shuffle 적용
    - is_test=True : augmentation X, shuffle X
    """
    # Subset 인 경우를 고려해 실제 dataset과 indices 분리
    if isinstance(dataset, torch.utils.data.Subset):
        raw_dataset = dataset.dataset
        indices = dataset.indices
    else:
        raw_dataset = dataset
        indices = None

    tf_list = []
    if with_aug and not is_test:
        tf_list.append(augment_transform())
    tf_list.append(base_transform())
    transform = transforms.Compose(tf_list)

    raw_dataset.transform = transform

    if indices is not None:
        dataset = torch.utils.data.Subset(raw_dataset, indices)
    else:
        dataset = raw_dataset

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=not is_test,
        num_workers=2,
        pin_memory=True
    )
    return dataloader

In [4]:
# ----------------------------------------------------
# 2. Mixup & CutMix 유틸 함수 + soft CE
# ----------------------------------------------------

def onehot(labels, num_classes):
    return F.one_hot(labels, num_classes=num_classes).float()

def mixup_data(x, y, alpha=1.0):
    if alpha > 0.0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def rand_bbox(size, lam):
    """CutMix용 bbox 생성"""
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # 중앙 좌표 샘플링
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    x1 = np.clip(cx - cut_w // 2, 0, W)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    y2 = np.clip(cy + cut_h // 2, 0, H)

    return x1, y1, x2, y2

def cutmix_data(x, y, alpha=1.0):
    if alpha > 0.0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    shuffled_x = x[index]
    shuffled_y = y[index]

    x1, y1, x2, y2 = rand_bbox(x.size(), lam)
    new_x = x.clone()
    new_x[:, :, y1:y2, x1:x2] = shuffled_x[:, :, y1:y2, x1:x2]

    # 실제 lam은 박스 면적 비율로 다시 조정
    box_area = (x2 - x1) * (y2 - y1)
    lam = 1.0 - box_area / (x.size(2) * x.size(3))

    y_a, y_b = y, shuffled_y
    return new_x, y_a, y_b, lam

def soft_cross_entropy(pred, target):
    """
    categorical_crossentropy에 해당 (soft-label용)
    pred : (B, C) logits
    target : (B, C) one-hot 또는 soft label
    """
    log_probs = F.log_softmax(pred, dim=1)
    loss = -(target * log_probs).sum(dim=1).mean()
    return loss


In [6]:
# ----------------------------------------------------
# 3. ResNet-50 모델 생성 함수
# ----------------------------------------------------

def create_resnet50(num_classes):
    model = models.resnet50(weights=None)    # 필요하면 pretrained으로 변경 가능
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


In [7]:
# ----------------------------------------------------
# 4. 데이터셋 준비
# ----------------------------------------------------

dataset_dir = "~/work/data_augmentation/data/Images/"
full_dataset = ImageFolder(root=dataset_dir, transform=base_transform())

num_classes = len(full_dataset.classes)
print("num_classes:", num_classes)

total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size   = total_size - train_size
ds_train, ds_val = random_split(full_dataset, [train_size, val_size])

# DataLoader (기본 Aug / 검증)
train_loader_baseline = apply_normalize_on_dataset(
    ds_train, is_test=False, batch_size=32, with_aug=True
)
val_loader = apply_normalize_on_dataset(
    ds_val, is_test=True, batch_size=32, with_aug=False
)

num_classes: 120


In [8]:
# ----------------------------------------------------
# 5. 학습 루프 (baseline / mixup / cutmix 공통)
# ----------------------------------------------------

def train_model(model,
                train_loader,
                val_loader,
                num_epochs=5,
                mode="baseline",
                alpha=1.0):

    """
    mode:
      - "baseline" : 기본 augmentation만 사용
      - "mixup"    : Mixup 적용
      - "cutmix"   : CutMix 적용
    """
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion_plain = nn.CrossEntropyLoss()

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": []
    }

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            if mode == "mixup":
                # Mixup → soft label + categorical_crossentropy
                mixed_x, y_a, y_b, lam = mixup_data(images, labels, alpha)
                y_a = onehot(y_a, num_classes).to(device)
                y_b = onehot(y_b, num_classes).to(device)

                outputs = model(mixed_x)
                loss = lam * soft_cross_entropy(outputs, y_a) + \
                       (1 - lam) * soft_cross_entropy(outputs, y_b)

            elif mode == "cutmix":
                # CutMix → soft label + categorical_crossentropy
                mixed_x, y_a, y_b, lam = cutmix_data(images, labels, alpha)
                y_a = onehot(y_a, num_classes).to(device)
                y_b = onehot(y_b, num_classes).to(device)

                outputs = model(mixed_x)
                loss = lam * soft_cross_entropy(outputs, y_a) + \
                       (1 - lam) * soft_cross_entropy(outputs, y_b)

            else:
                # baseline : 일반 CrossEntropy (sparse_categorical_crossentropy)
                outputs = model(images)
                loss = criterion_plain(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        epoch_train_loss = running_loss / len(train_loader.dataset)

        # ---------- 검증 ----------
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)

                # 검증은 공통적으로 sparse CE 사용 (hard label)
                loss = criterion_plain(outputs, labels)
                val_loss += loss.item() * images.size(0)

                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        epoch_val_loss = val_loss / len(val_loader.dataset)
        epoch_val_acc = correct / total

        history["train_loss"].append(epoch_train_loss)
        history["val_loss"].append(epoch_val_loss)
        history["val_accuracy"].append(epoch_val_acc)

        print(f"[{mode}][Epoch {epoch+1}/{num_epochs}] "
              f"Train Loss: {epoch_train_loss:.4f} | "
              f"Val Loss: {epoch_val_loss:.4f} | "
              f"Val Acc: {epoch_val_acc:.4f}")

    return model, history


In [None]:
# ----------------------------------------------------
# 7. 결과 시각화 (루브릭용 비교 그래프)
# ----------------------------------------------------

def plot_histories(hist_base, hist_mix, hist_cut):
    epochs = range(1, len(hist_base["val_accuracy"]) + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, hist_base["val_accuracy"], 'r-', label='Baseline Aug')
    plt.plot(epochs, hist_mix["val_accuracy"],  'b-', label='Mixup')
    plt.plot(epochs, hist_cut["val_accuracy"],  'g-', label='CutMix')
    plt.title('Model validation accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(loc='upper left')
    plt.grid(True)
    plt.show()

plot_histories(hist_baseline, hist_mixup, hist_cutmix)