In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from components.imports import *
from components.lightning import train as _train, LightningModel
from components.lightning.telemetry import Callback
from components.models import ModelBuilder
from components.models.base import basic_model_head
from components.metrics import accuracy
from components.datasets import imagenette
from components.utils import find_all
from components.models.resnet import ResNet

In [None]:
import seaborn as sns
sns.set_theme('paper')

This is the accompanying notebook to the "Replacing BatchNorm" blog post. For a full description of the ideas and outcomes of this notebook, please see the article.

We cover a few different methods of normalising layers in a neural network:
- BatchNorm
- Increasing Epsilon
- Running BatchNorm
- LayerNorm, GroupNorm, InstanceNorm
- Initialise to normalise
- Weight standardisation

 On the imagenette leaderboard, the best result after 5 epochs at 128px is 85%. The networks and the image sizes that we use here are smaller than the ones used on the leaderboard, but we should be able to get a similar result (peaking at 75%).

In [None]:
train_ds, val_ds = imagenette('/home/sara/datasets/imagenette2-160/', 96)
sub_val_dl = torch.utils.data.DataLoader(val_ds, batch_size=256, 
                                         sampler=torch.utils.data.SubsetRandomSampler(np.random.choice(np.arange(len(val_ds)), int(0.05*len(val_ds)))))

## The Baseline (No BatchNorm)

In [None]:
class ConvBlock(nn.Sequential):
    def __init__(self, cin, cout, bn=None, act=nn.ReLU):
        layers = ([nn.Conv2d(cin, cout, 3, padding=1, stride=2, bias=False), act()] + 
                  ([] if bn is None else [bn(cout)]))
        super().__init__(*layers)
    
class Network(ModelBuilder):
    __name__ = 'SmallCNN'
    def __init__(self, bn, categories=10):
        conv_params = [(c + [bn]) for c in [
            [3,8],[8,16],[16,32],[32,32]
        ]]
        super().__init__(ConvBlock, conv_params, head=basic_model_head(32, categories))

class Baseline(LightningModel):
    def __init__(self, hparams, norm_layer, use_resnet=False):
        if use_resnet:
            model = ResNet(18, 10)
            model.__name__ = 'ResNet18'
            ms,ps = find_all(model, nn.BatchNorm2d, path=True)
            for m,p in zip(ms,ps):
                model[p] = nn.Identity() if norm_layer is None else norm_layer(m.num_features)
        else: 
            model = Network(norm_layer)
        super().__init__(hparams, model, nn.CrossEntropyLoss(), 
                         train_ds=train_ds, val_ds=val_ds, metrics=[accuracy])
        self.reset()
        print(model)
    def reset(self):
        self.model.apply(init_cnn_)
        
def init_cnn_(m):
    if getattr(m, 'bias', None) is not None:
        nn.init.zeros_(m.bias)
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)

This train function is the core of the testing. We can set the normalisation method, the hyper-params to grid search over, whether to print telemetry, how many runs to do (and display standard deviation for) and whether to use a resnet architecture or a small baseline network (4 conv layers).

In [None]:
def train(normalisation=None, hparams=None, max_epochs=3, telemetry=True, 
          n_runs=3, use_resnet=False, **kwargs):
    hp = {'lr': [1e-2,1e-3], 'bs': [128, 2], 'sched': None}
    if hparams is not None: 
        hp.update(hparams)
    network = Baseline(hp, normalisation, use_resnet)
    _train(network, max_epochs=max_epochs, save_top_k=0, telemetry=telemetry, 
           n_runs=n_runs, **kwargs)
    network.logger.plot()
    return network

In [None]:
train()

In [None]:
train(use_resnet=True)

## BatchNorm Baseline

In [None]:
train(nn.BatchNorm2d)

In [None]:
train(nn.BatchNorm2d, use_resnet=True)

### Increasing epsilon

In [None]:
train(lambda x: nn.BatchNorm2d(x, eps=1e-1))

In [None]:
train(lambda x: nn.BatchNorm2d(x, eps=1e-1), use_resnet=True)

## Running BatchNorm

This in fact is considered in the WS paper as the BCN layer using estimates of the statistics at training time. They also cite https://arxiv.org/abs/1702.03275 as doing the same idea first. 

These both don't however use the "true" statistics up to that point, instead relying on the moving average of the statistics.

Interestingly, I think it is important to note that the gradients of a "detached" normalisation and the batch normalisation are very different - and it is this characteristic that seemingly smooths training. If simply normalising the outputs of the network was needed, then the initialise to normalise method described below would perform equally. This interpretation, however, seems to be missing from the BCN part of the WS paper - they say that BN helps remove "elimination singularities" and CN (GN) helps stabilise the BN for small batches. 

This means that the running batch norm updates, in point of fact, are quite different from the BN updates in terms of the gradients propagated. The gradients do not have the smoothing effects of the batch statistics.

In [None]:
class RunningBatchNorm(nn.Module):
    """Running Batch Norm layer from fast.ai: 
    https://github.com/fastai/course-v3/blob/master/nbs/dl2/07_batchnorm.ipynb

    Uses the running calculations at training time for batch norm.
    """
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom,self.eps = mom,eps
        self.weight = nn.Parameter(torch.ones(nf,1,1))
        self.bias = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('batch', torch.tensor(0.))
        self.register_buffer('count', torch.tensor(0.))
        self.register_buffer('step', torch.tensor(0.))
        self.register_buffer('dbias', torch.tensor(0.))

    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2,3)
        s = x.sum(dims, keepdim=True)
        ss = (x*x).sum(dims, keepdim=True)
        c = self.count.new_tensor(x.numel()/nc)
#         mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
        mom1 = 1 - (1-self.mom)/math.sqrt(bs)
        self.mom1 = self.dbias.new_tensor(mom1)
        self.sums.lerp_(s, self.mom1)
        self.sqrs.lerp_(ss, self.mom1)
        self.count.lerp_(c, self.mom1)
        self.dbias = self.dbias*(1-self.mom1) + self.mom1
        self.batch += bs
        self.step += 1

    def forward(self, x):
        if self.training: self.update_stats(x)
        sums = self.sums
        sqrs = self.sqrs
        c = self.count
        if self.step<100:
            sums = sums / self.dbias
            sqrs = sqrs / self.dbias
            c    = c    / self.dbias
        means = sums/c
        vars = (sqrs/c).sub_(means*means)
        if bool(self.batch < 20): vars.clamp_min_(0.01)
        x = (x-means).div_((vars.add_(self.eps)).sqrt())
        return x.mul_(self.weight).add_(self.bias)

In [None]:
train(RunningBatchNorm)

In [None]:
train(RunningBatchNorm, use_resnet=True)

## LayerNorm, GroupNorm, InstanceNorm

In [None]:
# LayerNorm
train(lambda x: nn.GroupNorm(1,x))

In [None]:
# LayerNorm
train(lambda x: nn.GroupNorm(1,x), use_resnet=True)

In [None]:
# GroupNorm Paper defaults to 32 groups.
# Try 2/4/8
# 4 Seems pretty good for the smallcnn, try larger for resnet
n_groups = 4
train(lambda x: nn.GroupNorm(n_groups, x))

In [None]:
# GroupNorm 32 groups
n_groups = 32
train(lambda x: nn.GroupNorm(n_groups, x), use_resnet=True)

## Which Part of BN actually Helps

The Normalisation or the scaling and biasing parameters.

In [None]:
class BatchScaler2d(nn.Module):
    """Initialise the weights of a basic broadcast linear layer to normalise the input.
    Use the LSUV idea for initialisation."""
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(1,channels,1,1))
        self.bias = nn.Parameter(torch.zeros(1,channels,1,1))
    
    def forward(self, x):
        return x.mul(self.weight).add(self.bias)
    
class Normalise(nn.Module):
    def __init__(self, c, detach=False):
        super().__init__()
        self.register_buffer('running_mean', torch.zeros(1,c,1,1))
        self.register_buffer('running_var', torch.ones(1,c,1,1))
        self.d = detach
    def forward(self, x):
        if self.training:
            var,mean = torch.var_mean(x, dim=(0,2,3), keepdim=True)
            if self.d: var,mean = var.detach(),mean.detach()
            with torch.no_grad():
                self.running_mean.lerp_(mean, 0.1)
                self.running_var.lerp_(var, 0.1)
        else:
            var,mean = self.running_var,self.running_mean
        return x.sub(mean).div_((var+1e-5).sqrt())

In [None]:
# ProbBatchScaler2d should use resnet since the small network is pretty klein
train(BatchScaler2d)

In [None]:
train(Normalise)

In [None]:
# Without gradient propagation
train(lambda x: Normalise(x, True))

## Weight Standardisation

In [None]:
eps = 1e-4

def ws(weight):
    mu = weight.mean((1,2,3), keepdim=True)
    fin,fout = nn.init._calculate_fan_in_and_fan_out(weight)
    kaiming = math.sqrt(2)/math.sqrt(fin)
    std = torch.sqrt(weight.var((1,2,3), keepdim=True)+eps)
    return weight.sub(mu).div(std) #.mul(kaiming)

class WSConv(nn.Conv2d):
    def forward(self, input):
        return self._conv_forward(input, ws(self.weight))

class ConvBlock(nn.Sequential):
    def __init__(self, cin, cout, bn=None, act=nn.ReLU):
        layers = ([WSConv(cin, cout, 3, padding=1, stride=2, bias=False), act()] + 
                  ([] if bn is None else [bn(cout)]))
        super().__init__(*layers)

In [None]:
# WS with GN
n_groups = 4
train(lambda x: nn.GroupNorm(n_groups, x))

In [None]:
# WS with GN
n_groups = 32
nn.Conv2d = WSConv
train(lambda x: nn.GroupNorm(n_groups, x), use_resnet=True)

## Batch Channel Normalisation

In [None]:
import torch
from torch import nn

In [None]:
# From author implementation: https://github.com/joe-siyuan-qiao/Batch-Channel-Normalization
class BCNorm(nn.Module):

    def __init__(self, num_channels, num_groups, eps, estimate=True):
        super(BCNorm, self).__init__()
        self.num_channels = num_channels
        self.num_groups = num_groups
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(1, num_groups, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_groups, 1))
        if estimate:
            self.bn = EstBN(num_channels)
        else:
            self.bn = nn.BatchNorm2d(num_channels)

    def forward(self, inp):
        out = self.bn(inp)
        out = out.view(1, inp.size(0) * self.num_groups, -1)
        out = torch.batch_norm(out, None, None, None, None, True, 0, self.eps, True)
        out = out.view(inp.size(0), self.num_groups, -1)
        out = self.weight * out + self.bias
        out = out.view_as(inp)
        return out

class EstBN(nn.Module):

    def __init__(self, num_features):
        super(EstBN, self).__init__()
        self.num_features = num_features
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        self.register_buffer('estbn_moving_speed', torch.zeros(1))

    def forward(self, inp):
        ms = self.estbn_moving_speed.item()
        if self.training:
            with torch.no_grad():
                inp_t = inp.transpose(0, 1).contiguous().view(self.num_features, -1)
                running_mean = inp_t.mean(dim=1)
                inp_t = inp_t - self.running_mean.view(-1, 1)
                running_var = torch.mean(inp_t * inp_t, dim=1)
                self.running_mean.data.mul_(1 - ms).add_(ms * running_mean.data)
                self.running_var.data.mul_(1 - ms).add_(ms * running_var.data)
        out = inp - self.running_mean.view(1, -1, 1, 1)
        out = out / torch.sqrt(self.running_var + 1e-5).view(1, -1, 1, 1)
        weight = self.weight.view(1, -1, 1, 1)
        bias = self.bias.view(1, -1, 1, 1)
        out = weight * out + bias
        return out

In [None]:
class BCN(nn.Module):
    def __init__(self, c, groups):
        super().__init__()
        assert float(c//groups) == c/groups
        self.bn = RunningBatchNorm(c)
#         self.bn = nn.BatchNorm2d(c) # Basically change this to RunningBN (or adapt to a moving average version)
        self.cn = nn.GroupNorm(groups, c)
    def forward(self, x):
        return self.cn(self.bn(x))

In [None]:
train(lambda x: BCNorm(x, 4, 1e-5))

In [None]:
nn.Conv2d = WSConv
train(lambda x: BCNorm(x, 32, 1e-5), use_resnet=True)

In [None]:
train(lambda x: BCN(x, 4))

## Eigenvalues

In [None]:
# From https://github.com/tomgoldstein/loss-landscape

import torch
import time
import numpy as np
from torch import nn
from torch.autograd import Variable
from scipy.sparse.linalg import LinearOperator, eigsh

################################################################################
#                              Supporting Functions
################################################################################
def npvec_to_tensorlist(vec, params):
    """ Convert a numpy vector to a list of tensor with the same dimensions as params

        Args:
            vec: a 1D numpy vector
            params: a list of parameters from net

        Returns:
            rval: a list of tensors with the same shape as params
    """
    loc = 0
    rval = []
    for p in params:
        numel = p.data.numel()
        rval.append(torch.from_numpy(vec[loc:loc+numel]).view(p.data.shape).float())
        loc += numel
    assert loc == vec.size, 'The vector has more elements than the net has parameters'
    return rval


def gradtensor_to_npvec(net, include_bn=False):
    """ Extract gradients from net, and return a concatenated numpy vector.

        Args:
            net: trained model
            include_bn: If include_bn, then gradients w.r.t. BN parameters and bias
            values are also included. Otherwise only gradients with dim > 1 are considered.

        Returns:
            a concatenated numpy vector containing all gradients
    """
    filter = lambda p: include_bn or len(p.data.size()) > 1
    return np.concatenate([p.grad.data.cpu().numpy().ravel() for p in net.parameters() if filter(p)])


################################################################################
#                  For computing Hessian-vector products
################################################################################
def eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda=False):
    """
    Evaluate product of the Hessian of the loss function with a direction vector "vec".
    The product result is saved in the grad of net.

    Args:
        vec: a list of tensor with the same dimensions as "params".
        params: the parameter list of the net (ignoring biases and BN parameters).
        net: model with trained parameters.
        criterion: loss function.
        dataloader: dataloader for the dataset.
        use_cuda: use GPU.
    """

    if use_cuda:
        net.cuda()
        vec = [v.cuda() for v in vec]

    net.eval()
    net.zero_grad() # clears grad for every parameter in the net

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = Variable(inputs), Variable(targets)
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        outputs = net(inputs)
        loss = criterion(outputs, targets)
        grad_f = torch.autograd.grad(loss, inputs=params, create_graph=True)

        # Compute inner product of gradient with the direction vector
        prod = Variable(torch.zeros(1)).type(type(grad_f[0].data))
        for (g, v) in zip(grad_f, vec):
            prod = prod + (g * v).cpu().sum()

        # Compute the Hessian-vector product, H*v
        # prod.backward() computes dprod/dparams for every parameter in params and
        # accumulate the gradients into the params.grad attributes
        prod.backward()


################################################################################
#                  For computing Eigenvalues of Hessian
################################################################################
def min_max_hessian_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.

        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            rank: rank of the working node.
            use_cuda: use GPU
            verbose: print more information

        Returns:
            maxeig: max eigenvalue
            mineig: min eigenvalue
            hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
    """

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = npvec_to_tensorlist(vec, params)
        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0: print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        return gradtensor_to_npvec(net)

    hess_vec_prod.count = 0
    if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank)

    A = LinearOperator((N, N), matvec=hess_vec_prod)
    eigvals, eigvecs = eigsh(A, k=1, tol=1e-2)
    maxeig = eigvals[0]
    if verbose and rank == 0: print('max eigenvalue = %f' % maxeig)

    # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
    # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
    shift = maxeig*.51
    def shifted_hess_vec_prod(vec):
        return hess_vec_prod(vec) - shift*vec

    if verbose and rank == 0: print("Rank %d: Computing shifted eigenvalue" % rank)

    A = LinearOperator((N, N), matvec=shifted_hess_vec_prod)
    eigvals, eigvecs = eigsh(A, k=1, tol=1e-2)
    eigvals = eigvals + shift
    mineig = eigvals[0]
    if verbose and rank == 0: print('min eigenvalue = ' + str(mineig))

    if maxeig <= 0 and mineig > 0:
        maxeig, mineig = mineig, maxeig

    return maxeig, mineig, hess_vec_prod.count

In [None]:
import copy
import math
import torch
import numpy as np
from torch import nn
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt

from interpret.misc import get_state_dicts, normalize_direction, get_rand_dir, plot_loss_landscape

In [None]:
trained_none = train(hparams={'lr': 0.001, 'bs': 2}, n_runs=1, telemetry=False, max_epochs=10)

In [None]:
trained_bn = train(nn.BatchNorm2d, hparams={'lr': 0.005, 'bs': 128}, n_runs=1, telemetry=False, max_epochs=10)

In [None]:
def hessian_eigs(network, dataloader, loss_fn=nn.CrossEntropyLoss(), dir1=None, dir2=None,
                   dir1_bound=(-1,1,20), dir2_bound=(-1,1,20), device=None):
    device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
    trained_sd = copy.deepcopy(network.state_dict())

    eigs = []
    x_pts = dir1_bound[2]
    y_pts = dir2_bound[2]
    total = x_pts*y_pts
    for sd in tqdm(get_state_dicts(trained_sd, dir1, dir2,
                                   dir1_bound=dir1_bound, dir2_bound=dir2_bound),
                   total=total, desc='Generating eigs'):
        network.load_state_dict(sd)
        maxeig,mineig,c = min_max_hessian_eigs(network, dataloader, loss_fn, use_cuda=torch.cuda.is_available())
        print(maxeig, mineig)
        eigs.append(abs(mineig/maxeig))

    # Restore original state
    network.load_state_dict(trained_sd)

    X = np.linspace(*dir1_bound)
    Y = np.linspace(*dir2_bound)
    X,Y = np.meshgrid(X,Y)
    Z = np.array(eigs).reshape((x_pts, y_pts)).T

    return X,Y,Z

def plot_eigs(out, title):
    sns.set_style('white')
    plt.figure(figsize=(12,10))
    plt.imshow(out[2], cmap='viridis')
    plt.clim(vmin=0,vmax=0.5)
#     plt.title(title)
    plt.colorbar()
    plt.axis('off')

In [None]:
out = hessian_eigs(trained_bn, sub_val_dl, dir1_bound=(-1,1,16), dir2_bound=(-1,1,16))
plot_eigs(out, 'BatchNorm Hessian Eigs')

In [None]:
out = hessian_eigs(trained_none, sub_val_dl, dir1_bound=(-1,1,16), dir2_bound=(-1,1,16))
plot_eigs(out, 'No BatchNorm Hessian Eigs')

## Visualisations

In [None]:
class ReLUFuncOverride:
    def __init__(self):
        self._orig_relu = F.relu
        
        def fn(*args, **kwargs):
            self.i += 1
            if self.i > self.iters:
                F.relu = self._orig_relu
            return GradReLU.apply(args[0])
        
        self.fn = fn
        self.iters = 50
        self.i = 0
        
    def __enter__(self):
        F.relu = self.fn
    def __exit__(self, *args):
        F.relu = self._orig_relu
        
class GradReLU(torch.autograd.Function):
    @staticmethod
    def forward(self, x, inplace=False):
        return x.clamp(min=0)
    @staticmethod
    def backward(self, grads):
        return grads

In [None]:
import pytorch_lightning as pl
from interpret import OptVis, denorm, unfreeze
import torch.nn.functional as F
from IPython.display import Image
class VisList(pl.callbacks.base.Callback):
    def __init__(self, layer, channel):
        self.images = []
        self.layer = layer
        self.c = channel
        
    def on_batch_start(self, trainer, pl_module):
        if getattr(self, 'images', None) is None: self.images = []
        with ReLUFuncOverride():
            v = OptVis.from_layer(pl_module, self.layer, self.c).vis(verbose=False, thresh=(200,))
        self.images.append(denorm(v()))
        unfreeze(pl_module.train())

In [None]:
vis = VisList('model/body/4', 33)
net = train(nn.BatchNorm2d, {'lr': 1e-2, 'bs': 256}, n_runs=1, 
      callbacks=[vis], use_resnet=True, max_epochs=8)

In [None]:
images = vis.images
name = 'vis-randinit-saved2.gif'
images[0].save(name,
               save_all=True, append_images=images[1:], 
               optimize=False, duration=200, loop=0)

Image(name)

In [None]:
images = vis.images
name = 'vis-randinit3.gif'
images[0].save(name,
               save_all=True, append_images=images[1:], 
               optimize=False, duration=200, loop=0)

Image(name)