In [1]:
import torch 
import numpy as np

def softmax(x):
    e_x = torch.exp(x - torch.max(x, dim=-1).values[:, None])
    return e_x / torch.sum(e_x, dim=-1)[:, None]

def compute_attention(Q, K, V):
    attention_scores = Q @ K.T
    attention_weights = softmax((attention_scores) * tau)
    O = torch.matmul(attention_weights, V)
    return O

def flash_attention(Q, K, V):
    
    O = torch.zeros((N, d)).float()
    L = torch.zeros((N, )).float()
    m = -np.inf * torch.ones(N).float()
    l = torch.zeros(N).float()
    
    T_r = int(np.ceil(N / B_r))
    T_c = int(np.ceil(N / B_c))
    
    for i in range(T_r):
        Q_i = Q[i * B_r:(i + 1) * B_r]
        
        m_ij_prev = m[i * B_r:(i + 1) * B_r]
        l_ij_prev = l[i * B_r:(i + 1) * B_r]
        O_ij_prev = O[i * B_r:(i + 1) * B_r]
        
    
        for j in range(T_c):
                
            K_j = K[j * B_c:(j + 1) * B_c]
            V_j = V[j * B_c:(j + 1) * B_c]
        
            S_ij = tau * (Q_i @ K_j.T) 

            m_ij = torch.maximum(m_ij_prev, torch.max(S_ij, dim=1).values)
            P_ij = torch.exp(S_ij - m_ij[:, None]) 
            l_ij = torch.exp(m_ij_prev - m_ij) * l_ij_prev + torch.sum(P_ij, dim=1)
                
            O_ij = torch.exp(m_ij_prev - m_ij)[:, None] * O_ij_prev + (P_ij @ V_j)
            
            O_ij_prev = O_ij
            m_ij_prev = m_ij
            l_ij_prev = l_ij
            
        O[i * B_r:(i + 1) * B_r] = (1 / l_ij_prev)[:, None] * O_ij_prev 
        L[i * B_r:(i + 1) * B_r] = m_ij_prev + torch.log(l_ij_prev)
        
    return O, L

def flash_attention_backward(Q, K, V, O_flash, dO): 
    
    dO = dO.double()
    dQ = torch.zeros_like(Q).double()
    dK = torch.zeros_like(K).double()
    dV = torch.zeros_like(V).double()
    T_r = int(np.ceil(N / B_r))
    T_c = int(np.ceil(N / B_c))
    
    D = torch.sum(dO * O_flash, dim=1)
    
    for j in range(T_c):
        K_j = K[j * B_c:(j + 1) * B_c]
        V_j = V[j * B_c:(j + 1) * B_c]
        dK_j = dK[j * B_c:(j + 1) * B_c]
        dV_j = dV[j * B_c:(j + 1) * B_c]
        for i in range(T_r):
            Q_i = Q[i * B_r:(i + 1) * B_r]
            O_i = O_flash[i * B_r:(i + 1) * B_r]
            dQ_i = dQ[i * B_r:(i + 1) * B_r]
            dO_i = dO[i * B_r:(i + 1) * B_r]
            L_i = L[i * B_r:(i + 1) * B_r]
            D_i = D[i * B_r:(i + 1) * B_r]
            # m_i = m[i * B_r:(i + 1) * B_r]
            
            S_ij = tau * (Q_i @ K_j.T) 
            # P_ij = (1.0/l_i)[:, None] * torch.exp(S_ij - m_i[:, None])  
            P_ij = torch.exp(S_ij - L_i[:, None])  
            

    
            dV_j = dV_j + (P_ij.T @ dO_i)
            dP_ij = dO_i @ V_j.T
    
            dS_ij = P_ij * (dP_ij - D_i[:, None])
            #dQ_i = dQ_i + (tau * dS_ij @ K_j)
    
            dQ[i * B_r:(i + 1) * B_r] = dQ_i + (tau * dS_ij @ K_j)
            dK_j = dK_j + tau * dS_ij.T @ Q_i
            
        dK[j * B_c:(j + 1) * B_c] = dK_j
        dV[j * B_c:(j + 1) * B_c] = dV_j
    return dQ, dK, dV

In [2]:
N = 334
d = 233
M = d * 25
tau = 1.0/np.sqrt(d)

np.random.seed(2)
Q = np.random.randn(N, d)
K = np.random.randn(N, d)
V = np.random.randn(N, d)

Q_standard = torch.tensor(Q.copy(), requires_grad=True, dtype=torch.float64)
K_standard = torch.tensor(K.copy(), requires_grad=True, dtype=torch.float64)
V_standard = torch.tensor(V.copy(), requires_grad=True, dtype=torch.float64)

O_standard = compute_attention(Q_standard, K_standard, V_standard)
loss = O_standard.sum()
loss.backward()

Q = torch.tensor(Q, dtype=torch.float64)
K = torch.tensor(K, dtype=torch.float64)
V = torch.tensor(V, dtype=torch.float64)

on_chip_memory_size = M
B_c = on_chip_memory_size // (4 * d)  # Using 4 bytes per float
B_r = min(on_chip_memory_size // (4 * d), d)


O_flash, L = flash_attention(Q, K, V)
dO = torch.ones_like(O_flash)
dQ, dK, dV = flash_attention_backward(Q, K, V, O_flash, dO)


print("Forward O", np.allclose(O_standard.detach().cpu().numpy(), O_flash.detach().cpu().numpy(), rtol=1.e-4, atol=1.e-4))
print("Backward V",np.allclose(V_standard.grad.detach().cpu().numpy(), dV.detach().cpu().numpy(), rtol=1.e-4, atol=1.e-4))
print("Backward K",np.allclose(K_standard.grad.detach().cpu().numpy(), dK.detach().cpu().numpy(), rtol=1.e-4, atol=1.e-4))
print("Backward Q",np.allclose(Q_standard.grad.detach().cpu().numpy(), dQ.detach().cpu().numpy(), rtol=1.e-4, atol=1.e-4))


Forward O True
Backward V True
Backward K True
Backward Q True
