# 前言

动机： 
带宽,NRAM 19TB/s，dBM 1.5TB/s，快了 12.67倍，但是只有20MB可以用
例如：
NRAM能防止 10000 个数据元素，假设Q[100,100],Q[100,100],那么加载Q，K则需要 2x100x100 = 20000个数据，则需要对dBM进行write/read。

- split Q -> [Q1[50,100] , Q2[50,100]]
- split K -> [K1[50,100] , K1[50,100]]

N_ij = Qi @ Kj, 最后想要求得完整的

N = [[N00,N01] , [N10,N11]]

目的：
解决NRAM <--> dBM成为访存瓶颈，
传统的Attn 需要7次交换，如果能拆成20MB放得下的矩阵块，就能加速7.6倍的访存时间消耗

# Flash Attn原理

safe softmax >> 3-pass online softmax >> 2-pass online softmax >>

online softmax self-attention >> flash attn >> flash attn(tiling) 


## 2-pass online softmax 

In [None]:

import torch


"""1d"""
N = 4
x = torch.arange(N,dtype = torch.float32)
x_clone = x.clone()

#更新x_max与x_sum
x_max_old = torch.tensor(-1e6)
x_sum_old = torch.tensor(0.0)
for i in range(N):
    x_max_new =  torch.max(x_max_old,x[i]) 
    x_sum_new = x_sum_old * torch.exp(x_max_old - x_max_new) + torch.exp(x[i] - x_max_new) #这里这个合并就是 2-pass
    x_max_old = x_max_new
    x_sum_old = x_sum_new
#更新x
for i in range(N):
    x[i] = (x[i] - x_max_old).exp()/x_sum_old

import torch.nn.functional as F
x_direct = F.softmax(x_clone,dim =0)
print(f"x is {x},x_direct is {x_direct}")
assert torch.allclose(x,x_direct)



"""2d"""
# x = torch.arange(16,dtype=torch.float32).reshape(4,4)
# x_exp  = torch.exp(x)
# print(f"x_exp is {x_exp}")

x is tensor([0.0321, 0.0871, 0.2369, 0.6439]),x_direct is tensor([0.0321, 0.0871, 0.2369, 0.6439])


'2d'

# 2-pass self-attn >> 1-pass self-attn

In [None]:
if None:

    N = 4
    d = 2

    Q = torch.arange(N*d,dtype=torch.float32).reshape(N,d)
    K = torch.arange(N*d,dtype=torch.float32).reshape(N,d)
    V = torch.arange(N*d,dtype=torch.float32).reshape(N,d)

    for i in range(N): #外圈 K,V
        # for k in range(): #内圈Q 
        x_i = Q[k,:] @ K[:,i] #k是啥
        m_i = max(m_i,x_i) #最大值
        d_i = d_i * torch.exp(m_(i-1) - m_i) + torch.exp(x_i - m_i) #和
    for i in range(N):
        a_i =  torch.exp(x_i) - m_i / d_N
        o_i = o_(i-1) + a_i*V[i,:]

#1-pass flash-attn 

    for i in range(N): #外圈 K,V
        x_i = Q[k,:] @ K[:,i] #k是啥
        m_i = max(m_i,x_i) #最大值
        d_i = d_i * torch.exp(m_(i-1) - m_i) + torch.exp(x_i - m_i) #和

        o_i = d(i-1) * torch.exp(m_(i-1) - m_i) /d_i + torch.exp(x_i - m_i)/d_i * V[i,:]

## Flash-attn实现

In [23]:
#Standard self-attn
import torch

NEG_INF = -1e10
EPSILON = 1e-10

"""
SRAM size is M
"""
#原始的 QKV尺寸  6x4 
N = 6
d = 2
#创建 QKV，Olm等矩阵
Q = torch.randn(1,1,N,d,requires_grad=True)
K = torch.randn(1,1,N,d,requires_grad=True)
V = torch.randn(1,1,N,d,requires_grad=True)

import torch.nn.functional as F
S = F.softmax(torch.einsum('...id,...jd -> ... ij',Q,K),dim = -1)
print(f"S.shape is {S.shape}")
O = torch.einsum('...Nk,...kd -> ... Nd',S,V)
print(f"O.shape is {O.shape}")


S.shape is torch.Size([1, 1, 6, 6])
O.shape is torch.Size([1, 1, 6, 2])


In [35]:
import torch

NEG_INF = -1e10
EPSILON = 1e-10

"""
SRAM size is M
"""
#原始的 QKV尺寸  6x4 
N = 6
d = 2

#切分求解 block_size ,Bc = M/4d,Br = min(M/4d,d)
Q_block_size = 2 #6//2,
KV_block_size = 2
#根据block_size得到，Q,K,V 切出来，Tr行和Tc列
Tr = N//Q_block_size
Tc = N//KV_block_size 

#创建 QKV，Olm等矩阵
Q = torch.randn(1,1,N,d,requires_grad=True)
K = torch.randn(1,1,N,d,requires_grad=True)
V = torch.randn(1,1,N,d,requires_grad=True)

O = torch.zeros_like(Q, requires_grad=True)
l = torch.zeros(Q.shape[:-1])[..., None] #删减后在增减一个为1的维度 
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF
print(f"l.shape is {l.shape},m.shape is {m.shape}")

#切分成tiling
Q_blocks = torch.split(Q,Q_block_size,dim = 2) #沿着序列维度去切分他,变成两个元组了，这里split第二个参数是 尺寸大小 
# print(f"Q_blocks is {Q_blocks}") 
K_blocks = torch.split(K,KV_block_size,dim = 2)
V_blocks = torch.split(V,KV_block_size,dim = 2)

O_blocks = list(torch.split(O,Q_block_size,dim = 2))
l_blocks = list(torch.split(l,Q_block_size,dim = 2))
m_blocks = list(torch.split(m,Q_block_size,dim = 2))

for j in range(Tc): #列循环 
    Kj = K_blocks[j] #加载进来 K和V
    Vj = V_blocks[j]
    for i in range(Tr): #行循环 
        Qi = Q_blocks[i]
        Oi = O_blocks[i]
        li = l_blocks[i]
        mi = m_blocks[i]

        S_ij = torch.einsum('... i d, ... j d -> ... i j',Qi,Kj) #einsum和@有什么区别？
        print(f"(i,j) is ({i},{j}),S_ij.shape is {S_ij.shape}")
        #对S解决 batch-block-online-softmax
        m_block_ij,_ = torch.max(S_ij,dim = -1,keepdims = True) #行最大值
        P_ij = torch.exp(S_ij - m_block_ij) #求个exp(x-m),P_ij就只是把S_ij转换成概率
        l_block_ij = torch.sum(P_ij,dim = -1,keepdim = True) + EPSILON #行求和，防止是0所以加个极小值，要做除数 

        #更新行的最大值 与 和
        mi_new = torch.maximum(m_block_ij,mi) #行最大值 与 同行的第j列求最大值 
        li_new = li * torch.exp(mi-mi_new) + torch.exp(m_block_ij - mi_new) * l_block_ij #行前和修正 + 当前块修正
        
        #求解Oi，Oi修正 + P_ij_Vj修正(这里 P_ij_Vj代表exp(Sij)和Vj直接乘起来了)
        P_ij_Vj = torch.einsum('...ij,...jd -> ...id',P_ij,Vj) #与Vj乘起来，但是后面需要修正mi以及除以li_new
        O_blocks[i] =  (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj

        l_blocks[i] = li_new
        m_blocks[i] = mi_new

O = torch.cat(O_blocks,dim = 2)
l = torch.cat(l_blocks,dim = 2)
m = torch.cat(m_blocks,dim = 2)

print(f"O.shape is {O.shape}")        


import torch.nn.functional as F
Sd = F.softmax(torch.einsum('...id,...jd -> ... ij',Q,K),dim = -1)
Od = torch.einsum('...Nk,...kd -> ... Nd',Sd,V)
print(f"Od.shape is {Od.shape}")

assert torch.allclose(Od,O,atol=1e-5)

l.shape is torch.Size([1, 1, 6, 1]),m.shape is torch.Size([1, 1, 6, 1])
(i,j) is (0,0),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (1,0),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (2,0),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (0,1),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (1,1),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (2,1),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (0,2),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (1,2),S_ij.shape is torch.Size([1, 1, 2, 2])
(i,j) is (2,2),S_ij.shape is torch.Size([1, 1, 2, 2])
O.shape is torch.Size([1, 1, 6, 2])
Od.shape is torch.Size([1, 1, 6, 2])
