# Create a custom nn.Module

The goal of this notebook is to show how we can create a custom nn.Module that performs some kind of calculation as part of a neural net.

Starting with https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html, we;
- use the fastai MNIST dataset
- update to use 3 chanel input (i.e. pass 3 rather than 1 to the 1st `Conv2d`)
- refactor to use `nn.Sequential`
- create custom modules; `NormalizeActivation` and `View`
- make it easy to compare accuracy of a trained model with/without normalize activation.

## What does the NormalizeActivation layer do?

- Before starting training, each NormalizeActivation layer can learn the mean and standard deviation of its input. See: `learn_normalize_activation_stats`.
- During training, each NormalizeActivation layer can "normalize" its input using the learned mean and standard deviation - so that the input to the next layer has a mean of zero and a standard deviation of 1.

Running this notebook should give results similar to;

|                                                   |accuracy |
|---------------------------------------------------|---------|
| with normalize activation - setup_and_train(True) | ~98     |
| no normalize activation - setup_and_train(False)  | ~92     |

While it is interesting that normalizing activation layers improves accuracy in this simple example, there are lots of other techniques (batch norm, model specific weight initialization ...) that should give better results.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from fastai.datasets import untar_data, URLs
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [2]:
path = untar_data(URLs.MNIST)
batch_size = 256
device = 'cuda:0'

In [3]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2), # pad images so we don't loose too much in the conv layers (28x28 to 32x32)
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.1], [0.2752]) # see: calculate mean/standard deviation ...
])
def new_loader(type, shuffle):
    return DataLoader(
        ImageFolder(root=path/type, transform=transforms), 
        batch_size=batch_size, num_workers=1, shuffle=shuffle)
train_loader = new_loader('training', True)
test_loader = new_loader('testing', False)

In [4]:
class NormalizeActivation(nn.Module):
    count = 0
    
    def __init__(self):
        super(NormalizeActivation, self).__init__()
        self.mode = 0 # 0=do nothing, 1=learning, 2=active
        self.id = NormalizeActivation.count
        NormalizeActivation.count += 1
        
    def start_learning(self):
        self.mean_list = torch.tensor([]).to(device)
        self.std_list = torch.tensor([]).to(device)
        self.mode = 1
        
    def stop_learning(self):
        self.mean = self.mean_list.mean()
        self.std = self.std_list.mean()
        self.mean_list = None
        self.std_list = None
        print('NormalizeActivation#stop_learning', self.id, self.mean, self.std)
        self.mode = 2
    
    def forward(self, x):
        if self.mode == 1:
            self.mean_list = torch.cat((self.mean_list, x.mean()[None]), 0)
            self.std_list = torch.cat((self.std_list, x.std()[None]), 0)
        if self.mode == 2:
            x = (x - self.mean) / self.std
        return x
    
    def extra_repr(self):
        mean = getattr(self, 'mean', None)
        std = getattr(self, 'std', None)
        return f'id={self.id} mode={self.mode} mean={mean} std={std}'

In [5]:
def conv_block(in_channels, out_channels, kernel_size):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size),
        nn.ReLU(),
        nn.MaxPool2d((2,2)),
        NormalizeActivation())

def fc_block(in_features, out_features):
    return nn.Sequential(
        nn.Linear(in_features, out_features),
        nn.ReLU(),
        NormalizeActivation())

In [6]:
# how many features input to 1st fully connected layer
fc1_in_features = 576 # see: finding fc1_in_features

## What does View do?

View "re-shapes" the data going into the 1st fully connected layer. Having this logic in an nn.Module makes building the nn.Sequential easy.

In [7]:
class View(nn.Module):
    def forward(self, x): return x.view(-1, fc1_in_features)

In [8]:
def learn_normalize_activation_stats(net, batches_per_module=10):
    normalize_activation_modules = [
        m for _, m in net.named_modules() if isinstance(m, NormalizeActivation)]
    idx = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(train_loader):
            if batch_idx % batches_per_module == 0:
                if idx > 0: 
                    normalize_activation_modules[idx-1].stop_learning()
                if idx < len(normalize_activation_modules): 
                    normalize_activation_modules[idx].start_learning()
                else: 
                    break
                idx += 1
            out = net(data.to(device))

In [9]:
def accuracy(net):
    net.eval()
    total = 0
    correct = 0
    incorrect = []
    with torch.no_grad():
        batch = 0
        for (data, target) in test_loader:
            batch += 1
            target = target.to(device)
            output = net(data.to(device))
            predictions = torch.argmax(output, dim=1)
            number_correct = (predictions == target).float().sum().item()
            total += len(target)
            correct += number_correct
    print(f'accuracy over {total} test images: {round(correct/total*100, 2)}')

In [10]:
def train(net, lrs):
    def f(x): return round(x.item(), 4)
    criterion = nn.CrossEntropyLoss()
    epoch = 0
    for lr in lrs:
        net.train()
        epoch += 1
        optimizer = optim.SGD(net.parameters(), lr=lr)
        losses = []
        total = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            total += len(data)
            optimizer.zero_grad()
            out = net(data.to(device))
            loss = criterion(out, target.to(device))
            if torch.isnan(loss):
                raise RuntimeError('loss is nan: re-build net and re-try (maybe with lower lr)')
            losses.append(loss.item())
            loss.backward()
            optimizer.step()
        losses = torch.tensor(losses)
        print('epoch', epoch, 'lr', lr, 'loss', f(losses[:25].mean()), 'last', f(loss), 
              'min', f(losses.min()), 'max', f(losses.max()), 'items', total)
        accuracy(net)

In [11]:
def setup_and_train(use_normalize_activation):
    net = nn.Sequential(
        conv_block(3, 6, 3),
        conv_block(6, 16, 3),
        View(),
        fc_block(fc1_in_features, 120),
        fc_block(120, 84),
        nn.Linear(84, 10)).to(device)
    if use_normalize_activation: 
        learn_normalize_activation_stats(net)
    print(net)
    lrs = [1.5e-2, 1e-2, 5e-3] # 2.5e-2 can work for 1st epoch but can be too high - depending on init
    train(net, lrs)

In [12]:
setup_and_train(True) # run with NormalizeActivation enabled

NormalizeActivation#stop_learning 0 tensor(0.2082, device='cuda:0') tensor(0.3664, device='cuda:0')
NormalizeActivation#stop_learning 1 tensor(0.5126, device='cuda:0') tensor(0.6389, device='cuda:0')
NormalizeActivation#stop_learning 2 tensor(0.2028, device='cuda:0') tensor(0.3290, device='cuda:0')
NormalizeActivation#stop_learning 3 tensor(0.2072, device='cuda:0') tensor(0.3128, device='cuda:0')
Sequential(
  (0): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): NormalizeActivation(id=0 mode=2 mean=0.20817556977272034 std=0.3663758933544159)
  )
  (1): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): NormalizeActivation(id=1 mode=2 mean=0.512565553188324 std=0.6388803720474243)
  )
  (2): View()
  (3): Sequenti

In [13]:
setup_and_train(False) # run without NormalizeActivation

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): NormalizeActivation(id=4 mode=0 mean=None std=None)
  )
  (1): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): NormalizeActivation(id=5 mode=0 mean=None std=None)
  )
  (2): View()
  (3): Sequential(
    (0): Linear(in_features=576, out_features=120, bias=True)
    (1): ReLU()
    (2): NormalizeActivation(id=6 mode=0 mean=None std=None)
  )
  (4): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): ReLU()
    (2): NormalizeActivation(id=7 mode=0 mean=None std=None)
  )
  (5): Linear(in_features=84, out_features=10, bias=True)
)
epoch 1 lr 0.015 loss 2.3013 last 0.8092 min 0.6967 max 2.3111 items 60000
accuracy over

## Calculate mean/standard deviation on the training data

We need to calculate stats on the data coming out of the train loader rather than the unmodified input data. i.e. padding changes stats from ([0.131], [0.308]) to ([0.1], [0.2752]).

Note: StatsHelper().gather_and_print() is commented as it takes a while to run.

In [14]:
class StatsHelper:
    def __init__(self):
        # do all transforms except normalize
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.Pad(2),
            torchvision.transforms.ToTensor()])
        self.loader = DataLoader(
            ImageFolder(root=path/'training', transform=transforms), batch_size=batch_size)
    def print_stats(self, a_list):
        print('mean', a_list.mean(), 'min', a_list.min(), 'max', a_list.max())
    def gather_and_print(self):
        self.mean_list = torch.tensor([])
        self.std_list = torch.tensor([])
        for (x, _) in self.loader:
            x.to(device)
            self.mean_list = torch.cat((self.mean_list, x[:,0].mean()[None]), 0)
            self.std_list = torch.cat((self.std_list, x[:,0].std()[None]), 0)
        self.print_stats(self.mean_list)
        self.print_stats(self.std_list)
# StatsHelper().gather_and_print()
# uncomment the line above and you'll get this output;
# mean tensor(0.1000) min tensor(0.0944) max tensor(0.1066)
# mean tensor(0.2752) min tensor(0.2673) max tensor(0.2837)

## Finding fc1_in_features

To find the number of features that will go into the 1st fully connected layer, we need to know the shape of the output of the conv layers.
We can;
- create a "net" with just the conv blocks
- pass one batch of data through this "net"
- pass the output of this "net" to num_flat_features

There are probably ways to calculate this, but ... for me, it's interesting that we can create part of a neural net and take a look at its output.

In [15]:
def num_flat_features(x): # taken from neural_networks_tutorial.html
    size = x.size()[1:]
    print('all dimensions except the batch dimension', size)
    num_features = 1
    for s in size:
        num_features *= s
    return num_features
conv_blocks = nn.Sequential(
    conv_block(3, 6, 3),
    conv_block(6, 16, 3))
data, _ = next(iter(train_loader))
print('fc1_in_features', num_flat_features(conv_blocks(data)))

all dimensions except the batch dimension torch.Size([16, 6, 6])
fc1_in_features 576
