# 7. Batch Normalization

In [1]:
import torch
import torch.nn as nn

import numpy as np

## Batch Normalization

There are 3 main reason why **batch normalization** is essential:

1. standardizes the scale of the input features
2. big differences in scales might lead to training problems (e.g. learning rate must be adjusted)
3. prevents deep neural networks from overfitting

For an input $\mathbf{x} \in \mathcal{B}$, **batch normalization** is defined as follows:

$$\mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}$$

where $\hat{\boldsymbol{\mu}}_\mathcal{B}$ and ${\hat{\boldsymbol{\sigma}}_\mathcal{B}}$ are the **mean** and **standard deviation** of the minibatch $\mathcal{B}$:

$$\begin{aligned} \hat{\boldsymbol{\mu}}_\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\
\hat{\boldsymbol{\sigma}}_\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned}$$

and $\boldsymbol{\gamma}$ and $\boldsymbol{\beta}$ are the learnable parameters **scale** and **shift**.

Note that batch normalization is based on the **statistics of the minibatch** in training but based on the **statistics on the entire dataset** when making predicitons.

In [2]:
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
                    nn.AvgPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
                    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
                    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
                    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
                    nn.Linear(84, 10))