# The AllConv Net

A parameter efficient ConvNet as baseline.

In [5]:
import os

import numpy as np
from numpy.random import choice
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from torchvision import datasets
import torchvision

import sys
sys.path.append('../utils')
sys.path.append('../models')

import utils
import all_conv_net

np.random.seed(43)

In [0]:
# All functions for this notebook in one cell:

def train_validation_split(dataset, fraction=0.1, batchsize=64):
    """
    @ dataset: torvision dataset
    return: training and validation set as DataLoaders
    """
    n_val_batches = round(len(dataset) * fraction)
    val_indices = choice(range(len(dataset)), size=n_val_batches, replace=False)
    train_indices = [x for x in range(len(dataset)) if x not in val_indices]

    train_loader = DataLoader(dataset,
                              sampler=SubsetRandomSampler(train_indices),
                              batch_size=batchsize
                              )
    val_loader = DataLoader(dataset,
                            sampler=SubsetRandomSampler(val_indices),
                            batch_size=batchsize
                            )
    return train_loader, val_loader


def prediction_from_output(model_output: torch.Tensor):
    """
    @ model_output: output from the last linear layer (the output is without
                    softmax because softmax is included in CE-loss)
    return: predictions as tensor of class-indices
    """
    probabilities = nn.Softmax(dim=1)(model_output)
    max_probs, predictions = probabilities.max(dim=1)  
    return max_probs, predictions
  

def train_model(net, train, validation, optimizer, max_epoch=100):
    """
    This function returns nothing. The parametes of @net are updated in-place
    and the error statistics are written to a global variable. This allows to
    stop the training at any point and still have the results.
  
    @ net: a defined model - can also be pretrained
    @ train, test: DataLoaders of training- and test-set
    @ max_epoch: stop training after this number of epochs
    """  
    global error_stats  # to track error log even when training aborted
    error_stats = []
  
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[x for x in range(1, max_epoch) if x % 20 == 0],
        gamma=0.5  # decrease learning rate by half each step
    )
    net.cuda()
  
    print('epoch\ttraining-CE\tvalidation-CE\tvalidation-accuracy (%)')
    for epoch in range(max_epoch):
        net.train()
        training_loss = 0
    
        for images, labels in train:
            labels = labels.cuda()
            images = images.cuda()
            optimizer.zero_grad()

            # prediction and error:
            output = net(images)
            loss = criterion(output, labels)  # loss of current batch
            training_loss += loss.item()  # keep track of training error

            # update parameters:
            loss.backward()
            optimizer.step()

        with torch.no_grad():  # no backpropagation necessary
            validation_loss = 0
            net.eval()

            for images, labels in validation:
                labels = labels.cuda()
                images = images.cuda()

                # prediction and error:
                output = net(images)
                loss = criterion(output, labels)
                validation_loss += loss.item()

                predictions = prediction_from_output(output)[1]
                accuracy = (predictions == labels).float().mean() * 100
    
        # convert to batch loss:
        training_loss = training_loss / len(train)
        validation_loss = validation_loss / len(validation)
        scheduler.step(validation_loss)

        torch.save(net.state_dict(), f'epoch{epoch}.pt')
        error_stats.append( (training_loss, validation_loss) )
        print('{}\t{:.2f}\t\t{:.2f}\t\t{:.2f}'.format(
            epoch, training_loss, validation_loss, accuracy)
             )

    
def test_set_evaluation(net, test, just_print=False):
    """
    Calculate cross-entropy loss (mean batch loss) and accuracy on the test-set
    """
    total_loss = 0
    net.eval()
    criterion = nn.CrossEntropyLoss()
  
    with torch.no_grad():
        for images, labels in test:
            labels = labels.cuda()
            images = images.cuda()

            output = net(images)
            batch_loss = criterion(output, labels)
            total_loss += batch_loss.item()

            predictions = prediction_from_output(output)[1]
            batch_accuracy = (predictions == labels).float().mean() * 100
    
    mean_batch_loss = total_loss / len(test)
    accuracy = batch_accuracy.mean()
  
    if just_print:
        print('\nEvaluation on the test-set:')
        print(f'mean batch cross-entropy loss: {mean_batch_loss:.2f}')
        print(f'accuracy: {accuracy:.2f}')
        return None
  
    return mean_batch_loss, accuracy
    

def count_parameters(model, in_millions=False):
    """
    Count number of parameters of @model
    """
    n_params = sum(p.numel() for p in model.parameters())
    if in_millions:
        n_params = n_params / 1000000
    return n_params


def show_image(processed_image: torch.Tensor, means: tuple, stdevs: tuple):
    """
    @ means / @ stdevs: per-channel values used for normalizing the raw images
    Recalculate original image from @processed_image and display it.
    """
    img = processed_image.clone()
    for channel in (0, 1, 2):
        img[channel, :, :] = img[channel, :, :] * stdevs[channel] + means[channel]
    img = img.permute(1, 2, 0).numpy()

    plt.imshow(img)
    plt.grid('off')
    plt.xticks([])
    plt.yticks([])
    plt.show();


def class_name(index: int) -> str:
    """
    @ index: class-index between 0-99
    return: class-name 
    """
    class_names = [
      'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
      'beetle','bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly',
      'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee',
      'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup',
      'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl',
      'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower',
      'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle',
      'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter',
      'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate',
      'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road',
      'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper',
      'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower',
      'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger',
      'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale',
      'willow_tree', 'wolf', 'woman', 'worm'
    ]
    class_index2name = dict(enumerate(class_names))
    return class_index2name.get(index, f'invalid class index: {index}')


def predict_and_display(net, testset, n=10):
    """
    Predict @n random examples from the test-set and show images + predictions
    """
    net.eval()
    for i in choice(range(len(testset)), size=n):
        image, label = test_set[i]
        output = net(image.unsqueeze(0).cuda())
        prob, pred = prediction_from_output(output.unsqueeze(0))
        prob, pred = prob.item(), pred.item()
        evaluation = 'correct' if pred == label else 'mistake'

        plt.figure( figsize=(2, 2) )
        print(f'\ntruth: {label} | pred: {pred} | prob: {prob:.2f}')
        print(f'{evaluation}: ({class_name(label)} vs. {class_name(pred)})')
        show_image(image, means=channel_means, stdevs=channel_standard_devs)


def plot_error_curves(errors_over_time: list, error_name='error'):
    """
    @ errors_over_time: list of tuples: (training-error, validation-error)
    """
    error_train, error_validation = zip(*errors_over_time)

    plt.plot(range(len(error_train)), error_train)
    plt.plot(range(len(error_validation)), error_validation)
    plt.xticks(range(0, len(error_train) + 1, len(error_train) // 2))
    plt.xlabel('epoch')
    plt.ylabel('CE')
    plt.legend(('training', 'validation'))
    plt.title(f'{error_name} over time')
    plt.show();

## Load data and pre-process

No sphisticated pre-processing was done in this project. The data was normalized as is best practice in computer vision. For each channel, it's mean and standard deviation were calculated (in an exploratory notebook). Then each channels was normalized by subtracting its mean and dividing by its standard deviation.

CIFAR-100 consists of 50000 training images and 10000 test-images (could be split differntly of course but we went with the standard split). The performance on the test-set was of course only evaluated once for each model. The 50000 remaining images were split into 40000 for training and 10000 for measuring the validation loss after every training epoch.



In [5]:
# values for normalisation
channel_means = (0.5071, 0.4865, 0.4409)
channel_standard_devs = (0.2673, 0.2564, 0.2762)

transformation = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(channel_means, channel_standard_devs)
])
batchsize=64

test_set = datasets.CIFAR100(data_dir, train=False, transform=transformation, download=True)
test_loader = DataLoader(test_set, batch_size=batchsize)

rest = datasets.CIFAR100(data_dir, train=True, transform=transformation, download=True)
train_loader, validation_loader = train_validation_split(rest, fraction=0.1, batchsize=batchsize)

print(f'training batches: {len(train_loader)}')
print(f'validation batches: {len(validation_loader)}')
print(f'test batches: {len(test_loader)}')
print(f'batch size: {batchsize}')

Files already downloaded and verified
Files already downloaded and verified
training batches: 704
validation batches: 79
test batches: 157
batch size: 64


### Dimensions

#### convolutional layer

$$W_{output} = \Bigl\lfloor \frac{W_{input} - F + 2*P}{S} \Bigr\rfloor+ 1$$

#### pooling layer

$$W_{output} = \frac{W_{input} - F}{S} + 1$$

with:  
W: width
F: filter width  
P: padding  
S: stride

... same goes for height (images are quadratic anyway)

In [6]:
# calculate convolutional-layer dimension:
( (16 - 3 + 2*1 ) / 2 ) + 1

8.5

### The all-convolution net

In search for a leaner architecture, we implemented the all-convolution net described in the paper "Striving for Simplicity: The All Convolutional Net" from Springenberg et. al. This architecture simply downsamples the input with a couple of convolutions as in regular convolutional nets. Then, it arrives at a number of channles that is equal to the number of classes. Taking the mean of each channel (a feature map of very small height / width) gives a vector of the same length as the number of classes - the same output as in a fully connected layer. In our implementation, this architecture has only ~0.3 million paramters when starting with 32 channels. As before, downsampling is done with stride=2 convolutions.

We also added optional residual shortcuts to this network. Due to the small number of parameters, it is feasible to compare the residual version with the non-residual version of this network.

In [0]:
class AllCNN(nn.Module):
    """
    All-convolutional network (no fully connected layers) with low number of
    parameters. Can be converted into a ResNet by 'residuals=True'
    """
  
    def conv3x3(self, in_channels, out_channels, stride=1):
        """
        Basic 3x3 convolutional layer plus batch normalization
        Note: no ReLU in layer (ReLU after addition of residual)
        """
        layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3,
                      stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Dropout2d(p=0.2)
        )
        return layer

    def __init__(self, c=32, n_classes=100, residuals=False):
        """
        @ c: number of channels after first convolution
        @ n_classes: number of target classes
        """
        super().__init__()

        self.residuals = residuals
        self.relu = nn.ReLU(inplace=False)

        self.conv1 = self.conv3x3(3,   c,   stride=1)
        self.conv2 = self.conv3x3(c,   c,   stride=1)

        self.conv3 = self.conv3x3(c,   c*2, stride=2)  # -> 16**2
        self.conv4 = self.conv3x3(c*2, c*2, stride=1)
        self.conv5 = self.conv3x3(c*2, c*2, stride=1)

        self.conv6 = self.conv3x3(c*2, c*4, stride=2)  # -> 8**2
        self.conv7 = self.conv3x3(c*4, c*4, stride=1)

        self.conv8 = nn.Sequential(
            nn.Conv2d(c*4, c*4, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(c*4),
            nn.Dropout2d(p=0.2),
            self.relu
        )

        """
        conv9: make 1 channel per class -> take average of each feature map
        -> 100 real values: same as in fully connected output layer
        """
        self.conv9 = nn.Sequential(
            nn.Conv2d(c*4, n_classes, kernel_size=1, stride=1, padding=0),
            nn.Dropout2d(p=0.2),
            self.relu,
            nn.AvgPool2d(kernel_size=8)
        )

    def forward(self, x):

        res = int(self.residuals)  # True / False; if 0: residual becomes 0

        h = self.relu(self.conv1(x))
        h = self.relu(self.conv2(h) + h*res )

        h = self.relu(self.conv3(h))  # -> 16**2
        h = self.relu(self.conv4(h) + h*res )
        h = self.relu(self.conv5(h) + h*res )

        h = self.relu(self.conv6(h))  # -> 8**2
        h = self.relu(self.conv7(h) + h*res )

        h = self.relu(self.conv8(h) + h*res )
        y = self.conv9(h).squeeze()

        return y


all_conv = AllCNN(c=64, residuals=False)

In [12]:
print('MODEL: all-convolution net')
print(f'parameters: {count_parameters(all_conv, in_millions=True)} million\n')

optimizer = torch.optim.SGD(all_conv.parameters(),
                            lr=0.1,  # will be decreased by scheduler
                            momentum=0.9,
                            weight_decay=0.00002
                           )
#train_model(all_conv, train_loader, validation_loader, optimizer, max_epoch=100)

MODEL: all-convolution net
parameters: 1.387044 million



In [0]:
plot_error_curves(error_stats, error_name='cross-entropy loss')