In [2]:
import torch
import torch.nn as nn
import torch.functional as F
import matplotlib.pyplot as plt
from torch import Tensor

### Implementation of Batch Normalization

In [3]:
# batch
# the batch size is 20 and inputs are 5d arrays

X = torch.randn(20,1,1,5)


In [4]:
# sanity check of normalization operation

# Note: greater precision is required to pass these tests when using floating point inputs 

x = torch.tensor([[4,2], [6,10]], dtype=torch.float32)
n_x = (x - x.mean(dim=0)) / x.var(0).sqrt() 
print(f'mean: {n_x.mean(0)}') # should be zero 
print(f'var: {n_x.var(0)}') # should be one

mean: tensor([0., 0.])
var: tensor([1.0000, 1.0000])


In [8]:
class MyBatchNorm(nn.Module):
    def __init__(self, num_feats:int, track_running = False):
        super(MyBatchNorm, self).__init__()
        
        self.g = nn.Parameter(torch.ones(1, num_feats)) # gamma (scale)
        self.b = nn.Parameter(torch.zeros(1, num_feats)) # beta (shift)
        
        self.eps = torch.tensor(0.001) # epsilon
        
        self.track_running = track_running
        self.running_mean = 0 if track_running else None
        self.running_var = 1 if track_running else None
        self.momentum = 0.1 if track_running else None
          
    def forward(self, x:Tensor):
        mean = x.mean(dim=0)
        var = x.var(dim=0)
                
        n_x = (x - mean) / (var + self.eps).sqrt() 
        
        out = self.g * n_x + self.b # apply affine transformation
        
        if self.track_running:
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean  
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        
        return out

In [14]:
torch.manual_seed(42) # for reproducibility

n_out = MyBatchNorm(X.shape[-1], track_running=True)(X)

In [15]:
n_out

tensor([[[[ 0.4136,  1.0143,  0.9512, -0.5985,  1.2037]]],


        [[[ 0.6544, -1.0698, -0.8482, -0.4151, -0.1756]]],


        [[[ 0.4210,  0.2651,  0.2415,  1.6280,  1.2487]]],


        [[[-0.8418, -0.5261, -0.9307,  1.2942, -0.9302]]],


        [[[ 0.6667,  0.8937,  1.1935,  0.0924,  1.5194]]],


        [[[-0.1991, -0.3589,  1.0229,  0.6780,  1.0011]]],


        [[[-2.1048, -1.0118, -0.7832,  1.2787,  0.0391]]],


        [[[-0.9924, -1.0418, -0.9804, -0.6197, -0.6769]]],


        [[[ 0.5459, -0.7028, -2.0770,  2.1612,  0.7681]]],


        [[[ 0.5409,  1.7770, -0.6982, -1.1653,  0.2736]]],


        [[[-0.4789, -0.1909,  0.5242, -0.4069, -1.2589]]],


        [[[-1.5301,  2.3469,  0.5438,  0.1440,  0.0774]]],


        [[[ 1.9908, -0.9825,  0.2264, -0.3796, -1.6928]]],


        [[[ 0.7313, -0.4806,  0.9011, -0.2341, -2.1980]]],


        [[[-0.1189,  1.0230, -0.3180,  0.2879,  0.2401]]],


        [[[ 1.5075,  0.2922, -1.4836, -1.6857, -0.5960]]],


        [[[-0.7173, -0.1

In [23]:
a = MyBatchNorm(5, track_running=True)

In [24]:
a(X)
a.running_mean

tensor([[[ 0.0105, -0.0249,  0.0293,  0.0077,  0.0426]]])