## Flash attention in python

In [2]:
import torch

In [3]:
def normal_attention(Q, K, V, mask=None):
    '''
    shape of K, Q, V : (seq len, d)
    mask: (seq len)
    '''
    # sacle = 1 / sqrt(d)
    scale = Q.size(-1) ** -0.5
    scale = 1
    Q *= scale
    # Q K^T
    QK = Q @ K.T
    # mask
    if mask is not None:
        QK.masked_fill_(mask == 0, float('-inf'))
    # softmax
    QK = QK.softmax(dim=-1)
    # QK V
    return QK @ V

In [4]:
bs = 512
def flash_attention_v1(Q, K, V, mask = None):
    '''
    forward attention
    shape of K, Q, V : (seq len, d)
    mask: (seq len)
    '''
    out = torch.zeros_like(Q)
    scale = Q.size(-1) ** -0.5

    block_size = min(bs, Q.size(0))

    m_prev = -torch.ones(block_size) * float('-inf')
    m_curr = -torch.ones(block_size) * float('-inf')
    d_prev = torch.zeros(block_size)
    d_curr = torch.zeros(block_size)

    for j in range(Q.size(0) // block_size):
        # Q
        d_curr.zero_()
        d_prev.zero_()
        m_prev.fill_(float('-inf'))
        m_curr.fill_(float('-inf'))
        Q_block = Q[j*block_size:(j+1)*block_size] * scale
        for i in range(Q.size(0) // block_size):
            # K, V
            K_block = K[i*block_size:(i+1)*block_size]
            V_block = V[i*block_size:(i+1)*block_size]
            qk = Q_block @ K_block.T # qk: (q_block_size, Kv_block_size)
            # get new maximum m from qk
            qk_max, _ = torch.max(qk, dim=0)
            m_curr = torch.max(m_prev, qk_max)
            # update old d
            d_prev *= torch.exp(m_prev - m_curr)
            p = torch.exp(qk - m_curr[:, None])
            d_curr = d_prev + p.sum(dim=1)
            # update out
            d_inv = 1. / d_curr
            p *= d_inv
            out[j*block_size:(j+1)*block_size] *= (d_prev * d_inv)[:, None]
            out[j*block_size:(j+1)*block_size] += p @ V_block
            # update m_prev, d_prev
            m_prev = m_curr
            d_prev = d_curr
            
    return out

In [6]:
# record time cost
import time
hf_time = 0
flash_time = 0
start = time.time()
for i in range(10):
    V = torch.randn(4096, 1024)
    t1 = time.time()
    A = normal_attention(V, V, V)
    hf_time += (time.time() - t1)
    t2 = time.time()
    B = flash_attention_v1(V, V, V)
    flash_time += (time.time() - t2)
print('hf time cost: ', hf_time)
print('flash time cost: ', flash_time)


hf time cost:  6.817880868911743
flash time cost:  6.2721312046051025
