# Multi-head attention

In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

## single-head

In [2]:
### init data
B,T,C = 2, 4, 12 # batch, time, channels
head_size = 16
x = torch.randn(B,T,C)

### define Wq, Wk, Wv
k_proj = torch.rand((head_size, C))
q_proj = torch.rand((head_size, C))
v_proj  = torch.rand((head_size, C))

### compute q, k, v
k = x @ k_proj.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
q = x @ q_proj.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
v = x @ v_proj.T   # (B, T, C) @ (C, hs) -> (B, T, hs)

### compute attention score
attn =  q @ k.transpose(-2, -1) / math.sqrt(head_size)# (B, T, hs) @ (B, hs, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
attn = attn.masked_fill(tril == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
print("causal attention score")
print(attn)

### compute output
out = attn @ v  # (B, T, T) @ (B, T, hs) -> (B, T, hs)

causal attention score
tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [9.7902e-10, 1.0000e+00, 0.0000e+00, 0.0000e+00],
         [8.5029e-03, 9.9091e-01, 5.8267e-04, 0.0000e+00],
         [2.0480e-06, 3.0020e-10, 1.4393e-05, 9.9998e-01]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.2679e-06, 0.0000e+00, 0.0000e+00],
         [1.5599e-11, 1.0883e-07, 1.0000e+00, 0.0000e+00],
         [9.7038e-02, 1.7277e-15, 7.7398e-38, 9.0296e-01]]])


## multi-head

In [None]:
### init data
B,T,C = 2, 4, 12 # batch, time, channels
head_size = 16
x = torch.randn(B,T,C)

########################## Head 1 ##########################
### define Wq, Wk, Wv
k_proj_1 = torch.rand((head_size, C))
q_proj_1 = torch.rand((head_size, C))
v_proj_1  = torch.rand((head_size, C))

### compute q, k, v
k_1 = x @ k_proj_1.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
q_1 = x @ q_proj_1.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
v_1 = x @ v_proj_1.T   # (B, T, C) @ (C, hs) -> (B, T, hs)

### compute attention score
attn_1 =  q_1 @ k_1.transpose(-2, -1) / math.sqrt(head_size)# (B, T, hs) @ (B, hs, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
attn_1 = attn.masked_fill(tril == 0, float('-inf'))
attn_1 = F.softmax(attn_1, dim=-1)

### compute output
out_1 = attn_1 @ v_1  # (B, T, T) @ (B, T, hs) -> (B, T, hs)

########################## Head 2 ##########################
### define Wq, Wk, Wv
k_proj_2 = torch.rand((head_size, C))
q_proj_2 = torch.rand((head_size, C))
v_proj_2  = torch.rand((head_size, C))

### compute q, k, v
k_2 = x @ k_proj_2.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
q_2 = x @ q_proj_2.T   # (B, T, C) @ (C, hs) -> (B, T, hs)
v_2 = x @ v_proj_2.T   # (B, T, C) @ (C, hs) -> (B, T, hs)

### compute attention score
attn_2 =  q_2 @ k_2.transpose(-2, -1) / math.sqrt(head_size)# (B, T, hs) @ (B, hs, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
attn_2 = attn.masked_fill(tril == 0, float('-inf'))
attn_2= F.softmax(attn_2, dim=-1)

### compute output
out_2 = attn_2 @ v_2 # (B, T, T) @ (B, T, hs) -> (B, T, hs)

########################## fuse multi head ##########################
multi_head_proj = torch.rand((head_size, head_size * 2)) # [hs, hs * 2]

concat_attention_output = torch.cat([out_1, out_2], dim = -1) # [B, T, hs * 2]

multi_head_output = concat_attention_output @ multi_head_proj.T # [B, T, hs]