In [1]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def compute_attention(Q, K, V):
    attention_scores = np.matmul(Q, K.T)
    attention_weights = softmax(attention_scores)
    O = np.matmul(attention_weights, V)
    return O


N = 512  
d = 64   

Q = np.random.randn(N, d)
K = np.random.randn(N, d)
V = np.random.randn(N, d)


O_standard = compute_attention(Q, K, V)


def flash_attention_modified(Q, K, V, on_chip_memory_size):
    B_c = on_chip_memory_size // (4 * d)  # Using 4 bytes per float
    B_r = min(on_chip_memory_size // (4 * d), d)
    O = np.zeros((N, d))
    m = -np.inf * np.ones(N)
    T_r = int(np.ceil(N / B_r))
    T_c = int(np.ceil(N / B_c))
    l = np.zeros(N)
    
    for j in range(T_c):
        K_j = K[j * B_c:(j + 1) * B_c]
        V_j = V[j * B_c:(j + 1) * B_c]
        
        for i in range(T_r):
            Q_i = Q[i * B_r:(i + 1) * B_r]
            m_i = m[i * B_r:(i + 1) * B_r]
            l_i = l[i * B_r:(i + 1) * B_r]
            O_i = O[i * B_r:(i + 1) * B_r]
            
            # Line 8
            S_ij = np.dot(Q_i, K_j.T)
            assert S_ij.shape == (B_r, B_c)
            
            # Line 9
            m_ij = np.max(S_ij, axis=1)
            P_ij = np.exp(S_ij - m_ij[:, np.newaxis])
            l_ij = np.sum(P_ij, axis=1)
            assert m_ij.shape == (B_r,)
            assert P_ij.shape == (B_r, B_c)
            assert l_ij.shape == (B_r,)
            
            # Line 10
            m_new = np.maximum(m_i, m_ij)
            l_new = np.exp(m_i - m_new) * l_i + np.exp(m_ij - m_new) * l_ij
            assert m_new.shape == (B_r,)
            assert l_new.shape == (B_r,)
        
            O[i * B_r:(i + 1) * B_r] = np.linalg.inv(np.diag(l_new)) @ ((np.diag(l_i) * np.exp(m_i - m_new) @ O_i) + (np.exp(m_ij - m_new) * np.dot(P_ij, V_j)))
            
            m[i * B_r:(i + 1) * B_r] = m_new
            l[i * B_r:(i + 1) * B_r] = l_new
            
    return O

M = N * d * 4  
O_flash = flash_attention_modified(Q, K, V, M)

np.allclose(O_standard, O_flash)


True