In [1]:
import torch
import os
from torch import nn
import torch.nn.functional as F
import pickle
from torchsummary import summary
import numpy as np
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import MultiStepLR
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from tqdm import tqdm
from model.res3net import Res3Net, BasicBlock
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast

In [2]:
def load_cifar_batch(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

class CustomCIFAR10Dataset(Dataset):
    def __init__(self, images, labels, transform):
        self.images = images
        self.labels = torch.tensor(labels, dtype=torch.long)  # 保持为numpy数组
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # img = Image.fromarray(self.images[idx])  # 直接转PIL图像（更高效）
        img = self.transform(self.images[idx])  # 应用transform
        label = self.labels[idx]
        return img, label

In [3]:
# Load test dataset
cifar_test_path = 'deep-learning-spring-2025-project-1/cifar_test_nolabel.pkl'
test_batch = load_cifar_batch(cifar_test_path)
test_images = test_batch[b'data'].astype(np.float32) / 255.0
test_idx = test_batch[b'ids']
test_transform = transforms.Compose([
    # transforms.ToPILImage(),  # Convert numpy array to PIL Image
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
# Convert test dataset to Tensor
test_dataset = [(test_transform(img),) for img in test_images]
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Importing Model and printing Summary
model = Res3Net(BasicBlock).to(device)
# summary(model, input_size=(3,32,32))
param = torch.load('res3net_200/model_res3net_200.pth')
model.load_state_dict(param)

<All keys matched successfully>

In [5]:
model.eval()
result = []
pbar = tqdm(test_dataset)
for idx, img in enumerate(pbar):
    img = img[0].to(device).unsqueeze(0)
    output = model(img)
    pred = output.argmax(dim=1, keepdim=True)
    result.append([test_idx[idx], int(pred)])

100%|██████████| 10000/10000 [00:18<00:00, 527.59it/s]


In [6]:
def make_output(result):
    out = ['ID,Labels']
    for i in range(len(result)):
        print(result[i])
        cur = str(result[i][0]) + str(',') + str(int(result[i][1]))
        out.append(cur)
    return '\n'.join(out)
with open("output.csv", "w", encoding="utf-8") as f:
    f.write(make_output(result))

[0, 6]
[1, 1]
[2, 8]
[3, 6]
[4, 9]
[5, 3]
[6, 0]
[7, 2]
[8, 9]
[9, 5]
[10, 2]
[11, 0]
[12, 8]
[13, 8]
[14, 7]
[15, 0]
[16, 5]
[17, 8]
[18, 5]
[19, 8]
[20, 9]
[21, 7]
[22, 7]
[23, 0]
[24, 6]
[25, 7]
[26, 4]
[27, 1]
[28, 9]
[29, 3]
[30, 1]
[31, 3]
[32, 3]
[33, 3]
[34, 3]
[35, 7]
[36, 2]
[37, 4]
[38, 7]
[39, 4]
[40, 1]
[41, 5]
[42, 8]
[43, 9]
[44, 9]
[45, 9]
[46, 1]
[47, 7]
[48, 3]
[49, 3]
[50, 4]
[51, 6]
[52, 1]
[53, 1]
[54, 2]
[55, 8]
[56, 1]
[57, 7]
[58, 9]
[59, 3]
[60, 4]
[61, 2]
[62, 8]
[63, 1]
[64, 6]
[65, 3]
[66, 6]
[67, 0]
[68, 5]
[69, 9]
[70, 0]
[71, 6]
[72, 2]
[73, 5]
[74, 3]
[75, 0]
[76, 0]
[77, 9]
[78, 8]
[79, 2]
[80, 2]
[81, 3]
[82, 8]
[83, 4]
[84, 1]
[85, 3]
[86, 7]
[87, 6]
[88, 9]
[89, 3]
[90, 0]
[91, 5]
[92, 8]
[93, 5]
[94, 1]
[95, 6]
[96, 7]
[97, 8]
[98, 1]
[99, 5]
[100, 5]
[101, 5]
[102, 4]
[103, 2]
[104, 1]
[105, 0]
[106, 9]
[107, 9]
[108, 7]
[109, 9]
[110, 7]
[111, 3]
[112, 4]
[113, 8]
[114, 8]
[115, 4]
[116, 4]
[117, 5]
[118, 7]
[119, 1]
[120, 8]
[121, 6]
[122, 2]
[12