1. 计算批量均值和方差：
对于输入到某一层的每一个小批量数据（Batch），计算其均值和方差
2. 标准化：
使用批量均值和方差对输入数据进行标准化处理，使其均值为0，方差为1
其中，ϵ 是一个很小的常数，用于防止除零错误。
3. 缩放和平移：
为了保持网络的表达能力，Batch Normalization 引入了两个可学习的参数：缩放参数 γ 和平移参数 β。标准化后的数据再进行缩放和平移

In [None]:
import torch
import torch.nn as nn

class BatchNormalization(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNormalization, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # 初始化gamma和beta
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
        # 初始化running mean和running variance
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)

    def forward(self, x):
        if self.training:
            # 计算batch mean和batch variance
            batch_mean = x.mean(dim=0)
            batch_var = x.var(dim=0, unbiased=False)
            
            # 更新running mean和running variance
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
            
            # 归一化
            x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
        else:
            # 使用running mean和running variance进行归一化
            x_hat = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        
        # 缩放和平移
        y = self.gamma * x_hat + self.beta
        return y

# 测试Batch Normalization函数
if __name__ == "__main__":
    # 创建一个BatchNormalization层
    bn = BatchNormalization(num_features=5)

    # 创建一个随机输入张量
    x = torch.randn(10, 5)
    
    # 进行前向传播
    y = bn(x)
    
    # 打印输出
    print("Input:\n", x)
    print("Output:\n", y)
