# Layerwise Sequential Unit Variance (LSUV)

The goal of this notebook is to see if "model specific weight initialization" improves training a small neural net.

See: All you need is a good init https://arxiv.org/abs/1511.06422

Starting with pytorch-custom-layer.ipynb, we;
- get rid of `NormalizeActivation`
- "borrow" the hooks, LSUV approach etc from https://github.com/fastai/course-v3/blob/master/nbs/dl2/07a_lsuv.ipynb
- make it easy to compare accuracy of a trained model with/without LSUV.

## What does do_lsuv do?

`do_lsuv` can make initialisation changes in reponse to the standard deviation (std) and mean of an activation

### std
- for each ReLU in the model
    - run a batch of data through the model
    - if the standard deviation of the activation of the ReLU is too far from 1, update the weights of the previous layer (conv or linear)
    - repeat the previous 2 steps until std is close to 1
    
### mean
- for each ReLU in the model
    - run a batch of data through the model
    - if the mean of the activation of the ReLU is too far from 0, update `ReLU#sub` (which makes a "post-ReLU" adjustment)
    - repeat the previous 2 steps until mean is close to 0

We try 3 variations around LSUV;
- 1 std only - like https://arxiv.org/abs/1511.06422
    - expect an initial mean of ~.5 because ReLU clamps to min zero
- 2 mean then std - like https://github.com/fastai/course-v3/blob/master/nbs/dl2/07a_lsuv.ipynb
    - expect an inital mean of ~.25 because the std adjustment messes up the mean adjustment we just made
- 3 mean then std in a loop
    - to get as close as possible to initial mean=0 and std=1

Running this notebook should give results similar to;

| LSUV type             |1st epoch accuracy |3rd epoch accuracy |
|-----------------------|-------------------|-------------------|
| mean then std in loop | ~93               | ~95               |
| mean then std         | ~91               | ~95               |
| std only              | ~89               | ~95               |
| none                  | ~84               | ~92               |

While all 3 LSUV variations end up with the same accuracy after 3 epochs (and have very similar stats after training too), "mean then std in loop" does consistently better on the 1st epoch.

Interesting that `NormalizeActivation` gives better results (o:

In [1]:
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from fastai.datasets import untar_data, URLs
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [2]:
path = untar_data(URLs.MNIST)
batch_size = 256
device = 'cuda:0'

In [3]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2), # pad images so we don't loose too much in the conv layers (28x28 to 32x32)
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.1], [0.2752]) # see: calculate mean/standard deviation ...
])
def new_loader(type, shuffle):
    return DataLoader(
        ImageFolder(root=path/type, transform=transforms), 
        batch_size=batch_size, num_workers=1, shuffle=shuffle)
train_loader = new_loader('training', True)
test_loader = new_loader('testing', False)

ReLU that can make a "post-ReLU" adjustment. When sub=0, this is just a normal ReLU

In [4]:
class ReLU(nn.Module):
    def __init__(self, sub=0):
        super(ReLU, self).__init__()
        self.sub = sub
     
    def forward(self, x):
        return F.relu(x).sub_(self.sub)
    
    def extra_repr(self):
        return f'sub={self.sub}'

In [5]:
def conv_block(in_channels, out_channels, kernel_size):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size),
        ReLU(),
        nn.MaxPool2d((2,2)))

def fc_block(in_features, out_features):
    return nn.Sequential(
        nn.Linear(in_features, out_features),
        ReLU())

In [6]:
# how many features input to 1st fully connected layer
def num_flat_features(x): # taken from neural_networks_tutorial.html
    size = x.size()[1:]
    print('all dimensions except the batch dimension', size)
    return torch.tensor(size).prod().item() # product of all elements in size
conv_blocks = nn.Sequential(
    conv_block(3, 6, 3),
    conv_block(6, 16, 3))
data, _ = next(iter(train_loader))
fc1_in_features = num_flat_features(conv_blocks(data))
print('fc1_in_features', fc1_in_features)

all dimensions except the batch dimension torch.Size([16, 6, 6])
fc1_in_features 576


## What does View do?

View "re-shapes" the data going into the 1st fully connected layer. Having this logic in an nn.Module makes building the nn.Sequential easy.

In [7]:
class View(nn.Module):
    def forward(self, x): return x.view(-1, fc1_in_features)

Search a model and return all ReLUs.

When we find a ReLU we save the previous module (conv or linear) in `_previous` so that `do_lsuv` can easily update the weights of the previous module to affect std

In [8]:
def find_hookable_modules(net):
    children = list(net.children())
    if len(children) >= 2 and isinstance(children[0], (nn.Conv2d, nn.Linear)):
        relu = children[1]
        relu._previous = children[0] # need easy access to the previous layer
        return [relu] # hook the ReLU after the Conv2d
    return sum([find_hookable_modules(m) for m in net.children()], [])

"borrow" hooks and `append_stat` from https://github.com/fastai/course-v3/blob/master/nbs/dl2/07a_lsuv.ipynb

In [9]:
class Hook():
    def __init__(self, module, f): 
        self.handle = module.register_forward_hook(partial(f, self))
    def remove(self): 
        self.handle.remove()
    def __del__(self): 
        self.remove()
        
class Hooks():
    def __init__(self, modules, f): 
        self.hooks = [Hook(module, f) for module in modules]
    def __enter__(self, *args):
        return self.hooks
    def __exit__ (self, *args): 
        self.remove()
    def __del__(self): 
        self.remove()
    def remove(self):
        for hook in self.hooks: hook.remove()
            
def append_stat(hook, mod, inp, outp):
    d = outp.data
    hook.mean, hook.std = d.mean().item(), d.std().item()

In [10]:
def accuracy(net):
    net.eval()
    total = 0
    correct = 0
    incorrect = []
    with torch.no_grad():
        batch = 0
        for (data, target) in test_loader:
            batch += 1
            target = target.to(device)
            output = net(data.to(device))
            predictions = torch.argmax(output, dim=1)
            number_correct = (predictions == target).float().sum().item()
            total += len(target)
            correct += number_correct
    print(f'accuracy over {total} test images: {round(correct/total*100, 2)}')

In [11]:
def train(net, lrs):
    def f(x): return round(x.item(), 4)
    criterion = nn.CrossEntropyLoss()
    epoch = 0
    for lr in lrs:
        net.train()
        epoch += 1
        optimizer = optim.SGD(net.parameters(), lr=lr)
        losses = []
        total = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            total += len(data)
            optimizer.zero_grad()
            out = net(data.to(device))
            loss = criterion(out, target.to(device))
            if torch.isnan(loss):
                raise RuntimeError('loss is nan: re-build net and re-try (maybe with lower lr)')
            losses.append(loss.item())
            loss.backward()
            optimizer.step()
        losses = torch.tensor(losses)
        print('epoch', epoch, 'lr', lr, 'loss', f(losses[:25].mean()), 'last', f(loss), 
              'min', f(losses.min()), 'max', f(losses.max()), 'items', total)
        accuracy(net)

In [12]:
# grab a batch of data for do_lsuv and print_stats
xb, _ = next(iter(train_loader))
xb = xb.to(device)

In [13]:
def do_lsuv(lsuv_type, net, module, tolerance=1e-3):
    # lsuv_type: 1 std only, 2 mean then std, 3 mean then std in loop
    hook = Hook(module, append_stat)
    
    with torch.no_grad():
        if lsuv_type == 2:
            while net(xb) is not None and abs(hook.mean) > tolerance: module.sub += hook.mean
        if lsuv_type in [1, 2]:
            while net(xb) is not None and abs(hook.std-1) > tolerance: module._previous.weight.data /= hook.std
        else:
            net(xb)
            while abs(hook.mean) > tolerance or abs(hook.std-1) > tolerance:
                module.sub += hook.mean
                module._previous.weight.data /= hook.std
                net(xb)
        
    hook.remove()
    return 'LSUV type', lsuv_type, hook.mean, hook.std

We'll use `print_stats` to check activation stats at the end of training

In [14]:
def print_stats(net):
    with torch.no_grad():
        with Hooks(find_hookable_modules(net), append_stat) as hooks:
            net(xb)
            for i, hook in enumerate(hooks):
                print('layer', i, 'mean', hook.mean, 'std', hook.std)

In [15]:
def setup_and_train(lsuv_type):
    net = nn.Sequential(
        conv_block(3, 6, 3),
        conv_block(6, 16, 3),
        View(),
        fc_block(fc1_in_features, 120),
        fc_block(120, 84),
        nn.Linear(84, 10)).to(device)
    if lsuv_type: 
        for module in find_hookable_modules(net): 
            print(do_lsuv(lsuv_type, net, module))
    print(net)
    lrs = [1.5e-2, 1e-2, 5e-3] # 2.5e-2 can work for 1st epoch but can be too high - depending on init
    train(net, lrs)
    print_stats(net)

In [16]:
setup_and_train(3) # run with LSUV mean then std in loop

('LSUV type', 3, 0.00039939384441822767, 1.0000386238098145)
('LSUV type', 3, 5.230880924500525e-05, 0.9999986290931702)
('LSUV type', 3, 0.0001786887733032927, 1.0)
('LSUV type', 3, -0.0006501650204882026, 0.9999986886978149)
Sequential(
  (0): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(
      sub=0.32031676825135946
      (_previous): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(
      sub=0.4957392776850611
      (_previous): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (2): View()
  (3): Sequential(
    (0): Linear(in_features=576, out_features=120, bias=True)
    (1): ReLU(
      sub=0.6864708364009857
      (_previous): Linear(in_features

In [17]:
setup_and_train(2) # run with LSUV mean then std

('LSUV type', 2, 0.2810233235359192, 0.9991238117218018)
('LSUV type', 2, 0.2697175145149231, 0.9998458623886108)
('LSUV type', 2, 0.2996945083141327, 0.9990943074226379)
('LSUV type', 2, 0.4509803354740143, 0.999989926815033)
Sequential(
  (0): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(
      sub=0.12537558376789093
      (_previous): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(
      sub=0.2298620641231537
      (_previous): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (2): View()
  (3): Sequential(
    (0): Linear(in_features=576, out_features=120, bias=True)
    (1): ReLU(
      sub=0.3807474374771118
      (_previous): Linear(in_features

In [18]:
setup_and_train(1) # run with LSUV std only

('LSUV type', 1, 0.3195320963859558, 0.9998204708099365)
('LSUV type', 1, 0.40966179966926575, 0.9998111724853516)
('LSUV type', 1, 0.6362127065658569, 1.0006959438323975)
('LSUV type', 1, 0.7175247073173523, 0.9999966025352478)
Sequential(
  (0): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(
      sub=0
      (_previous): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(
      sub=0
      (_previous): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (2): View()
  (3): Sequential(
    (0): Linear(in_features=576, out_features=120, bias=True)
    (1): ReLU(
      sub=0
      (_previous): Linear(in_features=576, out_features=120, bias=True)
    )
  )
  (4)

In [19]:
setup_and_train(0) # run without LSUV

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(sub=0)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(sub=0)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (2): View()
  (3): Sequential(
    (0): Linear(in_features=576, out_features=120, bias=True)
    (1): ReLU(sub=0)
  )
  (4): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): ReLU(sub=0)
  )
  (5): Linear(in_features=84, out_features=10, bias=True)
)
epoch 1 lr 0.015 loss 2.2933 last 0.3849 min 0.3849 max 2.3155 items 60000
accuracy over 10000 test images: 84.36
epoch 2 lr 0.01 loss 0.4391 last 0.3158 min 0.2247 max 0.572 items 60000
accuracy over 10000 test images: 90.4
epoch 3 lr 0.005 loss 0.3011 last 0.467 min 0.1641 max 0.467 items 60000
accuracy over 