[霍华德-解密RWKV线性注意力的进化过程](https://www.bilibili.com/video/BV1zW4y1D7Qg/?spm_id_from=333.999.0.0&vd_source=b843d04bfd7e977261b41de336930b9f)

In [1]:
import torch

In [10]:
T, D = 4, 5

In [11]:
Q, K, V = torch.randn(T, D), torch.randn(T, D), torch.randn(T, D)

## Transformer Attention 向量版本

$$
Attn(Q,K,V)_t = \frac{\sum_{i=0}^Te^{q^T_tk_i}v_i}{\sum_{i=0}^Te^{q^T_tk_i}}
$$

In [12]:
O = torch.zeros(T, D)
for t in range(T):
    Z = 0
    ot = torch.zeros(D)
    for i in range(T):
        attn = Q[t] @ K[i]
        attn = attn.exp()
        ot += attn * V[i]
        Z += attn
    ot = ot / Z
    O[t] = ot
print(O)

tensor([[-0.6032,  1.6079, -0.1444,  0.7193,  0.2736],
        [-0.3621, -1.4309, -1.9961, -1.0469, -0.0746],
        [-0.3389, -1.3303, -1.6987, -0.9762, -0.0429],
        [-0.3587, -1.2902, -1.5771, -0.9975, -0.0362]])


In [13]:
# 矩阵版本
res = torch.softmax(Q @ K.t(), dim=-1)
res = res @ V
res

tensor([[-0.6032,  1.6079, -0.1444,  0.7193,  0.2736],
        [-0.3621, -1.4309, -1.9961, -1.0469, -0.0746],
        [-0.3389, -1.3303, -1.6987, -0.9762, -0.0429],
        [-0.3587, -1.2902, -1.5771, -0.9975, -0.0362]])

## AFT Attention

$$
Attn^+(W, K, V)_t = \frac{\sum_{i=0}^Te^{w_{t,i}+k_i}v_i}{\sum_{i=0}^Te^{w_{t,i}+k_i}}
$$

In [15]:
O_aft = torch.zeros(T, D)
W = torch.randn(T, T)

for t in range(T):
    Z = 0
    ot = torch.zeros(D)
    for i in range(T):
        attn = W[t][i] + K[i]
        attn = attn.exp()
        ot += attn * V[i]
        Z += attn
    ot = ot / Z
    O[t] = ot
O

tensor([[-0.3950, -0.3049, -0.1834, -1.0573,  0.0850],
        [-0.3483, -1.0490, -1.2270, -0.7161,  0.0475],
        [-0.3572, -0.5709, -0.7710, -0.6966,  0.1757],
        [-0.3642, -1.1184, -1.4454, -1.0483,  0.0406]])

## RWKV Attention

$$
wkv_t = \frac{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_i+e^{u+k_t}\odot v_t}{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_i+e^{u+k_t}}
$$

In [21]:
W_rwkv = torch.randn(D)
U_rwkv = torch.randn(D) 

In [24]:
O_rwkv = torch.zeros(T, D)
for t in range(T):
    Z_rwkv = 0
    o_rwkv = torch.zeros(D)
    # 计算前 t-1个token的attn结果，其中不同位置会 有不同的W的取值，分别是 U，0, -W, -2W, ...
    for i in range(t):
        attn_rwkv = (-(t-1-i) * W_rwkv + K[i]).exp()
        o_rwkv += attn_rwkv * V[i]
        Z_rwkv += attn_rwkv
    # t这个位置使用U作为W权重
    div_x = o_rwkv + (U_rwkv + K[t]).exp() * V[t]
    div_y = Z_rwkv + (U_rwkv  + K[t]).exp()
    O_rwkv[t] = div_x / div_y
O_rwkv    
# 这里的计算需要依赖于前 t-1个token的值，但是除了位置向量其它的值都不会因为位置的变化而变化，所以是可以改变计算的形式的

tensor([[-1.6146,  1.0547,  1.1158, -1.6478,  0.0832],
        [-0.7832, -0.4170,  1.0957, -0.8807,  0.1356],
        [-0.3561, -1.3810, -0.2674, -1.2018, -0.0232],
        [-0.3593,  0.3675, -1.1717, -1.2065,  0.1966]])

## RWKV 的RNN形式

$$
wkv_t = \frac{a_{t-1}+e^{u+k_t}\odot v_t}{b_{t-1}+e^{u+k_t}}
$$

其中：
$$
a_{t} = e^{-w}\odot a_{t-1} + e^{k_t}\odot v_t
$$

$$
b_{t} = e^{-w}\odot b_{t-1} + e^{k_t}
$$

In [26]:
O_rwkv = torch.zeros(T, D)
O_rwkv[0] = V[0]

a = K[0].exp() * V[0]
b = K[0].exp()

for t in range(1, T):
    O_rwkv[t] = (a+(U_rwkv + K[t]).exp()*V[t]) / (b+(U_rwkv + K[t]).exp())
    a = -W_rwkv.exp() * a + K[t].exp() * V[t]
    b = -W_rwkv.exp() * b + K[t].exp()
O_rwkv # 可以看到递归形式的结果和上面的一致

tensor([[-1.6146,  1.0547,  1.1158, -1.6478,  0.0832],
        [-0.7832, -0.4170,  1.0957, -0.8807,  0.1356],
        [-0.2076, -1.7188,  2.5090, -0.6300, -0.1688],
        [-0.5078,  0.6518,  0.1357, -1.6221,  0.1939]])

## RWKV 数值稳定版本

fp16 和 bf16 中exp(k)很容易溢出

In [27]:
O_rwkv = torch.zeros(T, D)
O_rwkv[0] = V[0]

pp = K[0]
aa = V[0]
bb = 1

for t in range(1, T):
    ww = U_rwkv + K[t]
    
    qq = torch.max(ww, pp) # 取ww和pp的最大值
    # 都减去最大值防止溢出，相对大小没有改变
    e1 = torch.exp(pp - qq)
    e2 = torch.exp(ww - qq)
    
    a = e1 * aa + e2 * V[t]
    b = e1 * bb + e2
    O_rwkv[t] = a / b
    
    # 缩放cache值 a 和 b
    ww = pp - W_rwkv 
    qq = torch.maximum(ww, K[t])
    e1 = torch.exp(ww - qq)
    e2 = torch.exp(K[t] - qq)
    aa = e1 * aa + e2 * V[t]
    bb = e1 * bb + e2
    pp = qq
O_rwkv 

tensor([[-1.6146,  1.0547,  1.1158, -1.6478,  0.0832],
        [-0.7832, -0.4170,  1.0957, -0.8807,  0.1356],
        [-0.3561, -1.3810, -0.2674, -1.2018, -0.0232],
        [-0.3593,  0.3675, -1.1717, -1.2065,  0.1966]])