# Multi-Head Attention Explained Simply
## Multi-Head Attention is a key concept in transformers, allowing the model to focus on different parts of the input sequence simultaneously. Here's a simple breakdown:

## Attention Mechanism: Think of it as a way for the model to focus on specific words or tokens in a sequence. When processing a word, the model can "attend" to other important words in the sentence to understand context better.

## Multiple Heads: Instead of using just one focus point (attention), the model splits its focus into multiple parts, called "heads." Each head looks at the input sequence in a slightly different way, capturing various aspects of the context.

## Parallel Processing: Each head processes the sequence independently, finding different patterns or relationships between the words.

## Combining Results: After processing, the outputs from all heads are combined to get a richer understanding of the input sequence.

In [5]:
import torch.nn as nn
import torch

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model%num_heads==0
        self.d_model=d_model
        self.num_heads = num_heads
        self.d_head = d_model//num_heads
        self.W_Q = nn.Linear(d_model,d_model) #Query
        self.W_K = nn.Linear(d_model,d_model) #Key
        self.W_V = nn.Linear(d_model,d_model) #Value
        self.W_O = nn.Linear(d_model,d_model) #Output

    def forward(self,x):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        Q = Q.view(x.size(0), -1, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(x.size(0), -1, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(x.size(0), -1, self.num_heads, self.d_head).transpose(1, 2)
        return Q,K,V   

In [9]:
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8
input_tensor = torch.rand(batch_size, seq_len, d_model)
print("Input Tensor",input_tensor.size())
multi_head_attention = MultiHeadAttention(d_model, num_heads)
Q, K, V = multi_head_attention(input_tensor)

print("Query shape:", Q.shape) 
print("\n",Q)
print("Key shape:", K.shape)    
pri
print("Value shape:", V.shape) 

Input Tensor torch.Size([2, 10, 512])
Query shape: torch.Size([2, 8, 10, 64])
Key shape: torch.Size([2, 8, 10, 64])
Value shape: torch.Size([2, 8, 10, 64])
