In [18]:
import numpy as np

def attention(X):
    b, s, d_mod = X.shape
    d_k = d_v = d_mod // 2
    print(f"batch:{b}, seq:{s}, d_model:{d_mod}, d_k:{d_k}, d_v:{d_v}")
    
    W_Q = np.random.random((d_mod, d_k))
    W_K = np.random.random((d_mod, d_k))
    W_V = np.random.random((d_mod, d_v))

    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V
    print("Q: ", Q.shape)
    print("K: ", K.shape)
    print("V: ", V.shape)

    scores = Q @ np.transpose(K, (0, 2, 1)) / np.sqrt(d_k)
    print("Scores: ", scores.shape)

    causal_mask = np.triu(np.ones((s, s), dtype=bool), k=1)
    mask_values = np.where(causal_mask, -1e9, 0.0)
    scores = scores + mask_values[None, :, :]

    max_row = np.max(scores, axis=-1, keepdims=True)
    scores = np.exp(scores - max_row)
    A = scores / np.sum(scores, axis=-1, keepdims=True)
    print("Softmax: ", A.shape)

    out = A @ V
    print("Output: ", out.shape)
    return out

batch=4
seq=10
d_model=64

X = np.random.random((batch, seq, d_model))
res = attention(X)

batch:4, seq:10, d_model:64, d_k:32, d_v:32
Q:  (4, 10, 32)
K:  (4, 10, 32)
V:  (4, 10, 32)
Scores:  (4, 10, 10)
Softmax:  (4, 10, 10)
Output:  (4, 10, 32)


In [30]:
def multi_head_attention(X, num_heads=4):
    b, s, d_mod = X.shape

    assert d_mod % num_heads == 0

    d_k = d_v = d_mod // num_heads
    print(f"batch:{b}, seq:{s}, d_model:{d_mod}, d_k:{d_k}, d_v:{d_v}")
    
    W_Q = np.random.random((d_mod, d_mod))
    W_K = np.random.random((d_mod, d_mod))
    W_V = np.random.random((d_mod, d_mod))
    W_0 = np.random.random((d_mod, d_mod))

    # b, s, d_mod
    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V

    # reshape gives b, s, h, d_k -> transpose gives b, h, s, d_k
    Q = np.reshape(Q, (b, s, num_heads, d_k)).transpose((0, 2, 1, 3))
    K = np.reshape(K, (b, s, num_heads, d_k)).transpose((0, 2, 1, 3))
    V = np.reshape(V, (b, s, num_heads, d_v)).transpose((0, 2, 1, 3))

    print("Q: ", Q.shape)
    print("K: ", K.shape)
    print("V: ", V.shape)

    # should be b, h, s, s
    scores = Q @ np.transpose(K, (0, 1, 3, 2)) / np.sqrt(d_k)
    print("Scores: ", scores.shape)

    causal_mask = np.triu(np.ones((s, s), dtype=bool), k=1)
    mask_values = np.where(causal_mask, -1e9, 0.0)
    scores = scores + mask_values[None, None, :, :]

    max_row = np.max(scores, axis=-1, keepdims=True)
    scores = np.exp(scores - max_row)
    A = scores / np.sum(scores, axis=-1, keepdims=True)
    print("Softmax: ", A.shape)

    # should be b, h, s, d_k
    out = A @ V
    print("Per-head attention: ", out.shape)

    # transpose to b, s, h, d_k -> concatenate to b, s, d_mod -> project to b, s, d_mod
    res = out.transpose((0, 2, 1, 3)).reshape((b, s, d_mod))
    res = res @ W_0
    print("Output: ", res.shape)

    return res

batch=8
seq=10
d_model=64

X = np.random.random((batch, seq, d_model))
res = multi_head_attention(X)

batch:8, seq:10, d_model:64, d_k:16, d_v:16
Q:  (8, 4, 10, 16)
K:  (8, 4, 10, 16)
V:  (8, 4, 10, 16)
Scores:  (8, 4, 10, 10)
Softmax:  (8, 4, 10, 10)
Per-head attention:  (8, 4, 10, 16)
Output:  (8, 10, 64)
