In [1]:
import numpy as np
import sys
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.utils import to_categorical

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [2]:
MAX_TRAIN, MAX_TEST, MAX_VAL = 4000, 1000, 1000
MAX_EPOCHS = 20

In [3]:
def roll_my_axis(x):
    return np.rollaxis(x, 3, 1)
def reshape_label(y):
    return y.reshape(y.size)

In [4]:
cifar = tf.keras.datasets.cifar10 
(x_train_npy_o, y_train_npy_o), (x_test_npy, y_test_npy) = cifar.load_data()

x_train_npy, y_train_npy = roll_my_axis(x_train_npy_o[:MAX_TRAIN]), reshape_label(y_train_npy_o[:MAX_TRAIN])
x_test_npy, y_test_npy = roll_my_axis(x_test_npy[:MAX_TEST]), reshape_label(y_test_npy[:MAX_TEST])
val_idxs = range(int(1.5*MAX_TRAIN), int(1.5*MAX_TRAIN)+MAX_VAL)
x_val_npy, y_val_npy = roll_my_axis(x_train_npy_o[val_idxs]), reshape_label(y_train_npy_o[val_idxs])

x_train, x_test = torch.Tensor(x_train_npy), torch.Tensor(x_test_npy)
y_train, y_test = torch.Tensor(y_train_npy), torch.Tensor(y_test_npy)
x_val, y_val = torch.Tensor(x_val_npy), torch.Tensor(y_val_npy)

trainloader = DataLoader(TensorDataset(x_train.type('torch.FloatTensor'), y_train.type('torch.LongTensor')), 
                         batch_size = 128, shuffle=True, num_workers=4)
testloader = DataLoader(TensorDataset(x_test.type('torch.FloatTensor'), y_test.type('torch.LongTensor')), 
                        batch_size = 128, shuffle=True, num_workers=4)
valloader = DataLoader(TensorDataset(x_val.type('torch.FloatTensor'), y_val.type('torch.LongTensor')), 
                        batch_size = 128, shuffle=True, num_workers=4)

In [5]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 48, 3) 
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(48, 96, 3)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.bn2 = nn.BatchNorm2d(96)
        
        self.conv22 = nn.Conv2d(96, 192, 2)
        self.bn22 = nn.BatchNorm2d(192)
        
        self.conv3 = nn.Conv2d(192, 192, 2)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.bn3 = nn.BatchNorm2d(192)
        
        self.conv4 = nn.Conv2d(192, 256, 2)
        self.pool4 = nn.MaxPool2d(2, 2)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.fc1 = nn.Linear(1024, 512)
        self.dropout1 = nn.Dropout(.4)
        
        self.fc2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(.4)
        
        self.fc3 = nn.Linear(256, 10)
    def forward(self, x):
        out = (F.selu(self.conv1(x)))
        out = self.bn2(self.pool2(F.selu(self.conv2(out))))
        out = self.bn22(F.selu(self.conv22(out)))
        out = self.bn3(self.pool3(F.selu(self.conv3(out))))
        out = self.bn4(self.pool4(F.selu(self.conv4(out))))
        out = out.view(-1, 256*2*2)
        out = self.dropout1(F.selu(self.fc1(out)))
        out = self.dropout2(F.selu(self.fc2(out)))
        out = self.fc3(out)
        return out

In [6]:
def compute_validation_accuracy(model, dataLoader):
    correct, total = 0, 0
    with torch.no_grad():
        for data in dataLoader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct/total          

In [9]:
model = AlexNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(MAX_EPOCHS):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # Forward + Backward + Optimize 
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()
    
    print ("Accuracy at the end of Epoch %d: %.3f" % (epoch+1, compute_validation_accuracy(model, valloader)))

Accuracy at the end of Epoch 1: 0.168
Accuracy at the end of Epoch 2: 0.216
Accuracy at the end of Epoch 3: 0.224
Accuracy at the end of Epoch 4: 0.246
Accuracy at the end of Epoch 5: 0.253
Accuracy at the end of Epoch 6: 0.257
Accuracy at the end of Epoch 7: 0.282
Accuracy at the end of Epoch 8: 0.279
Accuracy at the end of Epoch 9: 0.288
Accuracy at the end of Epoch 10: 0.327
Accuracy at the end of Epoch 11: 0.318
Accuracy at the end of Epoch 12: 0.315
Accuracy at the end of Epoch 13: 0.319
Accuracy at the end of Epoch 14: 0.343
Accuracy at the end of Epoch 15: 0.355
Accuracy at the end of Epoch 16: 0.361
Accuracy at the end of Epoch 17: 0.361
Accuracy at the end of Epoch 18: 0.363
Accuracy at the end of Epoch 19: 0.373
Accuracy at the end of Epoch 20: 0.354


In [10]:
print ("Test Accuracy = %.3f" % compute_validation_accuracy(model, testloader))

Test Accuracy = 0.377
