# Pre-Normalization

- Normalization 动机
- Post-Norm 数据分布分析
- Pre-Norm 数据分布分析
- Pre-Norm 与预训练

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

<torch._C.Generator at 0x1134bcdb0>

## Normalization 动机

深度神经网络均有 normalization，如 CV 领域常有 Batch Normalization

在 Resnet 中，结合 short-cut 操作， 形如:

$$
y = \text{Norm}(F(x) + x)
$$

将网络中的数据分布归置正态分布，以帮助模型更快收敛。经典的 Norm 归一化被称之为“后置归一化（Post-Norm）”

In [2]:
def static_mean_var(x):
    bs, seq_len, dim = x.shape
    batch_mean = x.reshape(bs*seq_len, dim).mean(dim = -1).mean()
    batch_var = x.reshape(bs*seq_len, dim).var(dim = -1).mean()
    print('batch_mean/var:', batch_mean.item(), batch_var.item())
    return 

In [3]:
# 数据归一化, 统计所有token其特征维度上的 mean 和 var，的均值。

x = torch.randn(2, 3, 512) # bs, seq_len, dim
print(x[0,0,:].mean(dim = -1)) # 单个 token mean
print(x[0,0,:].var(dim = -1)) # 单个 token var

static_mean_var(x)

tensor(0.0277)
tensor(0.9440)
batch_mean/var: 0.005180692300200462 0.9886613488197327


## PostNorm Model 实现

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, dim = 512):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.eps = 1e-6
        
    def forward(self, x):
        mean = x.mean(dim = -1, keepdim=True)
        var = x.var(dim = -1, keepdim=True)
        x_ = (x - mean) / torch.sqrt(var + self.eps)
        x_ = x_ * self.gamma + self.beta
        return x_

class ReLU(nn.Module):
    def forward(self, x):
        return (x+torch.abs(x)) * 0.5
                   
class PostNormBlock(nn.Module):
    def __init__(self, dim = 512):
        super().__init__()
        self.w1 = nn.Linear(dim, dim)
        self.act = ReLU()
        self.w2 = nn.Linear(dim, dim)
        self.norm = LayerNorm(dim = dim)
    def forward(self, x):
        x_w1 = self.w1(x)
        x_act = self.act(x_w1)
        x_w2 = self.w2(x_act)
        y = self.norm(x_w2+x)
        return y

class PostNormModel(nn.Module):
    def __init__(self, dim = 512, num_layers = 6, num_class = 10):
        super().__init__()
        self.dim = dim
        self.num_class = 10
        self.num_layers = num_layers
        self.blocks = nn.ModuleList(
            [PostNormBlock(dim) for _ in range(self.num_layers)]
        )
        self.head = nn.Linear(self.dim, self.num_class)
        
    def forward(self, x, verbose = False):
        for block in self.blocks:
            x = block(x)
            if verbose:
                print(x.norm().item())
        last_hidden_state = x
        x = x.mean(dim = 1)
        logits = self.head(x) # logits: bs, seq_len, num_class
        return logits, last_hidden_state

model = PostNormModel() 

x = torch.randn(2, 3, 512) 

y, last_hidden_state = model(x, verbose = True)
print(y.shape)

55.37144470214844
55.3714485168457
55.3714485168457
55.3714485168457
55.37144470214844
55.37144088745117
torch.Size([2, 10])


In [5]:
# print(last_hidden_state.reshape(6, 512).mean(dim = -1).mean()) # 批量数据 mean mean
# print(last_hidden_state.reshape(6, 512).var(dim = -1).mean()) # 批量数据 var mean
static_mean_var(last_hidden_state)

batch_mean/var: 4.113341223188627e-09 0.9999990463256836


## PreNorm Model 实现


$$
y = x + F(\text{Norm}(x))
$$

In [6]:
class PreNormBlock(nn.Module):
    def __init__(self, dim = 512):
        super().__init__()
        self.w1 = nn.Linear(dim, dim)
        self.act = ReLU()
        self.w2 = nn.Linear(dim, dim)
        self.norm = LayerNorm(dim = dim)
    def forward(self, x):
        x_ = self.norm(x)
        x_w1 = self.w1(x_)
        x_act = self.act(x_w1)
        x_w2 = self.w2(x_act)
        y = x + x_w2
        return y

class PreNormModel(nn.Module):
    def __init__(self, dim = 512, num_layers = 6, num_class = 10):
        super().__init__()
        self.dim = dim
        self.num_class = 10
        self.num_layers = num_layers
        self.blocks = nn.ModuleList(
            [PreNormBlock(dim) for _ in range(self.num_layers)]
        )
        # 最后一层加入 layernorm
        self.last_norm = LayerNorm(dim = self.dim) 
        self.head = nn.Linear(self.dim, self.num_class)

    def forward(self, x, add_last_norm = False, verbose=False):
        old_norm = x.norm()
        for block in self.blocks:
            x = block(x)
            if verbose:
                print(x.norm().item(), '\t delta:',(x.norm() - old_norm).item())
            old_norm = x.norm()
        last_hidden_state = x
        if add_last_norm:
            last_hidden_state = self.last_norm(x)
            if verbose:
                print('last hidden state norm:', last_hidden_state.norm())
        y = last_hidden_state.mean(dim = 1)
        logits = self.head(y) # logits: bs, seq_len, num_class
        return logits, last_hidden_state


model = PreNormModel() 
x = torch.randn(2, 3, 512)
y, last_hidden_state = model(x, add_last_norm = True, verbose = True)
print(y.shape)
static_mean_var(last_hidden_state)

58.251708984375 	 delta: 1.43743896484375
59.79402160644531 	 delta: 1.5423126220703125
60.97715759277344 	 delta: 1.183135986328125
62.46257019042969 	 delta: 1.48541259765625
63.61300277709961 	 delta: 1.1504325866699219
64.91265106201172 	 delta: 1.2996482849121094
last hidden state norm: tensor(55.3715, grad_fn=<LinalgVectorNormBackward0>)
torch.Size([2, 10])
batch_mean/var: -1.1641532182693481e-09 0.9999993443489075


In [7]:
# 
y, last_hidden_state = model(x, add_last_norm = False, verbose = True)
static_mean_var(last_hidden_state)

58.251708984375 	 delta: 1.43743896484375
59.79402160644531 	 delta: 1.5423126220703125
60.97715759277344 	 delta: 1.183135986328125
62.46257019042969 	 delta: 1.48541259765625
63.61300277709961 	 delta: 1.1504325866699219
64.91265106201172 	 delta: 1.2996482849121094
batch_mean/var: -0.007094651460647583 1.3694543838500977


综上 PreNorm 设置下，不带 last norm 的情况下，方差变大且不为1

## 分析 Norm 


In [8]:
bs = 32
seq_len = 256
dim = 512
x = torch.randn(bs, seq_len, dim)
y = torch.randn(bs, seq_len, dim)
xy = x+y
xy_sqrt2 = xy / math.sqrt(2)
xy_2 = xy / 2
static_mean_var(x)
static_mean_var(y)
static_mean_var(xy)
static_mean_var(xy_sqrt2)
static_mean_var(xy_2) # 方差，标准差均不为1

batch_mean/var: 0.0006401605205610394 1.0001459121704102
batch_mean/var: -0.00033171908580698073 1.0000011920928955
batch_mean/var: 0.0003084417257923633 1.999786138534546
batch_mean/var: 0.00021810112230014056 0.9998931884765625
batch_mean/var: 0.00015422086289618164 0.4999465346336365


## 深层数据变化分析

In [9]:
model = PreNormModel(num_layers = 40) 
x = torch.randn(2, 3, 512)
y, last_hidden_state = model(x, add_last_norm = True, verbose = True)
print(y.shape)

56.46917724609375 	 delta: 1.7455635070800781
57.96708297729492 	 delta: 1.4979057312011719
59.74176788330078 	 delta: 1.7746849060058594
61.387413024902344 	 delta: 1.6456451416015625
62.77439498901367 	 delta: 1.3869819641113281
64.00464630126953 	 delta: 1.2302513122558594
65.04314422607422 	 delta: 1.0384979248046875
66.82162475585938 	 delta: 1.7784805297851562
68.3145980834961 	 delta: 1.4929733276367188
69.6028823852539 	 delta: 1.2882843017578125
70.9032211303711 	 delta: 1.3003387451171875
71.81951141357422 	 delta: 0.916290283203125
73.65388488769531 	 delta: 1.8343734741210938
74.57453918457031 	 delta: 0.920654296875
75.3595962524414 	 delta: 0.7850570678710938
76.23462677001953 	 delta: 0.875030517578125
77.53563690185547 	 delta: 1.3010101318359375
78.72762298583984 	 delta: 1.191986083984375
80.00578308105469 	 delta: 1.2781600952148438
80.94267272949219 	 delta: 0.9368896484375
81.98242950439453 	 delta: 1.0397567749023438
83.35127258300781 	 delta: 1.3688430786132812
8

随着深度变化，其方差越来越小，prenorm设定下，深层网络其层越虚，可视为宽层网络（注意：这是直觉的理解）

## 如何初始化

以下的实现中，哪怕没有 last norm，其 last hidden state 不会偏差太多

In [10]:
class PreNormBlockInit(nn.Module):
    def __init__(self, dim = 512, num_layers = 6):
        super().__init__()
        self.w1 = nn.Linear(dim, dim)
        self.act = ReLU()
        self.w2 = nn.Linear(dim, dim)
        self.norm = LayerNorm(dim = dim)

        # reference
        # transformers/src/transformers/models/gpt2/modeling_gpt2.py
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        
        initializer_range = 0.02
        std = initializer_range / math.sqrt(2 * num_layers)
        nn.init.normal_(self.w1.weight, mean=0.0, std=std)
        nn.init.normal_(self.w2.weight, mean=0.0, std=std)
        
    def forward(self, x):
        x_ = self.norm(x)
        x_w1 = self.w1(x_)
        x_act = self.act(x_w1)
        x_w2 = self.w2(x_act)
        y = x + x_w1
        return y

class PreNormModelInit(nn.Module):
    def __init__(self, dim = 512, num_layers = 6, num_class = 10):
        super().__init__()
        self.dim = dim
        self.num_class = 10
        self.num_layers = num_layers
        self.blocks = nn.ModuleList(
            [PreNormBlockInit(dim, num_layers = self.num_layers) for i in range(self.num_layers)]
        )
        # 最后一层加入 layernorm
        self.last_norm = LayerNorm(dim = self.dim) 
        self.head = nn.Linear(self.dim, self.num_class)

    def forward(self, x, add_last_norm = False, verbose=False):
        old_norm = x.norm()
        for block in self.blocks:
            x = block(x)
            if verbose:
                print(x.norm().item(), '\t delta:',(x.norm() - old_norm).item())
            old_norm = x.norm()
        last_hidden_state = x
        if add_last_norm:
            last_hidden_state = self.last_norm(x)
            if verbose:
                print('last hidden state norm:', last_hidden_state.norm())
        y = last_hidden_state.mean(dim = 1)
        logits = self.head(y) # logits: bs, seq_len, num_class
        return logits, last_hidden_state


model = PreNormModel(num_layers = 40) 
x = torch.randn(2, 3, 512)
y, last_hidden_state = model(x, add_last_norm = True, verbose = True)
static_mean_var(last_hidden_state) # 方差为 1

56.10917282104492 	 delta: 1.1105537414550781
57.41032028198242 	 delta: 1.3011474609375
58.757999420166016 	 delta: 1.3476791381835938
60.30143356323242 	 delta: 1.5434341430664062
61.772682189941406 	 delta: 1.4712486267089844
63.676788330078125 	 delta: 1.9041061401367188
64.64246368408203 	 delta: 0.9656753540039062
65.39842987060547 	 delta: 0.7559661865234375
66.78181457519531 	 delta: 1.3833847045898438
68.16849517822266 	 delta: 1.3866806030273438
69.3364028930664 	 delta: 1.16790771484375
70.53028106689453 	 delta: 1.193878173828125
71.57947540283203 	 delta: 1.0491943359375
73.1681137084961 	 delta: 1.5886383056640625
74.08837127685547 	 delta: 0.920257568359375
75.53704071044922 	 delta: 1.44866943359375
76.9328384399414 	 delta: 1.3957977294921875
77.95159149169922 	 delta: 1.0187530517578125
78.58661651611328 	 delta: 0.6350250244140625
79.57111358642578 	 delta: 0.9844970703125
80.55424499511719 	 delta: 0.9831314086914062
82.01812744140625 	 delta: 1.4638824462890625
83.

In [11]:
model = PreNormModel(num_layers = 40) 
x = torch.randn(2, 3, 512)
y, last_hidden_state = model(x, add_last_norm = False, verbose = True)
static_mean_var(last_hidden_state) # 方差为 1

56.08007049560547 	 delta: 1.649871826171875
57.136348724365234 	 delta: 1.0562782287597656
58.2601203918457 	 delta: 1.1237716674804688
59.80127716064453 	 delta: 1.5411567687988281
61.09972381591797 	 delta: 1.2984466552734375
62.547061920166016 	 delta: 1.4473381042480469
63.97510528564453 	 delta: 1.4280433654785156
65.08590698242188 	 delta: 1.1108016967773438
65.98751068115234 	 delta: 0.9016036987304688
67.04659271240234 	 delta: 1.05908203125
68.25686645507812 	 delta: 1.2102737426757812
69.205322265625 	 delta: 0.948455810546875
70.34056854248047 	 delta: 1.1352462768554688
71.13099670410156 	 delta: 0.7904281616210938
72.20435333251953 	 delta: 1.0733566284179688
73.48302459716797 	 delta: 1.2786712646484375
75.00140380859375 	 delta: 1.5183792114257812
76.28553771972656 	 delta: 1.2841339111328125
77.14724731445312 	 delta: 0.8617095947265625
78.60157012939453 	 delta: 1.4543228149414062
79.80459594726562 	 delta: 1.2030258178710938
80.9670639038086 	 delta: 1.16246795654296

## 补充推导

[浅谈Transformer的初始化、参数化与标准化](https://kexue.fm/archives/8620#残差连接)

该文章推导了 post-norm 和 pre-norm 的数据分布的方差变化。结论：

1. pre-norm 最后一层要加 norm
2. pre-norm 每层接受的原始输入分量是一样的
3. post-norm 随深度增加，原始输入分量会随着减少

本文代码从特征维度的方差进行分析，代码按照深层 normalization 对数据分布影响。

## 结论

1. PostNorm 实际上分布一直在归一化，PreNorm 在已学习好的分布上，继续强化，在深层上方差变化较小，增加深度收益变小
2. PreNorm 在预训练中可以不 warmup，PreNorm 更容易训练
3. PreNorm 在首层输入即进行了 norm， 对于 transformer 输入来说， 存在 E+PE，该数据不服从 0,1 正态分布
4. PreNorm 为当前 Transformer 类网络的标准归一化手段。