# Multi Head Attention or MHA

as the name, multiple attention mechanisms, each attending to different things / information in parallel.

In [1]:
d_model = 64 
num_heads = 8 
d_k = d_model // num_heads # dimension of per head

MHA : 
1. splits Q, K, V in multiple heads 
2. Apply attention to each head independently
3. Concatenate all heads
4. apply output proj

In [2]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np

In [6]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0 

        self.d_model = d_model 
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        # output projection 
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        """
        x : [batch-size, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # shape after -> [batch_sizem, num_heads, seq_len, d_k] 
        # splits d_model into heads * d_k, this gives multiple parallel views of the sequence
        # each head with its own attn scores
        # if not seperrated the denominator in softmax (attn fx) will be d_model instead of d_k, 
        # which will result in one giant head
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)

        # [batch_size, num_head, seq_len, d_k] @ [batch_size, num_head, d_k, seq_len] -> [batch_size, num_head, seq_len, seq_len]
        # scaled dot product attention 
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k) 
        attention_weights = F.softmax(scores, dim=-1)
        attn_outputs = torch.matmul(attention_weights, V)

        # concatenate heads : [batch_size, seq_len, d_model]
        attn_outputs = attn_outputs.transpose(1, 2).contiguous()
        attn_outputs = attn_outputs.view(batch_size, seq_len, d_model)

        # passing output through a linear layer 
        output = self.W_o(attn_outputs) 

        return output

In [7]:
d_model = 64 
num_heads = 8 
seq_len = 10 
batch_size = 2 

mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)

output = mha(x)

print(f'Input shape : {x.shape}')
print(f'Output shape : {output.shape}')
print(f'no of heads : {num_heads}')

Input shape : torch.Size([2, 10, 64])
Output shape : torch.Size([2, 10, 64])
no of heads : 8
