# Multi-Head Attention Explained Simply
## 1. Multi-Head Attention is a key concept in transformers, allowing the model to focus on different parts of the input sequence simultaneously. Here's a simple breakdown:

## 2. Attention Mechanism: Think of it as a way for the model to focus on specific words or tokens in a sequence. When processing a word, the model can "attend" to other important words in the sentence to understand context better.

## 3. Multiple Heads: Instead of using just one focus point (attention), the model splits its focus into multiple parts, called "heads." Each head looks at the input sequence in a slightly different way, capturing various aspects of the context.

## 4. Parallel Processing: Each head processes the sequence independently, finding different patterns or relationships between the words.

## 5. Combining Results: After processing, the outputs from all heads are combined to get a richer understanding of the input sequence.

In [1]:
import torch.nn as nn
import torch
import math
import torch.nn.functional as F

In [6]:
def scaled_dot_product(q,k,v,mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled,dim=-1)
    values = torch.matmul(attention,v)
    return values,attention


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model,batch_size, num_heads,sequence_length):
        super(MultiHeadAttention, self).__init__()
        assert d_model%num_heads==0
        self.d_model=d_model
        self.sequence_length = sequence_length
        self.batch_size = batch_size  # Batch size for parallel processing
        self.num_heads = num_heads
        self.d_head = d_model//num_heads
        self.qkv = nn.Linear(d_model,3*d_model)
        self.linear = nn.Linear(d_model,d_model)
    def forward(self,x):
        qkv = self.qkv(x)
        qkv = qkv.reshape(self.batch_size,self.sequence_length,self.num_heads,3*self.d_head)
        print("Dimension of Query key and value",qkv.size())
        qkv = qkv.permute(0,2,1,3)
        q,k,v = qkv.chunk(3,dim=-1)
        print("Afer breaking it ",q.size(),"k",k.size(),'v',v.size())
        values,attention = scaled_dot_product(q,k,v)
        print("After Scaled dot Product",values.size())
        values = values.reshape(self.batch_size, self.sequence_length, self.num_heads * self.d_head)
        return self.linear(values)

In [7]:
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8
input_tensor = torch.rand(batch_size, seq_len, d_model)
print("Input Tensor",input_tensor.size())
multi_head_attention = MultiHeadAttention(d_model, batch_size,num_heads,seq_len)
values = multi_head_attention(input_tensor)
print("Shape :::",values.size())
print(values)

Input Tensor torch.Size([2, 10, 512])
Dimension of Query key and value torch.Size([2, 10, 8, 192])
Afer breaking it  torch.Size([2, 8, 10, 64]) k torch.Size([2, 8, 10, 64]) v torch.Size([2, 8, 10, 64])
After Scaled dot Product torch.Size([2, 8, 10, 64])
Shape ::: torch.Size([2, 10, 512])
tensor([[[ 0.0266,  0.0013, -0.2697,  ...,  0.0700, -0.1862, -0.3483],
         [ 0.1292,  0.2899, -0.0725,  ..., -0.1331, -0.0367, -0.5440],
         [ 0.0320,  0.1188, -0.1892,  ...,  0.0804, -0.3294,  0.1453],
         ...,
         [ 0.1085,  0.0963,  0.1025,  ...,  0.3230,  0.3623, -0.1558],
         [ 0.2542,  0.0635,  0.3222,  ...,  0.1230,  0.0415, -0.2237],
         [-0.1716,  0.0673,  0.1148,  ...,  0.1442, -0.0649, -0.0435]],

        [[ 0.0598,  0.0235, -0.2230,  ...,  0.0520, -0.1016, -0.3768],
         [ 0.0987,  0.3513, -0.0627,  ..., -0.0927,  0.0324, -0.5550],
         [ 0.0592,  0.1090, -0.1552,  ...,  0.0421, -0.2721,  0.1636],
         ...,
         [ 0.1132,  0.0910,  0.0424,  ...,