In [1]:
from utils import count_params
import preprocessing as preprocess
import torch
import numpy as np
from torch.utils import data
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import pickle
import tensorflow as tf
import time

if torch.cuda.is_available():  
    DEVICE = "cuda:0" 
else:  
    DEVICE = "cpu" 
print(DEVICE)

cuda:0


In [2]:
writer = SummaryWriter(comment = '_all-cnn')

### Load & Preprocessing  Data

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data(label_mode='fine')
# x: N x 32 x 32 x 3
# y: N x 1

In [4]:
n_train = x_train.shape[0]
x_train = np.moveaxis(x_train, -1, 1).reshape(n_train, -1)
n_test = x_test.shape[0]
x_test = np.moveaxis(x_test, -1, 1).reshape(n_test, -1)

## Expects x_train, x_test in N x 3072, which reshapes to N x 3 x 32 x 32
# Returns: x_train, x_test in N x 3 x 32 x 32
X_train, X_test = preprocess.cifar_10_preprocess(x_train, x_test)

Pre-processing data


In [5]:
X_train = torch.tensor(X_train).float()
X_test = torch.tensor(X_test).float()
y_train = torch.tensor(y_train).squeeze().long()
y_test = torch.tensor(y_test).squeeze().long()

In [6]:
class Dataset(data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    # Get one sample
    def __getitem__(self, index):
        return self.data[index], self.labels[index]

In [7]:
# Put into dataloader
trainset = Dataset(X_train, y_train)
trainloader = data.DataLoader(trainset, batch_size=500, shuffle=True, num_workers=2)
testset = Dataset(X_test, y_test)
testloader = data.DataLoader(testset, batch_size=500, shuffle=False, num_workers=2)

### Model

In [2]:
class Flatten(nn.Module):
    """
    Implement a simple custom module that reshapes (n, m, 1, 1) tensors to (n, m).
    """
    def forward(self, input):
        size = input.size()
        return input.view(size[0], size[1])

# CNN Model (2 conv layer)
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Conv2d(3, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.Conv2d(96, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.Conv2d(96, 96, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU())
        self.layer2 = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(96, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.Conv2d(192, 192, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(192),
            nn.ReLU())
        self.layer3 = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(192, 192, kernel_size=3, padding=0, stride=1),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.Conv2d(192, 192, kernel_size=1, padding=0, stride=1),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.Conv2d(192, 100, kernel_size=1, padding=0, stride=1),
            nn.BatchNorm2d(100),
            nn.ReLU(),
            nn.AvgPool2d(6),
            Flatten())
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        #out = out.view(out.size(0), -1)
        #out = self.fc(out)
        return out


In [9]:
## Training Routine
def training_routine(model, train_generator, test_generator, n_epochs, writer = writer,  
                     eval_every=5):
    model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr =0.01, 
                                momentum = 0.9, weight_decay = 0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma = 0.8)
    accuracies = []
    
    for i in range(n_epochs):
        # Iterate over batches
        batch_losses = []
        
        for X_batch, y_batch in train_generator:
            optimizer.zero_grad()
            # forward pass
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
            batch_output = model(X_batch)
            batch_loss = criterion(batch_output, y_batch)
            # backward pass and optimization
            batch_loss.backward()
            optimizer.step()
            batch_losses.append(batch_loss.cpu().detach())
        print("Epoch {} | training loss: {}".format(i, np.mean(batch_losses)))
        writer.add_scalar('Loss/train', np.mean(batch_losses), i)
        scheduler.step()
        
        # Once every 100 iterations, print statistics
        if i%eval_every==0:
            train_accuracy = []
            test_accuracy = []
            for X_batch, y_batch in train_generator:
                X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
                batch_output = model(X_batch)
                batch_prediction = batch_output.cpu().detach().argmax(dim=1)
                train_accuracy.append((batch_prediction.numpy()==y_batch.cpu().numpy()).mean())
                
            for X_batch, y_batch in test_generator:
                X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
                batch_output = model(X_batch)
                batch_prediction = batch_output.cpu().detach().argmax(dim=1)
                test_accuracy.append((batch_prediction.numpy()==y_batch.cpu().numpy()).mean())
            print("Epoch {} | train acc: {}, test acc: {}".format(i, np.mean(train_accuracy), np.mean(test_accuracy)))
            writer.add_scalar('Accuracy/train', np.mean(train_accuracy), i)
            writer.add_scalar('Accuracy/test', np.mean(test_accuracy), i)
            
            accuracies.append((i, np.mean(train_accuracy), np.mean(test_accuracy)))
            
    return model.cpu(), accuracies

In [4]:
cnn = CNN()
num = count_params(cnn)

In [5]:
num

1389804

In [None]:
start = time.time()
cnn = CNN()
trained_net, accuracies = training_routine(cnn, trainloader, testloader, 51)
end = time.time()
print(end - start)

Epoch 0 | training loss: 4.493959426879883
Epoch 0 | train acc: 0.05288, test acc: 0.05030000000000001
Epoch 1 | training loss: 4.209875106811523


In [None]:
(end-start)/60

In [None]:
model_dict = trained_net.state_dict()
torch.save(model_dict, "all-cnn")

In [None]:
# Load some version of saved model
model = CNN()
state_dict = torch.load('all-cnn')
model.load_state_dict(state_dict)

### Evaluation

In [7]:
import torchvision.datasets as datasets
from torchvision import transforms
y_pred = []
model = CNN()
model = model.cuda()
model.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

testloader = torch.utils.data.DataLoader(
        datasets.CIFAR100(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=100, shuffle=False,
        num_workers=1, pin_memory=True)

start = time.time()
for local_batch, local_labels in testloader:
    local_batch, local_labels = local_batch.cuda(), local_labels.cuda()
    batch_output = model(local_batch)
    batch_prediction = batch_output.cpu().detach().argmax(dim=1)
    y_pred.append(batch_prediction)
end = time.time()  

y_pred = torch.cat(y_pred).numpy()

In [8]:
(end-start)/100

0.021364531517028808

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report, confusion_matrix
n_classes = 100

import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(confusion_matrix(y_test, y_pred, labels=range(n_classes)))
plt.colorbar()

In [None]:
accuracies = np.array(accuracies)
fig, ax = plt.subplots(1, 1, figsize = (6, 4))
ax.plot(accuracies[:, 0], accuracies[:, 1], label = "Train")
ax.plot(accuracies[:, 0], accuracies[:, 2], label = 'Test')
ax.legend(fontsize = 10)
ax.set_xlim((0,50))
ax.set_ylabel('Accuarcy', fontsize = 12)
ax.set_xlabel("# of Epochs", fontsize = 12)