# 04. CNN - CIFAR-10

컬러 이미지 분류를 위한 더 깊은 CNN을 구현합니다.

In [None]:
import torch
import matplotlib.pyplot as plt

torch.manual_seed(42)

In [None]:
from mlfs.utils.data import load_cifar10, CIFAR10_CLASSES
from mlfs.utils.viz import plot_images

X_train, y_train = load_cifar10(train=True)
X_test, y_test = load_cifar10(train=False)

print(f'Train: {X_train.shape}')
print(f'Test: {X_test.shape}')
print(f'Classes: {CIFAR10_CLASSES}')

In [None]:
# 샘플 이미지
plot_images(X_train[:10], labels=y_train[:10])

In [None]:
from mlfs.nn.models import CNN
from mlfs.nn.losses import CrossEntropyLoss
from mlfs.nn.optim import Adam

# CIFAR용 CNN (3채널)
model = CNN(in_channels=3, num_classes=10, image_size=32)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()

In [None]:
# 학습 (빠른 테스트용 서브셋)
n_samples = 5000
X_sub = X_train[:n_samples]
y_sub = y_train[:n_samples]

epochs = 5
batch_size = 64

for epoch in range(epochs):
    model.train()
    indices = torch.randperm(n_samples)
    total_loss = 0
    
    for i in range(0, n_samples, batch_size):
        batch_idx = indices[i:i+batch_size]
        X_batch = X_sub[batch_idx]
        y_batch = y_sub[batch_idx]
        
        logits = model(X_batch)
        loss = criterion(logits, y_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    # 테스트
    model.eval()
    with torch.no_grad():
        test_pred = model.predict(X_test[:1000])
        test_acc = (test_pred == y_test[:1000]).float().mean()
    
    print(f'Epoch {epoch+1}: Loss = {total_loss/(n_samples//batch_size):.4f}, Test Acc = {test_acc:.4f}')

## 요약

CIFAR-10은 MNIST보다 복잡한 데이터셋입니다. 더 깊은 네트워크와 데이터 증강이 필요합니다.