[1] 모듈 로딩 및 환경 설정

In [1]:
import torch
print(torch.cuda.is_available())            # True 여야 함
print(torch.cuda.get_device_name(0))        # NVIDIA GeForce RTX 5090


True
NVIDIA GeForce RTX 5090


NVIDIA GeForce RTX 5090 with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90.
If you want to use the NVIDIA GeForce RTX 5090 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



In [2]:
#%pip install matplotlib

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.models import VGG16_Weights

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import os

In [4]:
# 설정
IMG_ROOT = './data'
BATCH_SIZE = 64
EPOCHS = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() 
                      else 'cpu')
PATIENCE = 3  # 조기 종료 기준
MODEL_PATH = 'best_vgg16_fashionmnist.pth'

[2] 데이터 전처리 및 로딩

In [5]:
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]) ])

In [6]:
trainDS = datasets.FashionMNIST(root=IMG_ROOT,
                                download=True,
                                train=True,
                                transform=transform)

testDS  = datasets.FashionMNIST(root=IMG_ROOT,
                                download=True,
                                train=False,
                                transform=transform)

In [7]:
trainDL = DataLoader(trainDS, 
                     batch_size=BATCH_SIZE, 
                     shuffle=True)

testDL = DataLoader(testDS, 
                     batch_size=BATCH_SIZE, 
                     shuffle=False)

In [8]:
idx_to_class = { v : k for k, v in 
                trainDS.class_to_idx.items()}

[3] 모델 로딩 및 수정

In [9]:
model = models.vgg16(weights=VGG16_Weights.DEFAULT)

for param in model.features.parameters():
    param.requires_grad = False
    
model.classifier[6] = nn.Linear(4096, 10)
model = model.to(DEVICE)

CRITERION = nn.CrossEntropyLoss()
OPTIMIZER = optim.Adam(model.classifier.parameters(), lr=0.0001)

-----------------------------------------------------


[4] 학습 루프 + 조기 종료 + 최적 모델 저장

In [10]:
best_acc = 0.0
early_stop_counter = 0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    loop = tqdm(trainDL, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for images, labels in loop:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        OPTIMIZER.zero_grad()
        outputs = model(images)
        loss = CRITERION(outputs, labels)
        loss.backward()
        OPTIMIZER.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loop.set_postfix(loss=loss.item(), acc=100*correct/total)

    train_acc = 100 * correct / total
    print(f"\n→ Epoch {epoch+1} Training Accuracy: {train_acc:.2f}%")
    
    # [평가]
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testDL:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = 100 * correct / total
    print(f"→ Test Accuracy: {test_acc:.2f}%")
    
    # [조기 종료 조건]
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), MODEL_PATH)
        early_stop_counter = 0
        print(f"✔️ Best model saved! Accuracy: {best_acc:.2f}%\n")
    else:
        early_stop_counter += 1
        print(f"⚠️ No improvement. ({early_stop_counter}/{PATIENCE})\n")
        if early_stop_counter >= PATIENCE:
            print("⏹️ Early stopping triggered.")
            break

Epoch 1/10:   0%|          | 0/938 [00:00<?, ?it/s]


RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


[5] 저장된 모델 불러오기 및 최종 평가

In [None]:
model.load_state_dict(torch.load(MODEL_PATH))
model = model.to(DEVICE)
model.eval()

correct = 0
total = 0
with torch.no_grad():
    for images, labels in testDL:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

final_acc = 100 * correct / total
print(f"✅ Final Loaded Model Accuracy: {final_acc:.2f}%")

[6] 시각화

In [None]:
def denormalize(img_tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    return img_tensor * std + mean

images, labels = next(iter(testDL))
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
_, preds = torch.max(outputs, 1)

plt.figure(figsize=(12, 6))
for i in range(6):
    img = denormalize(images[i].cpu()).clamp(0, 1)
    npimg = img.permute(1, 2, 0).numpy()
    plt.subplot(2, 3, i+1)
    plt.imshow(npimg)
    plt.title(f"T: {idx_to_class[labels[i].item()]}\nP: {idx_to_class[preds[i].item()]}")
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
train_acc_list = []
test_acc_list = []


In [None]:
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)


In [None]:
import matplotlib.pyplot as plt

epochs = range(1, len(train_acc_list) + 1)

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_acc_list, label='Train Accuracy')
plt.plot(epochs, test_acc_list, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy over Epochs')
plt.xticks(epochs)
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
