In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torchvision import models
from sklearn.metrics import accuracy_score, confusion_matrix
%matplotlib inline

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
N = 10000
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=N, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
classes2 = ('0', '1', '2', '3',
           '4', '5', '6', '7', '8', '9')

In [4]:
class CNN(torch.nn.Module):
    def __init__(self, num_output_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 2 * 2, 256)
        self.fc2 = nn.Linear(256, 10)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.relu(self.conv4(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 2 * 2)
        x = self.relu(self.fc1(x))      
        x = self.fc2(x)
 
        return x

In [5]:
net = CNN()

In [6]:
# 保存しているtorchのパラメタを持ってくる
load_weights = torch.load('model72.pt')
net.load_state_dict(load_weights)

<All keys matched successfully>

In [8]:
ans = []
pred = []
for i, data in enumerate(testloader, 0):
    inputs, labels = data
    outputs = net(inputs)

    ans += labels.tolist()
    pred += torch.argmax(outputs, 1).tolist()

print('accuracy:', accuracy_score(ans, pred))
print('confusion matrix:')
print(confusion_matrix(ans, pred))

accuracy: 0.797
confusion matrix:
[[838  12  26  11   9   5   7   6  55  31]
 [  8 898   2   1   1   2   6   0  19  63]
 [ 68   3 650  36  75  61  67  19  10  11]
 [ 22   6  56 571  62 152  75  31  10  15]
 [ 16   1  27  35 814  12  36  48   9   2]
 [ 11   1  37 138  42 709  18  37   2   5]
 [  6   3  20  36  21  17 885   3   6   3]
 [ 16   1  14  32  43  46   4 835   1   8]
 [ 40  21   9   6   5   1   4   5 888  21]
 [ 22  55   1   5   1   3   6   7  18 882]]
