In [None]:
import numpy as np

# Matrices Q,K,V as (N * d) in HBM, SRAM Size of M
N = 6
d = 3


Q = np.random.rand(N, d)
K = np.random.rand(N, d)
V = np.random.rand(N, d)
M = 24

In [None]:
# 1. Set block size Bc = M/4d, Br = min(M/4d, d)
B_c = M // (4 * d)
B_r = min(B_c, d)

In [None]:
# 2. Initialize
Out = np.zeros((N, d))
l = np.zeros(N) #noqa: E741
m = np.full(N, -np.inf)

In [None]:
# 3. Divide Q into Tr which is N/B_r size of Br * d, K,V into Tr N/B_c Bc * d
Q_blocks = np.array_split(Q, N // B_r)  # np.array_split returns List of array
K_blocks = np.array_split(K, N // B_c)
V_blocks = np.array_split(V, N // B_c)
T_r = len(Q_blocks)
T_c = len(K_blocks)

In [None]:
# Proof of Step 3
Q_blocks = np.array(np.array_split(Q, N // B_r))
print(
    f"Q_Block's T_r is {Q_blocks.shape[0]}, so Q_blacks shape must be (T_r, B_r, d) which is {Q_blocks.shape}"
)
print(f"Q_blacks shape : {Q_blocks.shape}, Br : {B_r}, N : {d}")

In [None]:
# 4. Divide O,l,m each tiles
O_tiles = np.array_split(Out, T_r)
l_tiles = np.array_split(l, T_r)
m_tiles = np.array_split(m, T_r)

In [None]:
# 5. Loop in Tc
for j in range(T_c):
    # 6. Load K_j, V_j from HBM to SRAM
    K_j, V_j = K_blocks[j], V_blocks[j]

    # 7. Loop in Tr
    for i in range(T_r):
        # 8. Load Q 𝑖, O 𝑖, ℓ 𝑖, 𝑚 𝑖 from HBM to on-chip SRAM.
        # Q_i, O_i, l_i, m_i = Q_blocks[i], O_tiles[i], l_tiles[i], m_tiles[i]
        Q_i, O_i, l_i, m_i = Q_blocks[i], O_tiles[i], l_tiles[i], m_tiles[i]

        # 9. Compute Sij = Q_iK^T_j which return Br X Bc
        S_ij = np.dot(Q_i, K_j.T)  # Orginally on SRAM
        # Proof of shape
        # print(f"S_ij.shape:{S_ij.shape} vs B_r X B_c : {B_r,B_c}")

        # 10. Compute i) mhat_ij = rowmax(S_ij) return Br,
        #            ii) P_ij = e^(S_ij - mhat_ij) return B_r X B_c,
        #           iii) lhat_ij = rowsum(P_ij) return B_r
        mhat_ij = np.max(S_ij, axis=1)  # max per column
        # add dimension and broadcas

        P_ij = np.exp(S_ij - mhat_ij[:, np.newaxis])
        lhat_ij = np.sum(P_ij, axis=1)

        # 11. Compute i) mnew_i = max(m_i, mhat_ij) return B_r,
        #            ii) lnew_i = exp(m_i - mnew_i)l_i + exp(mhat_ij-mnew_i)lhat_ij return B_r
        mnew_i = np.maximum(m_i, mhat_ij)
        lnew_i = np.exp(m_i - mnew_i) * l_i + np.exp(mhat_ij - mnew_i) * lhat_ij

        # 12. O_i = diag(lnew_i)^-1 * (diag(l_i)exp(m_i - mnew_i)*O_i + exp(mhat_ij-mnew_i)lhat_ij) return B_r
        # TODO
        O_tiles[i] = (1 / lnew_i)[:, np.newaxis] * (
            (l_i * np.exp(m_i - mnew_i))[:,np.newaxis] * O_i
            + (np.exp(mhat_ij - mnew_i)[:,np.newaxis] * P_ij) @ V_j
        )
        # 13. override
        l_tiles[i], m_tiles[i] = lnew_i, mnew_i
# 14. O = concatenate(O_1, O_2, ... , O_T_r) return N X d
out = np.concatenate(O_tiles, axis=0)
print(out)

vanilla attention

In [None]:
S = Q @ K.T
# apply softmax
P = np.exp(S - np.max(S, axis=1, keepdims=True))
P = P / np.sum(P, axis=1, keepdims=True)
O_ = P @ V

In [None]:
O_

In [None]:
import torch

torch.nn.functional.scaled_dot_product_attention(torch.Tensor(Q), torch.Tensor(K), torch.Tensor(V))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation

# Define the matrix and vector
A = np.array([[2, 1], [1, 3]])
v = np.array([1, 2])

# Calculate the individual column results
results = [A[:, i].reshape(-1, 1) * v[i] for i in range(A.shape[1])]

# Set up the figure, the axis, and the plot elements
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(-10, 10)
ax.set_ylim(-10, 10)
ax.grid(True)

vector_color = ["yellow", "blue"]

vectors = []

for i, res in enumerate(results):
    (vec,) = ax.plot([], [], lw=2, color=vector_color[i], label=f"{v[i]} * Column {i+1}")
    vectors.append(vec)

ax.legend()


# Initialization function: plot the background of each frame
def init():
    for vec in vectors:
        vec.set_data([], [])
    return vectors


# Animation function: update the plot for each frame
def animate(i):
    for j, res in enumerate(results):
        if i == j:
            vectors[j].set_data([0, res[0]], [0, res[1]])
        else:
            vectors[j].set_data([], [])
    return vectors


# Call the animator
anim = FuncAnimation(
    fig, animate, init_func=init, frames=len(results), repeat=True, blit=True, interval=1000
)
anim.save("matrix-vector-multiplication.gif", writer="pillow", fps=1)