In [None]:
import numpy as np

class BatchNormLayer:
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)

    def forward(self, x, training=True):
        # 修正：BatchNorm应该在batch和spatial维度上计算统计量
        # 对于3D输入 (batch_size, seq_len, features)，在axis=(0,1)上计算
        if training:
            batch_mean = np.mean(x, axis=(0, 1))  # 修正：在batch和sequence维度上求均值
            batch_var = np.var(x, axis=(0, 1))   # 修正：在batch和sequence维度上求方差
            x_normalized = (x - batch_mean) / np.sqrt(batch_var + self.eps)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
        else:
            x_normalized = (x - self.running_mean) / np.sqrt(self.running_var + self.eps)
        
        out = self.gamma * x_normalized + self.beta
        return out

class LayerNormLayer:
    def __init__(self, num_features, eps=1e-5):
        self.num_features = num_features
        self.eps = eps
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)

    def forward(self, x):
        # LayerNorm在最后一个维度(特征维度)上计算统计量
        mean = np.mean(x, axis=-1, keepdims=True)  # 修正：在特征维度上求均值
        var = np.var(x, axis=-1, keepdims=True)    # 修正：在特征维度上求方差
        x_normalized = (x - mean) / np.sqrt(var + self.eps)
        out = self.gamma * x_normalized + self.beta
        return out

class GroupNormLayer:
    def __init__(self, num_features, num_groups=32, eps=1e-5):
        assert num_features % num_groups == 0, "num_features must be divisible by num_groups"
        self.num_features = num_features
        self.num_groups = num_groups
        self.eps = eps
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)

    def forward(self, x):
        N, C = x.shape[0], x.shape[-1]
        G = self.num_groups
        x = x.reshape(N, -1, G, C // G)  # Reshape to (N, *, G, C//G)
        
        mean = np.mean(x, axis=(1, 3), keepdims=True)  # Mean over (spatial and feature) dimensions
        var = np.var(x, axis=(1, 3), keepdims=True)    # Variance over (spatial and feature) dimensions
        
        x_normalized = (x - mean) / np.sqrt(var + self.eps)
        x_normalized = x_normalized.reshape(N, -1, C)  # Reshape back to original shape
        
        out = self.gamma * x_normalized + self.beta
        return out
    

# 测试代码
batch_data = np.random.randn(4, 8, 64)
print("输入数据形状:", batch_data.shape)

# BatchNorm: 在批次和序列维度上标准化
batch_norm = BatchNormLayer(num_features=64)
batch_norm_output = batch_norm.forward(batch_data)

# LayerNorm: 在特征维度上标准化  
layer_norm = LayerNormLayer(num_features=64)
layer_norm_output = layer_norm.forward(batch_data)

print(f"BatchNorm输出形状: {batch_norm_output.shape}")
print(f"LayerNorm输出形状: {layer_norm_output.shape}")

# 验证标准化效果
print(f"\nBatchNorm后每个特征的均值(应该接近0): {np.mean(batch_norm_output, axis=(0,1))[:5]}")
print(f"LayerNorm后每个样本的均值(应该接近0): {np.mean(layer_norm_output, axis=-1)[0, :3]}")

Batch Norm Output:
 [[[ 1.70863676 -0.13629827 -1.72076935 ...  1.37047155  1.5369013
   -1.06560742]
  [-1.20663801 -0.84436428 -1.68322335 ... -0.54353955  1.09234572
   -0.04310508]
  [ 0.60521268 -1.09602991 -1.09374313 ... -0.48319579 -1.59319968
    0.32805316]
  ...
  [-1.28043934 -0.86907098 -1.49960198 ...  0.41882009  0.14630969
   -1.5746244 ]
  [-0.15961256  0.21473206  1.06304736 ... -0.56396214  1.17502287
    0.81103531]
  [ 1.33351844 -1.26034282 -0.35312419 ...  0.54800968  0.62573213
    0.07953197]]

 [[-0.83197789  1.36378319  0.53981356 ... -1.44848487  0.21912869
   -0.90460137]
  [-0.15420669 -0.40543469  0.55032146 ...  0.62326496  0.51842754
   -1.60165968]
  [-0.06761847  1.13465329  1.17519583 ...  1.72640612  0.08282305
    0.22611451]
  ...
  [-0.4092746   1.51054338  0.14379908 ...  1.34138976  0.9147942
   -0.12218639]
  [ 1.19074324 -1.68765569  0.52295837 ... -1.35556804  0.22664898
   -1.70986371]
  [-0.61642744  1.23352504  1.22393031 ...  0.44433029 

In [4]:
# 举例讲解batch_mean = np.mean(x, axis=(0, 1))  # 修正：在batch和sequence维度上求均值
simple_data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # shape (2, 2, 2)
print("\n简单数据:\n", simple_data)
# sim batchnorm
batch_mean = np.mean(simple_data, axis=(0, 1))  # 在batch和sequence维度上求均值
batch_var = np.var(simple_data, axis=(0, 1))    # 在batch和sequence维度上求方差
print("BatchNorm均值:\n", batch_mean)
print("BatchNorm方差:\n", batch_var)

# sim layernorm
layer_mean = np.mean(simple_data, axis=-1, keepdims=True)  # 在特征维度上求均值
layer_var = np.var(simple_data, axis=-1, keepdims=True)    # 在特征维度上求方差
print("LayerNorm均值:\n", layer_mean)
print("LayerNorm方差:\n", layer_var)


简单数据:
 [[[1 2]
  [3 4]]

 [[5 6]
  [7 8]]]
BatchNorm均值:
 [4. 5.]
BatchNorm方差:
 [5. 5.]
LayerNorm均值:
 [[[1.5]
  [3.5]]

 [[5.5]
  [7.5]]]
LayerNorm方差:
 [[[0.25]
  [0.25]]

 [[0.25]
  [0.25]]]
