## What is multi-head attention

Multi-head attention showed as below image

<img src="./assert/multi-head-attentions.png" width="400" height="400" alt="MHA">


## Explain

Multi-head attention is pretty similar to self-attention.

On top of 1D query, key, value, now we have a n-d query, key, value.
After attention process, we simply concat them and flat into a single array for output.

## Why?

Single-head attention only apply for "one kind of understanding", for example, single head will understand "1 + 1" as **plain text**.

While multi-head attention could read "1 + 1" as both **plain text** and **formular**.

The more head we have, the more rich context we could extract from the input data.

In [1]:
import torch
import torch.nn as nn
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.w_query = nn.Linear(d_model, d_model)
        self.w_key = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)
        self.attention_scores = nn.Linear(d_model, d_model)
        
    def forward(self, x):
       seq_len = x.shape[0]
       query = self.w_query(x)
       key = self.w_key(x)
       value = self.w_value(x)
       
       # Reshape to (seq_len, n_heads, head_dim), then cast to (n_heads, seq_len, head_dim)
       query = query.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
       key = key.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
       value = value.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
       
       # score: (n_heads, seq_len, seq_len)
       score = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(self.head_dim)
       
       # attention: (n_heads, seq_len, head_dim)
       attention = torch.bmm(torch.softmax(score, dim=-1), value)
       
       attention =attention.transpose(0, 1).contiguous().view(seq_len, self.d_model)
       return self.attention_scores(attention)
    

ModuleNotFoundError: No module named 'torch'

In [None]:
torch.manual_seed(42)
embed_dim = 32
n_heads = 4
seq_len = 8

x = torch.randn(seq_len, embed_dim)
target = torch.randn(seq_len, embed_dim)

model = MultiHeadAttention(embed_dim, n_heads)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Epoch 0, Loss: 0.9528
Epoch 10, Loss: 0.8633
Epoch 20, Loss: 0.7874
Epoch 30, Loss: 0.6941
Epoch 40, Loss: 0.5665
Epoch 50, Loss: 0.4330
Epoch 60, Loss: 0.3291
Epoch 70, Loss: 0.2463
Epoch 80, Loss: 0.1821
Epoch 90, Loss: 0.1270


In [8]:
# Add musk function to multi-head attention

import torch
import torch.nn as nn
import math

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.w_query = nn.Linear(d_model, d_model)
        self.w_key = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)
        self.attention_scores = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        seq_len = query.shape[0]
        # from input, the query, key, value will be simply input matrix, input, input, input.
        query = self.w_query(query)
        key = self.w_key(key)
        value = self.w_value(value)

        query = query.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        key = key.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        value = value.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)

        attention_scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            # The computation is pretty similar to the previous one:
            # we have seq_len of query ( sub matrix ), after processing query * key.T, we get
            # seq_len, (n_heads, n_heads), where n_head is the size of original attention matrix
            
            # Note:
            # Post mask, we don't want to have 0 as masked value,
            # because, softmax(0) = 1, which will make the attention score too high.
            # So, we use -inf to mask the value. The normalization will still remains.
            attention_scores.masked_fill(mask == 0, float("-inf"))

        attention = torch.matmul(torch.softmax(attention_scores, dim=-1), value)

        attention = attention.transpose(0, 1).contiguous().view(seq_len, self.d_model)
        return self.attention_scores(attention)
            
        
        

In [None]:
torch.manual_seed(42)
embed_dim = 32
n_heads = 4
seq_len = 8

x = torch.randn(seq_len, embed_dim)
target = torch.randn(seq_len, embed_dim)

# mask: (seq_len, seq_len)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)

# model: (seq_len, embed_dim) -> (seq_len, embed_dim)
model = MaskedMultiHeadAttention(embed_dim, n_heads)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    output = model(x, x, x, mask)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Epoch 0, Loss: 0.9432
Epoch 10, Loss: 0.8329
Epoch 20, Loss: 0.7072
Epoch 30, Loss: 0.5641
Epoch 40, Loss: 0.4373
Epoch 50, Loss: 0.3366
Epoch 60, Loss: 0.2433
Epoch 70, Loss: 0.1577
Epoch 80, Loss: 0.0930
Epoch 90, Loss: 0.0512


In [18]:
import torch
import torch.nn as nn
import math

# In this example, we want to support batch processing.
# Means, the input will be now (batch_size, seq_size, d_model)

# Note:
# Batch will be as same as original output.
# The reason we need batch is because we want to support parallel processing.
# Consider GPU can handle larger memory, we can process more data in parallel.

class BatchedMaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.w_query = nn.Linear(d_model, d_model)
        self.w_key = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)

        self.attention_scores = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        seq_size = query.shape[1]
        
        # size for query, key, value: (batch_size, seq_size, d_model)
        query = self.w_query(query)
        key = self.w_key(key)
        value = self.w_value(value)
       
        # cast size to (batch_size, n_heads, seq_size, head_dim)
        query = query.view(batch_size, seq_size, self.n_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_size, self.n_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_size, self.n_heads, self.head_dim).transpose(1, 2)
        
        # the attention score will be computed from (seq_size, head_dim) * (seq_size, head_dim).T
        # attention_scores: (batch_size, n_heads, seq_size, seq_size)
        attention_scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)

        if mask is not None:
            # following previous steps to calcuation masked value.
            attention_scores.masked_fill(mask== 0, float("-inf"))

        # matmul with value for final score.
        # attention_scores: (batch_size, n_heads, seq_size, seq_size) * (batch_size, n_heads, seq_size, head_dim) -> (batch_size, n_heads, seq_size, head_dim)
        attention_scores = torch.matmul(torch.softmax(attention_scores, dim=-1), value)
    
        # cast size to original format: (batch_size, n_heads, seq_size, head_dim) -> (batch_size, seq_size, n_heads, head_dim) -> (batch_size, seq_size, d_model)
        # Note:
        # 1. transpose(1, 2) is to move the n_heads dimension to the last dimension.
        # 2. contiguous() is to make the tensor contiguous in memory.
        # 3. view(batch_size, seq_size, self.d_model) is to reshape the tensor to the original size.
        attention_output = attention_scores.transpose(1, 2).contiguous().view(batch_size, seq_size, self.d_model)

        return self.attention_scores(attention_output)




        

In [19]:
torch.manual_seed(42)
batch_size = 4
embed_dim = 32
n_heads = 4
seq_len = 8

x = torch.randn(batch_size, seq_len, embed_dim)
target = torch.randn(batch_size, seq_len, embed_dim)

# mask: (batch_size, seq_len, seq_len)
mask = torch.triu(torch.ones(batch_size, seq_len, seq_len), diagonal=1)

# model: (batch_size, seq_len, embed_dim) -> (batch_size, seq_len, embed_dim)
model = BatchedMaskedMultiHeadAttention(embed_dim, n_heads)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

model.train()
for epoch in range(100):
    optimizer.zero_grad()
    output = model(x, x, x, mask)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Epoch 0, Loss: 0.9783
Epoch 10, Loss: 0.9198
Epoch 20, Loss: 0.8733
Epoch 30, Loss: 0.8275
Epoch 40, Loss: 0.7754
Epoch 50, Loss: 0.7139
Epoch 60, Loss: 0.6447
Epoch 70, Loss: 0.5727
Epoch 80, Loss: 0.5028
Epoch 90, Loss: 0.4369


## Thinking

1. Since we are slicing by **head_num**, why not just random pick like dropout?

Well, actually in middle of MHA, order contains meaning, it called "position_embedding", without adding more training cycles to trace the order, we pretty much randomize the parameters again.

2. Also, check size carefully, since the computation of matrix will always apply on (A, B) @ (B, A), we need mention in matrix computation process, which two dimention we will need to compute.