In [69]:
import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [70]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = dset.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = dset.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader=data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' )

In [71]:
# showing some of the training images 
import matplotlib.pyplot as plt
import numpy as np

In [72]:
# functions to show an image
def imshow(img):
    img = img / 2 + 0.5 # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

  car plane  deer  ship


In [73]:
# Defining a CNN
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5) # flattens the tensor
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
net = Net()
        

In [74]:
# Loss functions and optimizers
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)

In [75]:
# train the network
for epoch in range(2): #looping over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        
        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)
            
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # fwd + bwd + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # print stastistics
        running_loss += loss.data[0]
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
print('Finished training')

[1,  2000] loss: 2.193
[1,  4000] loss: 1.853
[1,  6000] loss: 1.677
[1,  8000] loss: 1.580
[1, 10000] loss: 1.517
[1, 12000] loss: 1.472
[2,  2000] loss: 1.415
[2,  4000] loss: 1.368
[2,  6000] loss: 1.347
[2,  8000] loss: 1.330
[2, 10000] loss: 1.307
[2, 12000] loss: 1.262
Finished training


In [76]:
# Testing the network on the test data to check whether the network has learnt anything
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth:', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

outputs = net(Variable(images))

_, predicted = torch.max(outputs.data, 1)

print('Predicted:', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
 

GroundTruth:   cat  ship  ship plane
Predicted:   cat  ship  ship plane


In [77]:
# looking at how the network performs on the whole dataset
correct = 0
total = 0
for data in testloader:
    images, labels = data
    ouputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 100 %
Accuracy of the network on the 10000 test images: 50 %
Accuracy of the network on the 10000 test images: 41 %
Accuracy of the network on the 10000 test images: 31 %
Accuracy of the network on the 10000 test images: 30 %
Accuracy of the network on the 10000 test images: 25 %
Accuracy of the network on the 10000 test images: 25 %
Accuracy of the network on the 10000 test images: 21 %
Accuracy of the network on the 10000 test images: 19 %
Accuracy of the network on the 10000 test images: 17 %
Accuracy of the network on the 10000 test images: 15 %
Accuracy of the network on the 10000 test images: 14 %
Accuracy of the network on the 10000 test images: 13 %
Accuracy of the network on the 10000 test images: 14 %
Accuracy of the network on the 10000 test images: 13 %
Accuracy of the network on the 10000 test images: 12 %
Accuracy of the network on the 10000 test images: 11 %
Accuracy of the network on the 10000 test images: 12 %
Accuracy 

Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 10 %
Accuracy o

Accuracy of the network on the 10000 test images: 10 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on 

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on t

Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %
Accuracy of the network on the 10000 test images: 9 %


In [78]:
# classes that performed well and classes that performed poorly
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for data in testloader:
    images, labels = data
    ouputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    c = (predicted == labels).squeeze()
    for i in range(4):
        label = labels[i]
        class_correct[label] += c[i]
        class_total[label] += 1
for i in range(10):
    print('accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

accuracy of plane : 24 %
accuracy of   car :  0 %
accuracy of  bird :  0 %
accuracy of   cat : 24 %
accuracy of  deer :  0 %
accuracy of   dog :  0 %
accuracy of  frog :  0 %
accuracy of horse :  0 %
accuracy of  ship : 50 %
accuracy of truck :  0 %
