In [6]:
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def compute_attention(Q, K, V):
    attention_scores = np.matmul(Q, K.T)
    attention_weights = softmax(attention_scores)
    O = np.matmul(attention_weights, V)
    return O


N = 512  
d = 64   

Q = np.random.randn(N, d)
K = np.random.randn(N, d)
V = np.random.randn(N, d)


O_standard = compute_attention(Q, K, V)


def flash_attention(Q, K, V, on_chip_memory_size):
    B_c = on_chip_memory_size // (4 * d)  # Using 4 bytes per float
    B_r = min(on_chip_memory_size // (4 * d), d)
    O = np.zeros((N, d))
    m = -np.inf * np.ones(N)
    T_r = int(np.ceil(N / B_r))
    T_c = int(np.ceil(N / B_c))
    for j in range(T_c):
        K_j = K[j * B_c:(j + 1) * B_c]
        V_j = V[j * B_c:(j + 1) * B_c]
        for i in range(T_r):
            Q_i = Q[i * B_r:(i + 1) * B_r]
            m_i = m[i * B_r:(i + 1) * B_r]
            S_ij = np.dot(Q_i, K_j.T)
            m_ij = np.max(S_ij, axis=1)
            m_new = np.maximum(m_i, m_ij)
            S_ij -= m_new[:, np.newaxis]
            P_ij = np.exp(S_ij)
            P_ij /= P_ij.sum(axis=1, keepdims=True)
            O[i * B_r:(i + 1) * B_r] += np.dot(P_ij, V_j)
            m[i * B_r:(i + 1) * B_r] = m_new
    return O

M = N * d * 4 * 10  
O_flash = flash_attention(Q, K, V, M)

np.allclose(O_standard, O_flash)


True