https://colab.research.google.com/github/dlsyscourse/public_notebooks/blob/main/transformer_implementation.ipynb

`y = softmax(K Q^T / sqrt(d)) V`
`y = ( softmax(X W_K W_Q^T X) / sqrt(d)) X W_v ) W_out`

W_out seem to be uncessary since in above formular it can lum into W_v. But it will be useful in multi-head attn.

In [1]:
import numpy as np
import torch

def softmax(Z):
    Z = np.exp(Z - Z.max(axis=-1, keepdims=True)) # always take the softmax of the last dim
    return Z / Z.sum(axis=-1, keepdims=True)

def self_attn(X, mask, W_KQV, W_out):
    K,Q,V = np.split(X @ W_KQV, 3, axis=-1)
    # print(K.shape,Q.shape,V.shape)
    bach_size, d = X.shape
    attn = softmax(K @ Q.swapaxes(-1, -2) / np.sqrt(d) + mask)
    # print(">>>", attn.shape, V.shape, W_out.shape)
    return attn @ V @ W_out, attn

In [10]:
T, d = 100, 64
attn = torch.nn.MultiheadAttention(d, 1, bias=False, batch_first=True).cuda() # 1 head
mask = torch.triu(-float("inf")*torch.ones(T,T), 1).cuda()
X = torch.randn(1,T,d).cuda()
Y_, A_ = attn(X,X,X, attn_mask=mask)
# print(mask, X, "\n")
# print(Y_, A_)

In [13]:
W_KQV = attn.in_proj_weight.detach().cpu().numpy().T
W_out = attn.out_proj.weight.detach().cpu().numpy().T

assert W_out.shape == (64, 64), W_out.shape
assert W_KQV.shape[1] == 64 * 3, W_KQV.shape # k, q, v = 64 + 64 + 64

Y, A = self_attn(X[0].cpu().numpy(), mask.cpu().numpy(), W_KQV, W_out)

In [15]:
print(np.linalg.norm(A - A_[0].detach().cpu().numpy()))
print(np.linalg.norm(Y - Y_[0].detach().cpu().numpy()))


3.5984883e-07
2.4547944e-06


In [28]:
from flash_attn.flash_attention import FlashMHA
fmha = FlashMHA(d, 1, bias=False, batch_first=True, causal=True).cuda().half()
Y_, A_ = fmha(X.half())
print(Y_, Y)
print(np.linalg.norm(Y - Y_[0].detach().cpu().numpy()))


tensor([[[ 0.1205,  0.0696, -0.3733,  ..., -0.0801,  0.1240,  0.1228],
         [ 0.1050,  0.0545, -0.2861,  ...,  0.0406, -0.1459,  0.1060],
         [ 0.2625,  0.0136, -0.1875,  ...,  0.2360,  0.0106,  0.0776],
         ...,
         [-0.0457, -0.0076,  0.0082,  ..., -0.0568, -0.0705, -0.0358],
         [-0.0189, -0.0238,  0.0074,  ..., -0.0469, -0.1003, -0.0058],
         [-0.0105, -0.0269,  0.0029,  ..., -0.0108, -0.0815, -0.0106]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<UnsafeViewBackward0>) [[ 2.2795397e-01 -9.3587011e-02  4.9632975e-01 ... -6.0758632e-01
   7.2729014e-02  6.1132383e-01]
 [ 1.2217862e-01 -8.9289369e-03  3.0261526e-01 ... -4.8677298e-01
   4.0799819e-04  4.9533185e-01]
 [ 1.8917602e-01  1.3045375e-01 -4.4965193e-02 ... -6.3009659e-04
  -1.0004650e-01  1.5622777e-01]
 ...
 [-2.3968823e-02 -7.2586916e-02 -1.6450675e-02 ...  4.3055572e-02
  -2.2163093e-02 -5.0915599e-02]
 [-1.6038366e-02 -7.5717814e-02 -2.9027700e-02 ...  1.0709321e-01
   1.9591867e-0