In [1]:
import scipy.io as sio
from extract_features import KNN_test
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader



In [2]:
!python extract_features.py

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
100%|████████████████████████| 170498071/170498071 [00:17<00:00, 9607681.31it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data
iter: 1, F_vgg size: torch.Size([500, 4096])
iter: 1, F_alex size: torch.Size([500, 256])
iter: 1, len alex_feature: 1, vgg_feature: 1
iter: 50, F_vgg size: torch.Size([500, 4096])
iter: 50, F_alex size: torch.Size([500, 256])
iter: 50, len alex_feature: 50, vgg_feature: 50
iter: 100, F_vgg size: torch.Size([500, 4096])
iter: 100, F_alex size: torch.Size([500, 256])
iter: 100, len alex_feature: 100, vgg_feature: 100


In [3]:
# AlexNet feature and label shapes
alexnet_mat = sio.loadmat('alexnet.mat')
print('AlexNet Features Size: {}\nAlexNet Label Size: {}'.format(alexnet_mat['feature'].shape, alexnet_mat['label'].shape))

AlexNet Features Size: (50000, 256)
AlexNet Label Size: (1, 50000)


In [4]:
# VGG16 feature and label shapes
vgg16_mat = sio.loadmat('vgg16.mat')
print('VGG16 Features Size: {}\nVGG16 Label Size: {}'.format(vgg16_mat['feature'].shape, vgg16_mat['label'].shape))

VGG16 Features Size: (50000, 4096)
VGG16 Label Size: (1, 50000)


In [2]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_data = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform)

batch_size = 500
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

vgg16_acc, alexnet_acc = KNN_test('.', test_loader)
print('VGG16 Accuracy: {}\nAlexNet Accuracy: {}'.format(vgg16_acc, alexnet_acc))

Files already downloaded and verified
iter: 1, total evaluated sample: 500, curr vgg acc: 0.474, curr alex acc: 0.296
iter: 2, total evaluated sample: 1000, curr vgg acc: 0.484, curr alex acc: 0.309
iter: 3, total evaluated sample: 1500, curr vgg acc: 0.49933333333333335, curr alex acc: 0.312
iter: 4, total evaluated sample: 2000, curr vgg acc: 0.5, curr alex acc: 0.3155
iter: 5, total evaluated sample: 2500, curr vgg acc: 0.5, curr alex acc: 0.318
iter: 6, total evaluated sample: 3000, curr vgg acc: 0.5093333333333333, curr alex acc: 0.31866666666666665
iter: 7, total evaluated sample: 3500, curr vgg acc: 0.5145714285714286, curr alex acc: 0.318
iter: 8, total evaluated sample: 4000, curr vgg acc: 0.514, curr alex acc: 0.32
iter: 9, total evaluated sample: 4500, curr vgg acc: 0.516, curr alex acc: 0.324
iter: 10, total evaluated sample: 5000, curr vgg acc: 0.5144, curr alex acc: 0.325
iter: 11, total evaluated sample: 5500, curr vgg acc: 0.51, curr alex acc: 0.32236363636363635
iter: 