In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as dsets
import torchvision.transforms as trans

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

In [2]:
torch.cuda.set_device(1)
torch.set_default_tensor_type('torch.cuda.DoubleTensor')

In [3]:
train_set = dsets.CIFAR10(root = './CIFAR10/',
                        train = True,
                        transform = trans.ToTensor(),
                        download = False)

test_set = dsets.CIFAR10(root = './CIFAR10/',
                       train = False,
                       transform = trans.ToTensor(),
                       download = False)

In [4]:
list_inputs = []
list_label = []
for i, data in enumerate(train_set, 0):
    inputs, label = data
    if label == 0:
        list_inputs.append(inputs.tolist())
        list_label.append(0)
    elif label == 1:
        list_inputs.append(inputs.tolist())
        list_label.append(1)
        
A = torch.tensor(list_inputs[0:1000]).cuda()
B = torch.tensor(list_label[0:1000]).cuda()
train_set = TensorDataset(A, B)

list_inputs = []
list_label = []
for i, data in enumerate(test_set, 0):
    inputs, label = data
    if label == 0:
        list_inputs.append(inputs.tolist())
        list_label.append(0)
    elif label == 1:
        list_inputs.append(inputs.tolist())
        list_label.append(1)
C = torch.tensor(list_inputs[0:1000]).cuda()
D = torch.tensor(list_label[0:1000]).cuda()
test_set = TensorDataset(C, D)

In [5]:
train_dl = DataLoader(train_set,
                     batch_size = 50,
                     num_workers = 0)

test_dl = DataLoader(test_set,
                    batch_size = 100,
                    num_workers = 0)

In [6]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size = 3, padding = 1)
        self.pool1 = nn.AvgPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size = 3, padding = 1)
        self.pool2 = nn.AvgPool2d(2)
        self.conv3 = nn.Conv2d(32, 64, kernel_size = 3, padding = 1)
        self.pool3 = nn.AvgPool2d(2)
        self.conv4 = nn.Conv2d(64, 64, kernel_size = 3, padding = 1)
        self.pool4 = nn.AvgPool2d(2)
        self.fc1 = nn.Linear(2*2*64, 128)
        self.fc2 = nn.Linear(128, 2)
        
    def forward(self, x):
        o = self.conv1(x)
        o = self.pool1(o)
        o = F.relu(o)
        
        o = self.conv2(o)
        o = self.pool2(o)
        o = F.relu(o)
        
        o = self.conv3(o)
        o = self.pool3(o)
        o = F.relu(o)
        
        o = self.conv4(o)
        o = self.pool4(o)
        o = F.relu(o)
        
        # Flat 
        o = o.view(x.size(0),-1)
        
        o = self.fc1(o)
        o = F.relu(o)
        
        o = self.fc2(o)
        return o

In [7]:
def train_epoch(model, optimizer, criterion, dataloader):
    loss_tot, nbatch = 0, 0
    for batch_x, batch_y in dataloader:
        optimizer.zero_grad()
        o = model(batch_x)
        loss = criterion(o, batch_y)
        loss.backward()
        optimizer.step()
        loss_tot += loss.item()
        nbatch += 1
    return loss_tot / nbatch

In [8]:
def eval_accuracy(model, dataloader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_x, batch_y in dataloader:
            batch_o = model(batch_x)
            pred = batch_o.max(1, keepdim = True)[1]
            correct += pred.eq(batch_y.view_as(pred)).float().mean().item()
    return correct * 100.0 / len(dataloader)

In [9]:
net = VGG().cuda()

In [10]:
criterion = nn.CrossEntropyLoss()

In [11]:
optimizer = optim.Adam(net.parameters())

In [12]:
nepochs = 300
for epoch in range(nepochs):
    train_loss = train_epoch(net, optimizer, criterion, train_dl)
    train_acc = eval_accuracy(net, train_dl)
    print('{:}/{:}, tr_loss {:.2e}, tr_acc {:.2f}'.format(epoch + 1, nepochs, train_loss, train_acc))

1/300, tr_loss 6.84e-01, tr_acc 71.90
2/300, tr_loss 5.94e-01, tr_acc 73.30
3/300, tr_loss 5.92e-01, tr_acc 70.50
4/300, tr_loss 5.61e-01, tr_acc 73.90
5/300, tr_loss 5.37e-01, tr_acc 74.70
6/300, tr_loss 5.25e-01, tr_acc 75.70
7/300, tr_loss 5.06e-01, tr_acc 78.50
8/300, tr_loss 4.88e-01, tr_acc 79.40
9/300, tr_loss 4.65e-01, tr_acc 80.80
10/300, tr_loss 4.46e-01, tr_acc 82.30
11/300, tr_loss 4.16e-01, tr_acc 83.20
12/300, tr_loss 3.96e-01, tr_acc 83.80
13/300, tr_loss 3.79e-01, tr_acc 84.10
14/300, tr_loss 3.73e-01, tr_acc 85.80
15/300, tr_loss 3.65e-01, tr_acc 86.30
16/300, tr_loss 3.48e-01, tr_acc 86.60
17/300, tr_loss 3.39e-01, tr_acc 86.70
18/300, tr_loss 3.30e-01, tr_acc 87.90
19/300, tr_loss 3.24e-01, tr_acc 88.40
20/300, tr_loss 3.21e-01, tr_acc 87.90
21/300, tr_loss 3.11e-01, tr_acc 87.70
22/300, tr_loss 2.95e-01, tr_acc 89.20
23/300, tr_loss 2.83e-01, tr_acc 89.20
24/300, tr_loss 2.75e-01, tr_acc 89.70
25/300, tr_loss 2.66e-01, tr_acc 90.20
26/300, tr_loss 2.63e-01, tr_acc 8

In [13]:
eval_accuracy(net, test_dl)

89.59999740123749