# Batch Normalisation in Detail - First Lecture

When we batch normalise a data (batch), we obtain a tensor of the same dimension of the input tensor.

We have to specify to the function `nn.BatchNorm2d` the n. of features of our batch.
In the case of a convolutional layer output, it is the second dimension of a `torch.tensor`.
In the case of a linear layer output, it is the first dimension of the `torch.tensor`.

In [46]:
import torch
import torch.nn as nn
import numpy as np

norm = nn.BatchNorm2d(2,track_running_stats=False)


In [54]:
elements = np.arange(24)
n_features = 2
img = torch.from_numpy(elements).view((3,n_features,1,4)).type(dtype=torch.float)

out = norm(img)

print(f'{img.shape = }')
print(f'{out.shape = }')
# out.shape has the same dimension of the input

img.shape = torch.Size([3, 2, 1, 4])
out.shape = torch.Size([3, 2, 1, 4])


We show ho to compute correctly the mean and variance of a batch to be used in
the formula for batch normalising.

1. Compute one after the other the means considering all the other axes then the second one.
In this way, we obtain the correct mean of the batch of dimension `(1,n_features,1,1)`.
2. Compute the variance using this mean you've just computed. Thus using the explicit formula to
compute the variance (through the use of the `torch.mean` function).

*Note*: If you use the formula to compute the variance which is computed on one axis after the other,
you will obtain a wrong result, because you would be using the mean of each axis for each operation.
And, this is not the mean of the whole batch. It is not the same as using always the last computed mean.

The parameter `keepdim` permits to retain the axis (which will be of dimension 1) where we compute the statistic.

In [None]:
# correct way to do batch normalisation
mean_channels = torch.mean(img,dim=[0,2,3], keepdim=True)
# the operation on the mean is the same of:
mean_axis0 = torch.mean(img,dim=[0], keepdim=True)
mean_axis2 = torch.mean(mean_axis0,dim=[2], keepdim=True)
mean_axis3 = torch.mean(mean_axis2,dim=[3], keepdim=True)
# only way to compute the variance correctly
var1_through_ch = torch.mean((img - mean_axis3)**2,dim=[0,2,3], keepdim=True)

print(f'{mean_axis0.shape}')
print(f'{mean_axis2.shape}')
print(f'{mean_axis3.shape}')

# this is to see if the mean computed in the two ways is the same
print((mean_axis3.eq(mean_channels)).sum().item()==n_features)

# this is not the correct way to compute the variance for batch normalisation
var_through_ch = torch.var(img,dim=[0,2,3], keepdim=True)


In the following computation of the normalisation 2D for the batch, we use the default
value as number for numerical stability `eps=1e-5`.
The results are slight different for the batch normalised tensor using the already implemented
function `nn.BatchNorm2D` and the computation from scratch.
*We saw that the difference starts by the 7th decimal digit*, indeed by using as absolute tolerance in `torch.isclose` equal to `1e-6`, we have that
all the elements of the two normalised data are the same.
If want more digits to compare, as with `1e-7` as precision, we will have that not all the elements
are the same.

In [None]:
torch.set_printoptions(precision=8)
eps = 1e-5
abs_tol = 1e-6

norm1_batch = (img-mean_channels)/torch.sqrt(var1_through_ch + eps)
norm1_batch.requires_grad_(False)

print(out)
print(norm1_batch)

# norm1_batch.eq(out)
print((norm1_batch.isclose(out, atol=abs_tol, rtol=0).sum()==img.numel()).item())