In [1]:
import os
import random
import pickle
import numpy as np
from matplotlib import pyplot as plt

import torch
from torchvision import transforms

from utils import Cifar10
from model import ViT

# Vision Transformer using huggingface transformers

In [2]:
model = ViT(
    path = "../model/ViT/",
)
model.freeze()

In [3]:
path = "../data/cifar-10-batches-py/"
dataset = Cifar10(path)

In [4]:
train_images, train_labels = dataset.get_train()
test_images, test_labels = dataset.get_test()
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

((50000, 3, 32, 32), (50000,), (10000, 3, 32, 32), (10000,))

In [5]:
sample_number = 128

In [6]:
cat = {
    label: np.where(train_labels == label)[0].tolist()
    for label in range(10)
} # 每个类别的图片在训练集中的索引

sample = {
    label: random.sample(cat[label], sample_number)
    for label in range(10)
} # 每个类别随机采样1000张图片

In [7]:
X_train_new = np.empty((0, 3, 32, 32))
for label in range(10):
    X_train_new = np.vstack((X_train_new, train_images[sample[label]]))
y_train_new = np.concatenate(
    [train_labels[sample[label]] for label in range(10)]
)
X_train_new.shape, y_train_new.shape

((1280, 3, 32, 32), (1280,))

In [8]:
transform = transforms.Compose(
    [
    transforms.Resize((224, 224),antialias=True),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))],
)

In [9]:
X_train_new = torch.from_numpy(X_train_new)
X_train_new = transform(X_train_new)
X_train_new.shape

torch.Size([1280, 3, 224, 224])