In [41]:
import numpy as np
# Matrices Q,K,V as (N * d) in HBM, SRAM Size of M
N = 3
d = 2


Q = np.random.rand(N,d)
K = np.random.rand(N,d)
V = np.random.rand(N,d)
M = 8

In [42]:
# 1. Set block size Bc = M/4d, Br = min(M/4d, d)
B_c = M//(4*d)
B_r = min(B_c, d)

In [43]:
# 2. Initialize
O = np.zeros((N,d))
l = np.zeros(N)
m = np.full(N, -np.inf)

In [44]:
# 3. Divide Q into Tr which is N/B_r size of Br * d, K,V into Tr N/B_c Bc * d
Q_blocks = np.array_split(Q, N // B_r) # np.array_split returns List of array
K_blocks = np.array_split(K, N // B_c)
V_blocks = np.array_split(V, N // B_c)
T_r = len(Q_blocks)
T_c = len(K_blocks)

In [45]:
# Proof of Step 3
Q_blocks = np.array(np.array_split(Q, N // B_r))
print(f"Q_Block's T_r is {Q_blocks.shape[0]}, so Q_blacks shape must be (T_r, B_r, d) which is {Q_blocks.shape}")
print(f"Q_blacks shape : {Q_blocks.shape}, Br : {B_r}, N : {d}")

Q_Block's T_r is 3, so Q_blacks shape must be (T_r, B_r, d) which is (3, 1, 2)
Q_blacks shape : (3, 1, 2), Br : 1, N : 2


In [46]:
# 4. Divide O,l,m each tiles
O_tiles = np.array_split(O, T_r)
l_tiles = np.array_split(l, T_r)
m_tiles = np.array_split(m, T_r)

In [47]:
# 5. Loop in Tc
for j in range(T_c):
    # 6. Load K_j, V_j from HBM to SRAM
    K_j, V_j = K_blocks[j], V_blocks[j]
    
    # 7. Loop in Tr
    for i in range(T_r):
        # 8. Load Q 𝑖, O 𝑖, ℓ 𝑖, 𝑚 𝑖from HBM to on-chip SRAM.
        # Q_i, O_i, l_i, m_i = Q_blocks[i], O_tiles[i], l_tiles[i], m_tiles[i]
        Q_i, O_i, l_i, m_i = Q_blocks[i], O_tiles[i:i+B_r],l_tiles[i:i+B_r],m_tiles[i:i+B_r]
        
        # 9. Compute Sij = Q_iK^T_j which return Br X Bc
        S_ij = np.dot(Q_i, K_j.T) # Orginally on SRAM
        # Proof of shape
        # print(f"S_ij.shape:{S_ij.shape} vs B_r X B_c : {B_r,B_c}")

        # 10. Compute i) mhat_ij = rowmax(S_ij) return Br, 
        #            ii) P_ij = e^(S_ij - mhat_ij) return B_r X B_c, 
        #           iii) lhat_ij = rowsum(P_ij) return B_r
        mhat_ij = np.max(S_ij, axis=1)
        P_ij = np.exp(S_ij - mhat_ij)
        lhat_ij = np.sum(P_ij, axis=1)
        
        # 11. Compute i) mnew_i = max(m_i, mhat_ij) return B_r, 
        #            ii) lnew_i = exp(m_i - mnew_i)l_i + exp(mhat_ij-mnew_i)lhat_ij return B_r
        mnew_i = np.maximum(m_i, mhat_ij)
        lnew_i = np.exp(m_i - mnew_i)*l_i + np.exp(mhat_ij-mnew_i)*lhat_ij
        
        # 12. O_i = diag(lnew_i)^-1 * (diag(l_i)exp(m_i - mnew_i)*O_i + exp(mhat_ij-mnew_i)lhat_ij) return B_r
        O_i = (np.diag(l_i)*np.exp(m_i - mnew_i) * O_i + np.exp(mhat_ij-mnew_i)*np.dot(P_ij,V_j)) / np.diag(lnew_i)
        O_tiles[j][i:i+B_r] = O_i[0][0]
        # 13. override
        l[i:i+B_r], m[i:i+B_r] = lnew_i[0], mnew_i[0]

In [48]:
diag_l_new_i = np.diag(lnew_i)

In [49]:
O

array([[0.38913363, 0.8367397 ],
       [0.50596313, 0.5517325 ],
       [0.33130509, 0.89223909]])

In [50]:
import torch
torch.nn.functional.scaled_dot_product_attention(torch.Tensor(Q),torch.Tensor(K),torch.Tensor(V))

tensor([[0.4061, 0.7114],
        [0.4001, 0.7272],
        [0.4054, 0.7114]])