In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [7]:
torch.__version__

'1.7.1'

In [3]:
embed_dim = 4
num_heads = 1

In [32]:
from pprint import pprint

model = nn.MultiheadAttention(
    embed_dim=embed_dim,
    num_heads=num_heads,
    bias=False,
)
pprint(list(model.named_parameters()))

[('in_proj_weight',
  Parameter containing:
tensor([[ 0.0492,  0.1454, -0.3237, -0.5530],
        [-0.4122, -0.4595, -0.1724, -0.2761],
        [-0.4626,  0.2290, -0.5929, -0.3930],
        [-0.2985, -0.1988,  0.4165,  0.6081],
        [-0.1084, -0.1049,  0.2425, -0.0454],
        [-0.1500,  0.5269,  0.5555, -0.1842],
        [-0.3304,  0.6096, -0.3090, -0.0992],
        [ 0.0750, -0.1702,  0.3382, -0.3463],
        [ 0.0877, -0.2240, -0.4005, -0.3957],
        [ 0.4808, -0.0672, -0.2906, -0.3212],
        [-0.0817, -0.0162, -0.5907, -0.4768],
        [ 0.5461,  0.0227,  0.2328, -0.5729]], requires_grad=True)),
 ('out_proj.weight',
  Parameter containing:
tensor([[ 0.4502, -0.4243,  0.1935, -0.2146],
        [ 0.0306, -0.1774,  0.3395,  0.2386],
        [-0.2647,  0.4024,  0.0799, -0.2737],
        [-0.3745, -0.3256, -0.2285, -0.2394]], requires_grad=True)),
 ('out_proj.bias',
  Parameter containing:
tensor([-0.3763,  0.1134,  0.3692,  0.3514], requires_grad=True))]


In [33]:
B = 16
L = 10
D = embed_dim
X = torch.randn(B, L, D)
X.shape

torch.Size([16, 10, 4])

In [34]:
X = X.permute(1, 0, 2) # Batch second (L, B, D)
Q = K = V = X
attn_output, attn_weights = model(Q, K, V)

In [35]:
attn_output.shape

torch.Size([10, 16, 4])

In [31]:
attn_weights.shape

torch.Size([16, 10, 10])

$$ MultiHeadAttention(Q,K,V)=Concat(head_1, ..., head_h)W^O \\
head_i = Attention(QW^Q_i, KW^K_i, VW^V_i) \\
Attention(Q,K,V) = softmax({QK^T \over \sqrt{d_Q}})V$$

In [43]:
model_weights = {name: param.data for name, param in model.named_parameters()}
Wi = model_weights['in_proj_weight']
Wo = model_weights['out_proj.weight']
bias = model_weights['out_proj.bias']

In [40]:
Wi.shape

torch.Size([12, 4])

In [41]:
Wo.shape

torch.Size([4, 4])

In [45]:
Wi_q, Wi_k, Wi_v = Wi.chunk(3)
Wi_q.shape

torch.Size([4, 4])

$ head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) $  
now single head, so $i = 1$

In [79]:
Q_ = Q.permute(1, 0, 2) # Batch first (B, L, D)
K_ = K.permute(1, 0, 2)
V_ = V.permute(1, 0, 2)
QW = torch.matmul(Q_, Wi_q.T) # (B, L, D) dot (D, D)
KW = torch.matmul(K_, Wi_k.T)
VW = torch.matmul(V_, Wi_v.T)
QW.shape # (B, L, D)

torch.Size([16, 10, 4])

$$ Attention(Q,K,V)=(Attention Weights)V \\
Attention Weights = softmax({QK^T \over \sqrt{d_q}})
$$

In [80]:
import math

KW_t = KW.permute(0, 2, 1) # (B, D, L)
QK_t_scaled = torch.bmm(QW, KW_t) / math.sqrt(D) # (B, L, D) batch dot (B, D, L) -> (B, L, L)
attn_weights_ = F.softmax(QK_t_scaled, dim=-1)
attn_weights_.shape

torch.Size([16, 10, 10])

In [81]:
attn_weights.allclose(attn_weights_)

True

In [121]:
attention = torch.bmm(attn_weights_, VW) # (B, L, L) (B, L, D) -> (B, L, D)
attention.shape

torch.Size([16, 10, 4])

$$ head_i = Attention(Q,K,V)\\MultiHead(Q,K,V)=Concat(head_1,…,head_h)W^O $$

In [125]:
attn_output_ = torch.matmul(attention, Wo.T) + bias
attn_output_ = attn_output_.permute(1, 0, 2) # Batch second (L, B, D)
attn_output_.shape

torch.Size([10, 16, 4])

In [123]:
attn_output.allclose(attn_output_, atol=1e-04)

True