In [None]:
import os
import re
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

# Function to set random seed
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Function to prepare test data
def prepare_test_data(batch_size=64):
    ########## 필요한 코드 채워주세요 ##########
    transform = transforms.Compose([
        transforms.ToTensor(),              # tensor로 변환
        transforms.Normalize((0.5,), (0.5,))  # 1채널 정규화
        # 더 정확히 하려면: Normalize((0.2860,), (0.3530,))
    ])

    # FashionMNIST 테스트 데이터셋
    test_dataset = torchvision.datasets.FashionMNIST(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )
    ###########################################
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Function to calculate accuracy
def calculate_accuracy(logits, targets):
    pred = np.argmax(logits, axis=1)
    accuracy = 100. * np.sum(pred == targets.numpy()) / len(targets)
    return round(accuracy, 2)

# ----- 간단 CNN 정의 -----
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2),  # 28->14
            torch.nn.Conv2d(32, 64, 3),           torch.nn.ReLU(), torch.nn.MaxPool2d(2)   # 14->6
        )
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(64*6*6, 128), torch.nn.ReLU(),
            torch.nn.Linear(128, 10)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# Main function to evaluate logits
def evaluate(logits_filename):
    set_seed()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # ---- 테스트 로더 준비 ----
    test_loader = prepare_test_data()
    print(f'Test dataset size: {len(test_loader.dataset)}')

    # ---- (추가) 학습용 train_loader 생성 ----
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
        # 또는 (0.2860,), (0.3530,)
    ])
    train_dataset = torchvision.datasets.FashionMNIST(
        root='./data', train=True, download=True, transform=transform
    )
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    # ---- (추가) CNN 학습 ----
    model = CNN().to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    for epoch in range(3):  # 데모용 3 epoch (원하면 늘리세요)
        running = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running += loss.item()
        print(f"[Epoch {epoch+1}/3] loss: {running/len(train_loader):.4f}")

    # ---- (추가) 테스트셋 logits 생성 & 저장 ----
    model.eval()
    all_logits = []
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            outputs = model(images)             # (batch, 10) — Softmax 불필요
            all_logits.append(outputs.cpu().numpy())
    logits_generated = np.concatenate(all_logits, axis=0)  # (10000, 10)
    np.save("LimJunsung.npy", logits_generated)

    try:
        # Ensure file name is valid
        assert re.match(r"^(r\d{8}|LimJunsung\d*)\.npy$", logits_filename), \
            "File name must be 'rYYYYMMDD.npy' or 'LimJunsung[digits].npy'."

        # Load logits
        logits = np.load(logits_filename)
        assert logits.shape == (len(test_loader.dataset), 10), \
            f"Logits shape mismatch: expected ({len(test_loader.dataset)}, 10)."

        # Calculate accuracy
        targets = torch.cat([target for _, target in test_loader]).cpu()
        accuracy = calculate_accuracy(logits, targets)
        print(f'{logits_filename[:-4]} - Accuracy: {accuracy:.2f}%')

    except AssertionError as e:
        print(f"[AssertionError] {e}")
        accuracy = 0.0

if __name__ == "__main__":
    evaluate("LimJunsung.npy")
