# 手撕 Flash Attention Pytorch实现

![](./flash.png)

## 1. 设定分块矩阵

In [3]:
import torch
from einops import rearrange

NEG_INF = -1e10  # -infinity
EPSILON = 1e-10

Q_LEN = 6
K_LEN = 6
Q_BLOCK_SIZE = 3 # 
KV_BLOCK_SIZE = 3
Tr = Q_LEN // Q_BLOCK_SIZE
Tc = K_LEN // KV_BLOCK_SIZE

Q = torch.randn(1, 1, Q_LEN, 4, requires_grad=True).to(device='cpu')
K = torch.randn(1, 1, K_LEN, 4, requires_grad=True).to(device='cpu')
V = torch.randn(1, 1, K_LEN, 4, requires_grad=True).to(device='cpu')
O = torch.zeros_like(Q, requires_grad=True)
l = torch.zeros(Q.shape[:-1])[..., None]
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF

Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))

## 2. 计算Flash Attention

In [5]:
# 先 KV 后 Q 
for j in range(Tc):
    Kj = K_BLOCKS[j]
    Vj = V_BLOCKS[j]
    for i in range(Tr):
        Qi = Q_BLOCKS[i]
        Oi = O_BLOCKS[i]
        li = l_BLOCKS[i]
        mi = m_BLOCKS[i]

        S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi, Kj)
        m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
        P_ij = torch.exp(S_ij - m_block_ij)
        l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
        P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
        
        mi_new = torch.maximum(m_block_ij, mi)
        
        li_new = torch.exp(mi - mi_new) * li  \
               + torch.exp(m_block_ij - mi_new) * l_block_ij 

        O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi \
                    +(torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
        print(f'-----------Attn : Q{i}xK{j}---------')
#         print(O_BLOCKS[i].shape)
        print(O_BLOCKS[0])
        print(O_BLOCKS[1])
        print('\n')
        
        l_BLOCKS[i] = li_new
        m_BLOCKS[i] = mi_new

O = torch.cat(O_BLOCKS, dim=2)
l = torch.cat(l_BLOCKS, dim=2)
m = torch.cat(m_BLOCKS, dim=2)

# Flash Attention2

Flash Attntion2 相较1最终要的特点就是改变内外循环，从而减少O的交换

![](./flash2.png)

In [6]:
# author: 小冬瓜AIGC
import torch
from einops import rearrange

Q_BLOCK_SIZE = 3
KV_BLOCK_SIZE = 3
NEG_INF = -1e10  # -infinity
EPSILON = 1e-10
Q_LEN = 6
K_LEN = 6
Tr = Q_LEN // Q_BLOCK_SIZE # Q循环
Tc = K_LEN // KV_BLOCK_SIZE # KV 循环 

Q = torch.randn(1, 1, 6, 4, requires_grad=True).to(device='cpu')
K = torch.randn(1, 1, 6, 4, requires_grad=True).to(device='cpu')
V = torch.randn(1, 1, 6, 4, requires_grad=True).to(device='cpu')
O = torch.zeros_like(Q, requires_grad=True)
l = torch.zeros(Q.shape[:-1])[..., None]
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF

Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))

Flash Attention 每个步骤都要做scaled处理， Flash Attention2 只在最后做scaled

$$
\begin{align} O^{(2)}&=diag(l^{(1)}/l^{(2)})^{-1}O^{(1)}+diag(l^{(2)})^{-1}e^{S^{(2)}-m^{(2)}}V^{(2)} \end{align}
$$

![](./flash_scaled.png)_

$$
\begin{align} \widetilde{O}^{(2)} &= diag(l^{(1)})^{-1}O^{(1)}+e^{S^{(2)}-m^{(2)}}V^{(2)} \\ O^{(2)} &= diag(l^{(2)})^{-1}\widetilde{O}^{(2)} \\  O^{(N)} &= diag(l^{(N)})^{-1}\widetilde{O}^{(N)} \end{align}
$$

In [2]:
# start with Q
for i in range(Tr):
    Qi = Q_BLOCKS[i]
    Oi = O_BLOCKS[i]
    li = l_BLOCKS[i]
    mi = m_BLOCKS[i]
    
    for j in range(Tc):
        #if j>i: 
        #    continue    # ignore masked      
        Kj = K_BLOCKS[j]
        Vj = V_BLOCKS[j]

        S_ij = Qi @ Kj.transpose(2,3)
        m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
        mi_new = torch.maximum(m_block_ij, mi)
        P_ij_hat = torch.exp(S_ij - mi_new)
        l_block_ij = torch.sum(P_ij_hat, dim=-1, keepdims=True) + EPSILON
        li_new = torch.exp(mi - mi_new) * li  + l_block_ij 
        O_i = torch.exp(mi - mi_new) * Oi + P_ij_hat @ Vj
          
        print(f'-----------O{i} = attn( Q{i}, KV[{j}])---------')
        print(O_i)
        
    O_BLOCKS[i] = O_i / li_new # 最后做Scaled
    l_BLOCKS[i] = li_new
    m_BLOCKS[i] = mi_new
    
O = torch.cat(O_BLOCKS, dim=2)
l = torch.cat(l_BLOCKS, dim=2)
m = torch.cat(m_BLOCKS, dim=2)