## What is multi-head attention

Multi-head attention showed as below image

<img src="./assert/multi-head-attentions.png" width="400" height="400" alt="MHA">


## Explain

Multi-head attention is pretty similar to self-attention.

On top of 1D query, key, value, now we have a n-d query, key, value.
After attention process, we simply concat them and flat into a single array for output.

## Why?

Single-head attention only apply for "one kind of understanding", for example, single head will understand "1 + 1" as **plain text**.

While multi-head attention could read "1 + 1" as both **plain text** and **formular**.

The more head we have, the more rich context we could extract from the input data.

In [13]:
import torch
import torch.nn as nn
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.w_query = nn.Linear(d_model, d_model)
        self.w_key = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)
        self.attention_scores = nn.Linear(d_model, d_model)
        
    def forward(self, x):
       seq_len = x.shape[0]
       query = self.w_query(x)
       key = self.w_key(x)
       value = self.w_value(x)
       
       # Reshape to (seq_len, n_heads, head_dim), then cast to (n_heads, seq_len, head_dim)
       query = query.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
       key = key.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
       value = value.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
       
       # score: (n_heads, seq_len, seq_len)
       score = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(self.head_dim)
       
       # attention: (n_heads, seq_len, head_dim)
       attention = torch.bmm(torch.softmax(score, dim=-1), value)
       
       attention =attention.transpose(0, 1).contiguous().view(seq_len, self.d_model)
       return self.attention_scores(attention)
    

In [14]:
torch.manual_seed(42)
embed_dim = 32
n_heads = 4
seq_len = 8

# Fix: Correct input shape (seq_len, d_model)
x = torch.randn(seq_len, embed_dim)
target = torch.randn(seq_len, embed_dim)

model = MultiHeadAttention(embed_dim, n_heads)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Epoch 0, Loss: 0.9528
Epoch 10, Loss: 0.8633
Epoch 20, Loss: 0.7874
Epoch 30, Loss: 0.6941
Epoch 40, Loss: 0.5665
Epoch 50, Loss: 0.4330
Epoch 60, Loss: 0.3291
Epoch 70, Loss: 0.2463
Epoch 80, Loss: 0.1821
Epoch 90, Loss: 0.1270
