# 批量归一化

Batch Normalization对于训练很深的网络是一种非常有效的提高收敛速度的方法，它可以让我们使用更大的学习率。

$$y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$

一般而言，使用或不使用BN，对于模型的精度是没有影响的，核心影响的是网络训练的速度。

在较深的网络结构中，因为梯度弥散的问题，导致靠近输出的部分收敛快，而靠近输入的网络收敛慢。如果每个批量的样本的特征在分布上差异较大，就会使得网络的output产生较大的差别，这样已经收敛的尾部就会开启大的调整，等于之前就白训练了。

所以我们可以对网络中每一层的输出在一个批量上进行归一化，让他们分布都符从均值为0，方差为1的分布。

更多BatchNorm为什么有效的Intuition：[Batch Normalization — an intuitive explanation](https://towardsdatascience.com/batch-normalization-an-intuitive-explanation-42e473fa753f)

In [1]:
import torch
from torch import nn

In [2]:
def batch_norm(X, gamma, beta, moving_mu, moving_var, momentum, eps):
    assert len(X.shape) in (2, 4) #只处理全连接层和2维卷积层
    # 通过is_grad_enabled判断当前是训练来是推理，如果是推理，则直接使用移动平均的均值和方差
    if not torch.is_grad_enabled():
        Y_hat = (X - moving_mu) / torch.sqrt(moving_var + eps)
    else: 
        # 全连接层
        if len(X.shape) == 2:
            mu = X.mean(dim=0)
            var = ((X - mu)**2).mean(dim=0)
        # 卷积层，将通道维看作是特征维，求均值和方差时，保持形状方法进行broadcasting
        if len(X.shape) == 4:
            mu = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mu)**2).mean(dim=(0, 2, 3), keepdim=True)
        Y_hat = (X - mu) / torch.sqrt(var + eps)
        # 对移动平均的均值与方差进行更新
        moving_mu = momentum * moving_mu + (1 - momentum) * mu
        moving_var = momentum * moving_var + (1 - momentum) * var
    Y = Y_hat * gamma + beta
    return Y, moving_mu, moving_var

In [3]:
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dim, momentum=0.9, eps=1e-5):
        super().__init__()
        if num_dim == 2:
            shape = (1, num_features)
        if num_dim == 4:
            shape = (1, num_features, 1, 1)
        # 初始化mu为0，var为1
        self.moving_mu = torch.zeros(shape)
        self.moving_var = torch.ones(shape)
        # gamma和beta属于需要进行更新的参数
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.momentum = momentum
        self.eps = eps

    def forward(self, x):
        if x.device != self.moving_mu.device:
            self.moving_mu, self.moving_var = self.moving_mu.to(x.device), self.moving_var.to(x.device)
            self.gamma.data = self.gamma.data.to(x.device)
            self.beta.data = self.beta.data.to(x.device)
        x, self.moving_mu, self.moving_var = batch_norm(
            x, self.gamma.data, self.beta.data, self.moving_mu, self.moving_var, self.momentum, self.eps)
        return x

In [4]:
a = torch.randn(1, 3, 5, 5, device='cuda')

In [5]:
bn = BatchNorm(num_features=a.shape[1], num_dim=len(a.shape))
bn(a)

tensor([[[[ 0.6553,  0.4040,  0.5476, -1.3837,  1.3100],
          [-0.4144,  1.7244, -0.7444, -0.9298, -0.7004],
          [ 1.2725,  1.9125, -0.3266, -0.8016, -0.4329],
          [ 0.4292,  0.9642,  0.8395,  0.4539, -0.2588],
          [-1.5279, -0.2640,  0.2332, -1.8499, -1.1118]],

         [[ 0.5982, -0.5397,  1.4367,  0.3117, -0.1409],
          [ 1.4547, -0.3516,  0.2519, -0.7490,  0.1770],
          [ 1.6657,  0.3715,  0.3305,  1.3777, -0.3522],
          [-0.5594, -1.5442,  0.6910, -2.7335, -0.6035],
          [-0.8145,  0.5104,  0.7009, -0.1551, -1.3341]],

         [[ 1.3260, -0.0765,  2.7826,  0.7548,  0.3470],
          [-0.1259, -0.6698, -0.7616, -0.2149,  0.7231],
          [-0.8911, -0.1881, -0.7951, -1.0108,  1.7675],
          [ 0.7457, -1.2737, -0.1661, -1.6098,  0.9449],
          [-0.6425,  0.5503, -0.0344, -0.5444, -0.9371]]]], device='cuda:0')

# BatchNorm in Pytorch

https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html

In [6]:
# torch里的BatchNorm2d只支持输出shape为NCHW，不支持NC的形式
torch_bn = nn.BatchNorm2d(num_features=a.shape[1])
torch_bn = torch_bn.to(a.device)
torch_bn(a)

tensor([[[[ 0.6553,  0.4040,  0.5476, -1.3837,  1.3100],
          [-0.4144,  1.7244, -0.7444, -0.9298, -0.7004],
          [ 1.2725,  1.9125, -0.3266, -0.8016, -0.4329],
          [ 0.4292,  0.9642,  0.8395,  0.4539, -0.2588],
          [-1.5279, -0.2640,  0.2332, -1.8499, -1.1118]],

         [[ 0.5982, -0.5397,  1.4367,  0.3117, -0.1409],
          [ 1.4547, -0.3516,  0.2519, -0.7490,  0.1770],
          [ 1.6657,  0.3715,  0.3305,  1.3777, -0.3522],
          [-0.5594, -1.5442,  0.6910, -2.7335, -0.6035],
          [-0.8145,  0.5104,  0.7009, -0.1551, -1.3341]],

         [[ 1.3260, -0.0765,  2.7826,  0.7548,  0.3470],
          [-0.1259, -0.6698, -0.7616, -0.2149,  0.7231],
          [-0.8911, -0.1881, -0.7951, -1.0108,  1.7675],
          [ 0.7457, -1.2737, -0.1661, -1.6098,  0.9449],
          [-0.6425,  0.5503, -0.0344, -0.5444, -0.9371]]]], device='cuda:0',
       grad_fn=<CudnnBatchNormBackward>)