In [18]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def compute_attention(Q, K, V):
    attention_scores = np.matmul(Q, K.T)
    
    max_scores = np.max(attention_scores, axis=1, keepdims=True)
    L = max_scores[:, 0] + np.log(np.sum(np.exp(attention_scores - max_scores), axis=1))
    
    attention_weights = softmax(attention_scores)
    O = np.matmul(attention_weights, V)
    return O, L


N = 512
d = 8

Q = np.random.randn(N, d)
K = np.random.randn(N, d)
V = np.random.randn(N, d)


O_standard, L_standard = compute_attention(Q, K, V)


def flash_attention_modified(Q, K, V, on_chip_memory_size):
    B_c = on_chip_memory_size // (4 * d)  # Using 4 bytes per float
    B_r = min(on_chip_memory_size // (4 * d), d)
    
    
    O = np.zeros((N, d))
    L = np.zeros((N, ))
    m = -np.inf * np.ones(N)
    l = np.zeros(N)
    
    T_r = int(np.ceil(N / B_r))
    T_c = int(np.ceil(N / B_c))
    
    
    for i in range(T_r):
        Q_i = Q[i * B_r:(i + 1) * B_r]
        m_ij_prev = m[i * B_r:(i + 1) * B_r]
        l_ij_prev = l[i * B_r:(i + 1) * B_r]
        O_ij_prev = O[i * B_r:(i + 1) * B_r]
        
    
        for j in range(T_c):
            K_j = K[j * B_c:(j + 1) * B_c]
            V_j = V[j * B_c:(j + 1) * B_c]
        
            S_ij = np.dot(Q_i, K_j.T)
            assert S_ij.shape == (B_r, B_c)
            
            # Line 9
            m_ij = np.maximum(m_ij_prev, np.max(S_ij, axis=1))
            P_ij = np.exp(S_ij - m_ij[:, np.newaxis])
            l_ij = np.exp(m_ij_prev - m_ij) * l_ij_prev + np.sum(P_ij, axis=1)
            
            assert m_ij.shape == (B_r,)
            assert P_ij.shape == (B_r, B_c)
            assert l_ij.shape == (B_r,)
            
            # Line 10
            
            O_ij = np.diag(1 / (m_ij_prev - m_ij)) @ O_ij_prev + np.dot(P_ij, V_j)
            
            O_ij_prev = O_ij
            m_ij_prev = m_ij
            l_ij_prev = l_ij
            
        
        O[i * B_r:(i + 1) * B_r] = np.diag(1 / l_ij_prev) @ O_ij_prev
        L[i * B_r:(i + 1) * B_r] = m_ij_prev + np.log(l_ij_prev)
            
    return O, L

M = N * d * 4 
O_flash, L_flash = flash_attention_modified(Q, K, V, M)

print(np.allclose(O_standard, O_flash))
print(np.allclose(L_standard, L_flash))


True
True
