In [None]:
import os
import random
import pickle
import numpy as np
from matplotlib import pyplot as plt

import torch
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset

from utils import Cifar10
from model import ResNet50

pretrained_path = '../model/ResNet/resnet50-19c8e357.pth'

In [None]:
model = ResNet50(
    num_classes=10,
    pretrained=pretrained_path
)
model.freeze()

In [None]:
path = "../data/cifar-10-batches-py/"
dataset = Cifar10(path)
train_images, train_labels = dataset.get_train()
test_images, test_labels = dataset.get_test()
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

In [None]:
sample_number = 1024

In [None]:
cat = {
    label: np.where(train_labels == label)[0].tolist()
    for label in range(10)
} # 每个类别的图片在训练集中的索引

sample = {
    label: random.sample(cat[label], sample_number)
    for label in range(10)
} # 每个类别随机采样1000张图片

In [None]:
preprocess = transforms.Compose([
    transforms.Resize(256, antialias=True),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
X_train_new = np.empty((0, 3, 32, 32))
for label in range(10):
    X_train_new = np.vstack((X_train_new, train_images[sample[label]]))
y_train_new = np.concatenate(
    [train_labels[sample[label]] for label in range(10)]
)
train_dataset = TensorDataset(
    preprocess(torch.from_numpy(X_train_new).float()),
    torch.from_numpy(y_train_new).long()
)
test_dataset = TensorDataset(
    preprocess(torch.from_numpy(test_images).float()),
    torch.from_numpy(test_labels).long()
)

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False
)

In [None]:
epochs = 10
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
for epoch in range(1, epochs+1):
    model.train()
    for i, (images, labels) in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Epoch {epoch}, Iter {i+1}, Loss {loss.item():.4f}")
    model.eval()
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
    print(f"Epoch {epoch}, Test Accuracy {correct.item()/total:.4f}")

In [None]:
model.eval()
for images, labels in test_dataloader:
    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        print(sum(predicted == labels))
        break