# Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

这篇论文首先提出来一个猜想:
> The Change in the distributions of layers' inputs presents a problem because the layers need to continuously adapt to the new distribution.

这篇论文中, 深层网络内部节点的分布的变化称为 **Internal Covariate Shift**.
Batch Normalization 提出了处理梯度消失的一个解决方法: 通过保证网络中每一层的输入分布的一致性, 来减少训练深层网络的难度.

解决思路也非常简单(但其实理解起来并不是很直接): 对每一层的输入进行标准化(论文中称作 normalize each dimension，但是实际操作就是标准化), 使得标准化后的输入服从标准正态分布 $\mathcal{N}(0, 1)$. 记 $d$维的输入 $x=(x^{(1)} \cdots x^{(d)})$, 那么将 $x$ 标准化的操作为: 
$$
\hat{x}^{(k)} = \frac{x^(k) - \mathbf{E}[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}.
$$

这里的数学期望和方差是在 mini-batch 上计算得到的.

但是, 做到这一步可能存在的问题是: *the simply normalizing each input of a layer may change what the layer can represent.* 为了解决这个问题, batch norm 提出来一种方法来保证 *the transformation inserted in the network can represent the identity transform*. 具体来说, batch norm 使用了额外的两个参数 $\gamma^{k}, \beta^{k}$ 来对标准化后的值进行平移和缩放:

$$
y^{(k)} = \gamma^{(k)} \hat(x)^{k} + \beta^{k}.
$$

注意到这些参数也是需要在网络中进行学习的. **Batch Normalization** 算法如下:
![batch_norm](img/batch_norm_alg1.png)

其中的 $\epsilon$ 避免来零除从而保证数值稳定性. Batch Normalization 的反向传播公式如下:

$$
\frac{\partial l}{\partial \hat{x}_i} = \frac{\partial l}{\partial y_i} \cdot \gamma, \\
\frac{\partial l}{\partial \sigma_\mathcal{B}^2} = \sum_{i=1}^m \frac{\partial l}{\partial \hat{x}_i} \cdot (x_i - \mu_{\mathcal{B}}) \cdot \frac{-1}{2} (\sigma^2_{\mathcal{B}} + \epsilon)^{\frac{3}{2}}, \\
\frac{\partial l}{\partial \mu_{\mathcal{B}}} = (\sum_{i=1}^m \frac{\partial l}{\partial \hat{x}_i} \cdot \frac{-1}{\sqrt{\sigma_{\mathcal{B}}^2}}) + \frac{\partial l}{\partial \sigma^2_{\mathcal{B}}} \cdot \frac{\sum_{i=1}^m -2(x_i - \mu_{\mathcal{B}})}{m}, \\
\frac{\partial l}{\partial x_i} = \frac{\partial l}{\partial \hat{x}_i} \cdot \frac{1}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}} + \frac{\partial l}{\partial \sigma_{\mathcal{B}^2}} \cdot \frac{2(x_i - \mu_{\mathcal{B}})}{m} + \frac{\partial l}{\partial \mu_{\mathcal{B}}} \cdot \frac{1}{m},\\
\frac{\partial l}{\partial \gamma} = \sum_{i=1}^m \frac{\partial l}{\partial y_i} \cdot \hat{x}_i, \\
\frac{\partial l}{\partial \beta} = \sum_{i=1}^m \frac{\partial l}{\partial y_i}.
$$

最后需要注意的是, batch normalization 在训练和预测时候有些不同(摘自 《dive into deep learning》):
>***在训练过程中，我们⽆法得知使⽤整个数据集来估计平均值和⽅差，所以只能根据每个小批次的平均值和⽅差不断训练模型。而在预测模式下，可以根据整个数据集精确计算批量规范化所需的平均值和⽅差。***

但是, 训练完成后我们并不需要从整个数据集来计算平均值和方差. 一个技巧是 ***在训练过程中动态更新平均值和方差. 因此, 当训练完成后, 我们就能够直接使用最后一次更新得到的平均值和方差来标准化每一层的输入***.

包含训练和预测的 batch normalization 算法如下:
![batch_norm](img/batch_norm_alg2.png)



In [None]:
Batch Normalization 的实现如下:

In [None]:
import torch
from torch import nn

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
  if not torch.is_grad_enabled():
    # Inference mode.
    X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
  else:
    # Training mode.
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
      # Fully connected layer.
      mean = X.mean(dim=0)
      var = ((X - mean) ** 2).mean(dim=0)
    else:
      # 2D convolutional layer.
      mean = X.mean(dim=(0, 2, 3), keepdim=True)
      var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

    # Standardlization of the input.
    X_hat = (X - mean) / torch.sqrt(var + eps)
    # Update the mean and variance of mean shift.
    moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
    moving_var = momentum * moving_var + (1.0 - momentum) * var

  Y = gamma * X_hat + beta
  return Y, moving_mean.data, moving_var.data


class BatchNorm(nn.Module):
  def __init__(self, num_features, num_dims):
    super().__init__()
    if num_dims == 2:
      shape = (1, num_features)
    else:
      shape = (1, num_features, 1, 1)

    self.gamma = nn.Parameter(torch.ones(shape))
    self.beta = nn.Parameter(torch.zeros(shape))
    self.moving_mean = torch.zeros(shape)
    self.moving_var = torch.ones(shape)

  def forward(self, X):
    if self.moving_mean.device != X.device:
      self.moving_mean = self.moving_mean.to(X.device)
      self.moving_var = self.moving_var.to(X.device)
    # save the updated moving_mean and moving_var
    Y, self.moving_mean, self.moving_var = batch_norm(
      X, self.gamma, self.beta, self.moving_mean,
      self.moving_var, eps=1e-5, momentum=0.9)

    return Y