### What does de-bias do in Running Batch Norm?

It looks like the de-bias logic doesn't change the calculated means or vars. I can't explain it, but I can demononstrate using 07_batchnorm.ipynb as a start point.

Turns out Stas knows the answer (https://forums.fast.ai/u/stas/summary)

https://forums.fast.ai/t/lesson-10-discussion-wiki-2019/42781/339 - the count and sums are biased in the same way, use them together and the two sets of bias end up canceling out!

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_06 import *

## ConvNet

Let's get the data and training interface from where we left in the last notebook.

[Jump_to lesson 10 video](https://course.fast.ai/videos/?lesson=10&t=5899)

In [3]:
x_train,y_train,x_valid,y_valid = get_data()

x_train,x_valid = normalize_to(x_train,x_valid)
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)

nh,bs = 50,512
c = y_train.max().item()+1
loss_func = F.cross_entropy

data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [4]:
mnist_view = view_tfm(1,28,28)
cbfs = [Recorder,
        partial(AvgStatsCallback,accuracy),
        CudaCallback,
        partial(BatchTransformXCallback, mnist_view)]

In [5]:
nfs = [8,16,32,64,64]

### Running Batch Norm

In [6]:
pre_dbias_means, pre_dbias_vars, post_dbias_means, post_dbias_vars = {}, {}, {}, {}

class RunningBatchNorm(nn.Module):
    def __init__(self, nf, mom=0.1, eps=1e-5):
        super().__init__()
        self.mom,self.eps = mom,eps
        self.mults = nn.Parameter(torch.ones (nf,1,1))
        self.adds = nn.Parameter(torch.zeros(nf,1,1))
        self.register_buffer('sums', torch.zeros(1,nf,1,1))
        self.register_buffer('sqrs', torch.zeros(1,nf,1,1))
        self.register_buffer('batch', tensor(0.))
        self.register_buffer('count', tensor(0.))
        self.register_buffer('step', tensor(0.))
        self.register_buffer('dbias', tensor(0.))

    def update_stats(self, x):
        bs,nc,*_ = x.shape
        self.sums.detach_()
        self.sqrs.detach_()
        dims = (0,2,3)
        s = x.sum(dims, keepdim=True)
        ss = (x*x).sum(dims, keepdim=True)
        c = self.count.new_tensor(x.numel()/nc)
        mom1 = 1 - (1-self.mom)/math.sqrt(bs-1)
        self.mom1 = self.dbias.new_tensor(mom1)
        self.sums.lerp_(s, self.mom1)
        self.sqrs.lerp_(ss, self.mom1)
        self.count.lerp_(c, self.mom1)
        self.dbias = self.dbias*(1-self.mom1) + self.mom1
        self.batch += bs
        self.step += 1

    def forward(self, x):
        global pre_dbias_means, pre_dbias_vars, post_dbias_means, post_dbias_vars
        if self.training: self.update_stats(x)
        sums = self.sums
        sqrs = self.sqrs
        c = self.count
        # calculate means/vars before dbias and save them
        means = sums/c
        vars = (sqrs/c).sub_(means*means)
        if means.shape not in pre_dbias_means:
            pre_dbias_means[means.shape] = []
            post_dbias_means[means.shape] = []
            pre_dbias_vars[means.shape] = []
            post_dbias_vars[means.shape] = []
        pre_dbias_means[means.shape].append(means)
        pre_dbias_vars[means.shape].append(vars)
        # end
        if self.step<100:
            sums = sums / self.dbias
            sqrs = sqrs / self.dbias
            c    = c    / self.dbias
        means = sums/c
        vars = (sqrs/c).sub_(means*means)
        # save means/vars after dbias
        post_dbias_means[means.shape].append(means)
        post_dbias_vars[means.shape].append(vars)
        # end
        if bool(self.batch < 20): vars.clamp_min_(0.01)
        x = (x-means).div_((vars.add_(self.eps)).sqrt())
        return x.mul_(self.mults).add_(self.adds)

In [7]:
def conv_rbn(ni, nf, ks=3, stride=2, bn=True, **kwargs):
    layers = [nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=not bn),
              GeneralRelu(**kwargs)]
    if bn: layers.append(RunningBatchNorm(nf))
    return nn.Sequential(*layers)

When we init this cnn, `l[0].bias` is NoneType - i.e. no bias on the Conv2d

Do we need to update init_cnn in the Generalized ReLU section of 06_cuda_cnn_hooks_init.ipynb?

In [8]:
def init_cnn(m, uniform=False):
    f = init.kaiming_uniform_ if uniform else init.kaiming_normal_
    for l in m:
        if isinstance(l, nn.Sequential):
            f(l[0].weight, a=0.1)
            if hasattr(l[0].bias, 'data'): # l[0].bias has no data
                l[0].bias.data.zero_()

In [9]:
learn,run = get_learn_run(nfs, data, 0.4, conv_rbn, cbs=cbfs)

In [10]:
%time run.fit(1, learn)

train: [0.29878693359375, tensor(0.9120, device='cuda:0')]
valid: [0.11251114501953124, tensor(0.9675, device='cuda:0')]
CPU times: user 3.23 s, sys: 535 ms, total: 3.77 s
Wall time: 3.47 s


Now check the pre vs post means and vars

In [11]:
def do_check():
    for shape in pre_dbias_means:
        a = torch.stack(pre_dbias_means[shape])
        b = torch.stack(post_dbias_means[shape])
        print('means', shape, torch.allclose(a, b))
        a = torch.stack(pre_dbias_vars[shape])
        b = torch.stack(post_dbias_vars[shape])
        print('vars', shape, torch.allclose(a, b, atol=1e-5)) # need to increase absolute tolerance a little
do_check()

means torch.Size([1, 8, 1, 1]) True
vars torch.Size([1, 8, 1, 1]) True
means torch.Size([1, 16, 1, 1]) True
vars torch.Size([1, 16, 1, 1]) True
means torch.Size([1, 32, 1, 1]) True
vars torch.Size([1, 32, 1, 1]) True
means torch.Size([1, 64, 1, 1]) True
vars torch.Size([1, 64, 1, 1]) True


In case it helps, find an example where we are not "all close"

In [12]:
size = torch.Size([1, 8, 1, 1])
for i in range(len(pre_dbias_means[size])):
    a = pre_dbias_vars[size][i]
    b = post_dbias_vars[size][i]
    if not torch.allclose(a, b):
        print('found a diffence in vars', size, 'i', i)
        print(a.view(size[1],))
        print(b.view(size[1],))
        break

found a diffence in vars torch.Size([1, 8, 1, 1]) i 0
tensor([0.3031, 0.0430, 0.0756, 4.9190, 0.1070, 0.2733, 0.4318, 5.0159],
       device='cuda:0', grad_fn=<ViewBackward>)
tensor([0.3032, 0.0430, 0.0756, 4.9190, 0.1070, 0.2733, 0.4319, 5.0159],
       device='cuda:0', grad_fn=<ViewBackward>)


### Re-init and check over a full epoch

In [13]:
pre_dbias_means, pre_dbias_vars, post_dbias_means, post_dbias_vars = {}, {}, {}, {}

In [14]:
data = DataBunch(*get_dls(train_ds, valid_ds, 32), c)

In [15]:
learn,run = get_learn_run(nfs, data, 0.9, conv_rbn, cbs=cbfs
                          +[partial(ParamScheduler,'lr', sched_lin(1., 0.2))])

In [16]:
%time run.fit(1, learn)

train: [0.15330603515625, tensor(0.9519, device='cuda:0')]
valid: [0.3575843994140625, tensor(0.9660, device='cuda:0')]
CPU times: user 9.77 s, sys: 38.3 ms, total: 9.81 s
Wall time: 9.59 s


In [17]:
do_check()

means torch.Size([1, 8, 1, 1]) True
vars torch.Size([1, 8, 1, 1]) True
means torch.Size([1, 16, 1, 1]) True
vars torch.Size([1, 16, 1, 1]) True
means torch.Size([1, 32, 1, 1]) True
vars torch.Size([1, 32, 1, 1]) True
means torch.Size([1, 64, 1, 1]) True
vars torch.Size([1, 64, 1, 1]) True
