In [1]:
# April 2023
# A walkthrough for batch normalization

![image](batch.png)

In [21]:
# what is important here is that the input and output of this layer
# is always of the same dimension. if you give it n n*n it gives you
# an n*n output which is normalized.

# I suggest you whatch [2] and [3] first.

In [11]:
import torch
from torch import nn

In [12]:
sample_n   = 2
channel_n  = 2
img_width  = 2
img_height = 2
img = torch.rand((sample_n, channel_n, img_width, img_height))
img

tensor([[[[0.6728, 0.8347],
          [0.5499, 0.0040]],

         [[0.8692, 0.8842],
          [0.1897, 0.7183]]],


        [[[0.0265, 0.7271],
          [0.7613, 0.5915]],

         [[0.1974, 0.0166],
          [0.1354, 0.7594]]]])

In [13]:
img.shape

torch.Size([2, 2, 2, 2])

In [15]:
# num_features here in BatchNormalization referes to num channels.
# in other normalizations it means num of smaples in batch, etc.
bn = torch.nn.BatchNorm2d(num_features=2, momentum=1)

In [16]:
bn(img)

tensor([[[[ 0.4995,  1.0321],
          [ 0.0950, -1.7004]],

         [[ 1.1568,  1.2002],
          [-0.8185,  0.7181]]],


        [[[-1.6266,  0.6780],
          [ 0.7904,  0.2320]],

         [[-0.7961, -1.3217],
          [-0.9762,  0.8375]]]], grad_fn=<NativeBatchNormBackward0>)

In [19]:
# batch normalization is done 'accross samples'.

In [17]:
bn.state_dict()

OrderedDict([('weight', tensor([1., 1.])),
             ('bias', tensor([0., 0.])),
             ('running_mean', tensor([0.5210, 0.4713])),
             ('running_var', tensor([0.1056, 0.1352])),
             ('num_batches_tracked', tensor(1))])

In [24]:
sample_n   = 100
channel_n  = 3
img_width  = 28
img_height = 28
img = torch.rand((sample_n, channel_n, img_width, img_height))
# img

In [81]:
# ex1

In [83]:
net = nn.Sequential(
                nn.Conv2d(in_channels=channel_n, 
                          out_channels=16, 
                          kernel_size=(3, 3), padding=1 ),
    
                
                nn.AvgPool2d(kernel_size=2, stride=2),
                nn.BatchNorm2d(16),

)

In [84]:
net(img).shape

torch.Size([100, 16, 14, 14])

In [85]:
#ex2

In [86]:
net = nn.Sequential(
                nn.Conv2d(in_channels=channel_n, 
                          out_channels=16, 
                          kernel_size=(3, 3), padding=1 ),
    
                
                nn.AvgPool2d(kernel_size=2, stride=2),
                nn.BatchNorm2d(16),
                
                nn.Conv2d(in_channels=16, 
                          out_channels=8, 
                          kernel_size=(3, 3), padding=1 ),
    
                
                nn.AvgPool2d(kernel_size=2, stride=2),
                nn.BatchNorm2d(8),
                nn.ReLU()

)

In [87]:
net(img).shape

torch.Size([100, 8, 7, 7])

In [88]:
net = nn.Sequential(
                nn.Conv2d(in_channels=channel_n, 
                          out_channels=16, 
                          kernel_size=(3, 3), padding=1 ),
    
                nn.BatchNorm2d(16),
                nn.Flatten(start_dim=1),
                nn.ReLU(),
                nn.Linear(16*28*28,24),
                
                nn.BatchNorm1d(24),
                nn.ReLU()

)

In [89]:
# look at 16 and 24  in the top cell... It gives you
# the intuition on how to set batch normalization inputs.
# also becareful about the second batch layer, i.e.
# BatchNorm1d instead of 2d. Cos we flattend that earlier. 

In [90]:
net(img).shape

torch.Size([100, 24])

In [95]:
# it's been said that in real application momentum is chosen small like 0.1

In [96]:
net = nn.Sequential(
                nn.Conv2d(in_channels=channel_n, 
                          out_channels=16, 
                          kernel_size=(3, 3), padding=1 ),

                nn.BatchNorm2d(16, momentum=0.1),

)

In [94]:
net(img).shape

torch.Size([100, 16, 28, 28])

In [None]:
# about the learnable params

In [98]:
# With Learnable Parameters
nn.BatchNorm2d(100)

BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [97]:
# Without Learnable Parameters
nn.BatchNorm2d(100, affine=False)

BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)

In [99]:
# Note: never forget to use net.eval() when ever you want to avoid some randomness
# in the test or wherever you need that..

In [None]:
# There are also groupnormalization, localnormalizationa and so forth.