In [None]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join("..")))
print(os.path.abspath(os.path.join("..")))

In [None]:
from dataset import download_dataset

loader_dict = download_dataset()
train_loader = loader_dict["train"]
test_loader = loader_dict["test"]


# CIFAR-10 클래스 이름
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# 테스트셋에서 일부 샘플 가져오기
dataiter = iter(test_loader)
images, labels = next(dataiter)

In [None]:
from mobilenet.mobilenet_4 import MobileNetV4ConvMedium

import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device : {device}")
model = MobileNetV4ConvMedium(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model_path = "models/mobilenetv4_medium_cifar10.pth"
print(f'model exist : {os.path.exists(model_path)}')
# 모델을 평가 모드로 전환
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


In [None]:
images, labels = images.to(device), labels.to(device)

# 모델 추론
with torch.no_grad():
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

# 시각화를 위해 CPU로 복사
images = images.cpu()
labels = labels.cpu()
predicted = predicted.cpu()


In [None]:
import matplotlib.pyplot as plt

# 랜덤하게 8개 샘플 시각화
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

for idx in range(8):
    img = images[idx] / 2 + 0.5  # unnormalize ([-1,1] → [0,1])
    npimg = img.numpy().transpose((1, 2, 0))

    axes[idx].imshow(npimg)
    axes[idx].set_title(f"Pred: {classes[predicted[idx]]}\nTrue: {classes[labels[idx]]}",
                        color=('green' if predicted[idx] == labels[idx] else 'red'),
                        fontsize=10)
    axes[idx].axis('off')

In [None]:
corrected_count = 0
failed_count = 0
for label, predict in zip(labels, predicted):
    if label == predict:
        corrected_count += 1
    else:
        failed_count += 1
print(f'Corrected Count : {corrected_count}, Failed Count : {failed_count}')
print(f'Total Count : {corrected_count + failed_count}')
print(f'Corrected Percent : {int(corrected_count / (corrected_count + failed_count) * 100)}%')