### Multi-Head Attention

In [1]:
import torch
import torch.nn.functional as F
import math

In [5]:
# random input with batch=1, tokens 4 with embedding vector of 8
x = torch.randn(1,4,8)

In [12]:
embed_dim = 8
num_heads = 2
head_dim = embed_dim // num_heads
head_dim # from 1 (1,4,8) now we have two heads with dim = (1,4,2,4) - 2 head with 4d vector

4

In [7]:
# weight matrices for Q,K,V and output projection
W_q = torch.randn(embed_dim,embed_dim)
W_k = torch.randn(embed_dim,embed_dim)
W_v = torch.randn(embed_dim,embed_dim)
W_o = torch.randn(embed_dim,embed_dim)

In [9]:
# Computing Q, K and V
Q = x @ W_q
K = x @ W_q
V = x @ W_v
Q.shape, K.shape, V.shape

(torch.Size([1, 4, 8]), torch.Size([1, 4, 8]), torch.Size([1, 4, 8]))

In [13]:
def split_heads(t):
  return t.view(1,4,num_heads,head_dim).transpose(1,2) # first (1,4,8) -> (1,4,2,4) -> (1,2,4,4)

In [14]:
Qh = split_heads(Q)
Kh = split_heads(K)
Vh = split_heads(V)
Qh.shape, Kh.shape, Vh.shape

(torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4]), torch.Size([1, 2, 4, 4]))

In [17]:
## Compute attention per heads
scores = Qh @ Kh.transpose(-2,-1)

#  Scaling
scores = scores / math.sqrt(head_dim)

# attention weights
attn = F.softmax(scores, dim=-1)

head_output = attn @ Vh

In [18]:
## concatenate heads back to get original shape
concate = head_output.transpose(1,2).reshape(1,4,embed_dim)

In [20]:
print('shape of head output: ',head_output.shape)
print('concatenate result: ',concate.shape)

shape of head output:  torch.Size([1, 2, 4, 4])
concatenate result:  torch.Size([1, 4, 8])


In [21]:
print('input\n',x)
print('multihead attention output\n',concate)

input
 tensor([[[ 0.4994, -1.6186, -1.7349,  1.4511,  0.3490, -2.0378,  1.6342,
          -0.2907],
         [-0.0035, -0.6745, -1.4670, -2.0123, -0.5813,  0.1806,  0.1866,
           0.1056],
         [ 0.7883,  0.1376,  0.9274, -1.1609, -3.1396, -0.4939, -1.9834,
          -0.5342],
         [ 0.2885, -0.0091, -1.0703, -0.8698,  0.0283, -0.6103, -0.1859,
          -0.9951]]])
multihead attention output
 tensor([[[ 0.9287,  0.6518, -4.3290, -2.4623,  2.7119,  6.4714, -2.5216,
          -5.8881],
         [ 0.9205,  0.6397, -4.2778, -2.4562,  4.0769, -0.3710,  1.6185,
          -0.1193],
         [ 0.7359, -4.3003,  2.2744,  2.4425,  0.4189, -1.2835, -0.9693,
           2.7203],
         [ 0.7298,  0.2853, -3.3072, -2.2036,  0.9025, -0.0401, -1.1982,
           1.2232]]])
