相比于 Batch Norm，Layer Norm (LN) 其实更好写，逻辑也更简单。

因为它不需要维护 Running Mean/Var，训练和推理的逻辑是完全一样的。

In [1]:
import torch
import torch.nn as nn


In [4]:
class MyLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super(MyLayerNorm, self).__init__()

        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)

        self.eps = eps
        self.normalized_shape = normalized_shape

        #1 即可学习参数
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

        # 注意：LayerNorm 不需要 register_buffer 存 running_mean/var！

    def forward(self,x):
        '''
        :param x: [B, Seq_len, Hidden_Dim]
        :return:
        '''

        # 1. 确定要在哪些维度上求均值和方差
        # 如果 normalized_shape 是 (D,), 那么就是对倒数第 1 个维度求均值
        # 如果 normalized_shape 是 (H, W), 那么就是对倒数第 2, 1 个维度求均值
        dims = tuple(range(len(x.shape) - len(self.normalized_shape), len(x.shape)))

        #计算均值和方差
        mean = x.mean(dim=dims, keepdim=True)
        var = x.var(dim=dims, keepdim=True,unbiased=False)

        #3 归一化
        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        # 4. 仿射变换
        # PyTorch 会自动广播 gamma/beta 到 x_norm 的形状
        out = self.gamma * x_norm + self.beta

        return out





验证代码

In [5]:
# --- 验证代码 ---
# 模拟 Transformer 输入: Batch=2, Seq_Len=5, Dim=4
x = torch.randn(2, 5, 4)

# 实例化
dim = 4
my_ln = MyLayerNorm(dim)
torch_ln = nn.LayerNorm(dim)

# 统一权重
torch_ln.weight.data = my_ln.gamma.data
torch_ln.bias.data = my_ln.beta.data

# 验证输出
out_my = my_ln(x)
out_torch = torch_ln(x)

print("LayerNorm 误差:", (out_my - out_torch).abs().max().item())

LayerNorm 误差: 2.384185791015625e-07
