#### Why Self-attention is insufficient?

The main limitation is that a single attention "head" can only capture one type of relationship or feature pattern at a time. Real-world data, like language, has complex interactions, including:

- Syntax (e.g., subject-verb agreement)
- Semantic relationships (e.g., co-reference)
- Long-range dependencies (e.g., linking beginning and end of a paragraph)
- Different positional patterns

Trying to encode all these diverse patterns with one single attention is restrictive and inefficient.

`The cat sat on the mat becuase it was tired`

1. One attention head might want to focus on resolving pronouns ("it" → "cat").
2. Another head might want to focus on positional relations (which word is next to which).
3. Another might focus on important keywords ("tired", "mat").

With only one head, the model tries to mix all these focuses into a single attention pattern, which can dilute the quality and richness of learned features.

**How Multi-Head attention solves this?**
Instead of having one set of Q, K, V projections, multi-head attention splits the model’s capacity into several heads, each learning different projections:

- Each head attends to different parts or aspects of the input independently.
- The outputs of all heads are concatenated and linearly transformed, allowing the model to jointly attend to information from different representation subspaces.

This way, each attention head can focus on a specific aspect (e.g., syntax vs semantics), reducing the conflict and improving overall performance.


![MH](media/mh.png)

In [26]:
import torch
import torch.nn as nn
import numpy as np
from torch.functional import F

torch.manual_seed(42)

<torch._C.Generator at 0x7be7dc110f70>

![MH](media/mhqkv.png)

In [27]:
input = torch.randn(2, 4, 9)
b, seq, d_model = input.size()
n_heads = 3

wq = nn.Linear(d_model, d_model)
wk = nn.Linear(d_model, d_model)
wv = nn.Linear(d_model, d_model)

dropout = nn.Dropout(0.4)

# input
Q = wq(input)
K = wk(input)
V = wv(input)   

print(f"Q:{Q.size()}, K: {K.size()}, V:{V.size()}")

Q:torch.Size([2, 4, 9]), K: torch.Size([2, 4, 9]), V:torch.Size([2, 4, 9])


In [28]:
# lets make sure the num head is correct
assert Q.size(0)//n_heads == 0, "d_model must be divisible by n_heads"

head_dim = d_model // n_heads

# lets split heads 
Q = Q.view(b, seq, n_heads, head_dim)
K = K.view(b, seq, n_heads, head_dim)
V = V.view(b, seq, n_heads, head_dim)

print(f"Q:{Q.size()}, K: {K.size()}, V:{V.size()}")

Q:torch.Size([2, 4, 3, 3]), K: torch.Size([2, 4, 3, 3]), V:torch.Size([2, 4, 3, 3])


**lets transpose the heads to get the right shape for matrix multiplication so that each head can attend to all tokens in the sequence.**


In [29]:
Q = Q.transpose(1,2) # (b, n_heads, seq, head_dim)
K = K.transpose(1,2) # (b, n_heads, seq, head_dim)
V = V.transpose(1,2) # (b, n_heads, seq, head_dim)

print(f"Q:{Q.size()}, K: {K.size()}, V:{V.size()}\nQ data:\n{Q}")

Q:torch.Size([2, 3, 4, 3]), K: torch.Size([2, 3, 4, 3]), V:torch.Size([2, 3, 4, 3])
Q data:
tensor([[[[-0.1936, -0.1699,  0.1577],
          [-1.3494,  0.1358, -0.4783],
          [-0.0148, -0.5717,  0.0037],
          [ 0.2234, -0.3064, -0.5063]],

         [[ 0.1638, -0.0595,  0.2435],
          [ 0.5134,  0.2944,  0.9202],
          [-0.1530,  0.1878, -0.7437],
          [-0.3977, -0.7100, -0.6491]],

         [[-0.0672, -1.8729,  1.1167],
          [-0.0991, -0.0460,  1.0454],
          [ 0.4609,  0.6065, -0.0997],
          [-0.0332, -1.4115,  0.9683]]],


        [[[-0.3469,  0.2196, -0.0338],
          [-0.4134, -0.3903, -0.3228],
          [ 0.0970, -0.2078,  0.5537],
          [ 0.2554, -1.3217, -0.7302]],

         [[ 0.8727, -0.3630,  1.0433],
          [-0.0310, -0.4599,  0.5259],
          [ 0.0860,  0.0208, -0.6648],
          [-0.4612, -1.1273, -0.9428]],

         [[-0.2181,  0.0292,  0.0248],
          [ 0.9741, -0.6951,  0.6917],
          [ 0.2260, -0.2608, -0.5675],

In [30]:
# creating mask
mask = torch.triu(torch.ones((seq, seq), dtype=torch.bool), diagonal=1) # (seq, seq)
# expanding mask across batch and heads dimensions
mask = mask.unsqueeze(0).unsqueeze(0).expand(b, n_heads, -1, -1) # (b,n_heads, seq, seq )
print(f"S:{mask.shape}\n{mask}")

S:torch.Size([2, 3, 4, 4])
tensor([[[[False,  True,  True,  True],
          [False, False,  True,  True],
          [False, False, False,  True],
          [False, False, False, False]],

         [[False,  True,  True,  True],
          [False, False,  True,  True],
          [False, False, False,  True],
          [False, False, False, False]],

         [[False,  True,  True,  True],
          [False, False,  True,  True],
          [False, False, False,  True],
          [False, False, False, False]]],


        [[[False,  True,  True,  True],
          [False, False,  True,  True],
          [False, False, False,  True],
          [False, False, False, False]],

         [[False,  True,  True,  True],
          [False, False,  True,  True],
          [False, False, False,  True],
          [False, False, False, False]],

         [[False,  True,  True,  True],
          [False, False,  True,  True],
          [False, False, False,  True],
          [False, False, False, False]]]]

In [31]:
# compute attention score 
mh_attention_score = Q @ K.transpose(2,3)
print(f"S:{mh_attention_score.shape}\n{mh_attention_score[1]}")

S:torch.Size([2, 3, 4, 4])
tensor([[[ 0.2569, -0.2507,  0.0530,  0.0344],
         [ 0.0221, -0.5462, -0.1012,  0.1261],
         [-0.2866, -0.0408, -0.2315, -0.1548],
         [-0.6317, -0.2327, -0.2567,  0.1943]],

        [[ 0.4582,  0.8343,  1.8337, -0.7559],
         [ 0.0463,  0.6201,  0.7675, -0.1347],
         [ 0.1563, -0.1418, -0.5324,  0.1354],
         [ 0.2650,  0.9619, -0.3649,  0.4824]],

        [[-0.0670, -0.0365, -0.0631, -0.4189],
         [ 0.2537,  0.5420,  0.5959,  1.8983],
         [ 0.2429, -0.4177, -0.3470,  0.4181],
         [ 0.2617,  0.0530,  0.0512,  0.6890]]], grad_fn=<SelectBackward0>)


In [32]:
# apply mask to prevent looking ahead (causla attention)
mh_attention_score.masked_fill_(mask, -torch.inf)
print(f"S:{mh_attention_score.shape}\n{mh_attention_score[1]}")

S:torch.Size([2, 3, 4, 4])
tensor([[[ 0.2569,    -inf,    -inf,    -inf],
         [ 0.0221, -0.5462,    -inf,    -inf],
         [-0.2866, -0.0408, -0.2315,    -inf],
         [-0.6317, -0.2327, -0.2567,  0.1943]],

        [[ 0.4582,    -inf,    -inf,    -inf],
         [ 0.0463,  0.6201,    -inf,    -inf],
         [ 0.1563, -0.1418, -0.5324,    -inf],
         [ 0.2650,  0.9619, -0.3649,  0.4824]],

        [[-0.0670,    -inf,    -inf,    -inf],
         [ 0.2537,  0.5420,    -inf,    -inf],
         [ 0.2429, -0.4177, -0.3470,    -inf],
         [ 0.2617,  0.0530,  0.0512,  0.6890]]], grad_fn=<SelectBackward0>)


In [33]:
# scale and apply softmax
attention_weight = F.softmax(mh_attention_score / (head_dim ** 0.5), dim=-1) # (b, n_heds, seq, seq)
print(f"Shape:{attention_weight.shape}\n{attention_weight[1]}")

Shape:torch.Size([2, 3, 4, 4])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5813, 0.4187, 0.0000, 0.0000],
         [0.3140, 0.3619, 0.3241, 0.0000],
         [0.1956, 0.2463, 0.2429, 0.3152]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4179, 0.5821, 0.0000, 0.0000],
         [0.3978, 0.3349, 0.2673, 0.0000],
         [0.2313, 0.3458, 0.1607, 0.2622]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4585, 0.5415, 0.0000, 0.0000],
         [0.4177, 0.2852, 0.2971, 0.0000],
         [0.2468, 0.2188, 0.2186, 0.3159]]], grad_fn=<SelectBackward0>)


In [34]:
# lets apply dropout 
attention_weight = dropout(attention_weight)
print(f'After Dropout:\n{attention_weight[1]}')

After Dropout:
tensor([[[0.0000, 0.0000, 0.0000, 0.0000],
         [0.9688, 0.0000, 0.0000, 0.0000],
         [0.5233, 0.6031, 0.0000, 0.0000],
         [0.3260, 0.0000, 0.0000, 0.0000]],

        [[1.6667, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.9701, 0.0000, 0.0000],
         [0.6630, 0.5582, 0.4455, 0.0000],
         [0.0000, 0.5764, 0.0000, 0.4370]],

        [[1.6667, 0.0000, 0.0000, 0.0000],
         [0.7641, 0.9025, 0.0000, 0.0000],
         [0.0000, 0.4754, 0.0000, 0.0000],
         [0.4113, 0.0000, 0.0000, 0.0000]]], grad_fn=<SelectBackward0>)


In [35]:
# compute attention output V @ attention weight
context_vector = attention_weight @ V # (b, n_heads, seq, head_dim)
print(f"Shape:{context_vector.shape}\n{context_vector[1]}")

Shape:torch.Size([2, 3, 4, 3])
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.5083,  0.6182,  0.3072],
         [ 0.4235,  0.4775, -0.2888],
         [ 0.1710,  0.2080,  0.1034]],

        [[ 1.4620,  0.1345, -1.6095],
         [ 0.1398, -0.9007, -0.8149],
         [ 0.9255, -0.4715, -1.2611],
         [-0.2885, -0.5021,  0.0798]],

        [[ 0.3246,  0.7681,  0.8544],
         [ 1.0397,  1.1539,  0.8462],
         [ 0.4692,  0.4223,  0.2394],
         [ 0.0801,  0.1896,  0.2109]]], grad_fn=<SelectBackward0>)


In [36]:
# Now lets combine the heads d_model = n_heads * head_dim. Mostly the tensor is contiguous after all these operations, but just to be safe we can use .contiguous()
context_vector = context_vector.contiguous().view(b, seq, d_model) # (b, seq, d_model)
print(f"Shape:{context_vector.shape}\n{context_vector[1]}")

Shape:torch.Size([2, 4, 9])
tensor([[ 0.0000,  0.0000,  0.0000,  0.5083,  0.6182,  0.3072,  0.4235,  0.4775,
         -0.2888],
        [ 0.1710,  0.2080,  0.1034,  1.4620,  0.1345, -1.6095,  0.1398, -0.9007,
         -0.8149],
        [ 0.9255, -0.4715, -1.2611, -0.2885, -0.5021,  0.0798,  0.3246,  0.7681,
          0.8544],
        [ 1.0397,  1.1539,  0.8462,  0.4692,  0.4223,  0.2394,  0.0801,  0.1896,
          0.2109]], grad_fn=<SelectBackward0>)


**Multi Head attention**

In [37]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, qkv_bias=True, dropout=0.3):
        # here we could have used d_in, and d_out. But the assumption is d_models is the same as the input size embedding and desired output
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.n_heads = n_heads
        self.d_model = d_model
        self.head_dim = int(d_model / n_heads)

        self.wq = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.wk = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.wv = nn.Linear(d_model, d_model, bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)

        self.wo = nn.Linear(d_model, d_model)

    def forward(self, x):
        b, seq, d_model = x.size()
        assert self.d_model == d_model, "Input size must be equal as passed into the model"
        Q = self.wq(x) # (b,seq,d_model) 
        K = self.wk(x) # (b,seq,d_model)
        V = self.wv(x) # (b,seq,d_model)

        # split heads and transpose 
        # (b,seq,d_model) -> (b, seq, n_heads, head_dim) -> (b, n_heads, seq, head_dim)
        Q = Q.view(b, seq, self.n_heads, self.head_dim).transpose(1, 2) 
        K = K.view(b, seq, self.n_heads, self.head_dim).transpose(1,2)
        V = V.view(b, seq, self.n_heads, self.head_dim).transpose(1,2)

        # compute the attention scores
        attn_scores = Q @ K.transpose(2,3)
        # create a boolean mask of (seq, seq)
        mask = torch.triu(torch.ones((seq, seq), dtype=torch.bool), diagonal=1)
        # expand it to match the attention scores shape (b, n_heads, seq, seq)
        mask = mask.unsqueeze(0).unsqueeze(0).expand(b, self.n_heads, -1, -1)

        # create masked attention score 
        attn_scores = attn_scores.masked_fill_(mask, -torch.inf)

        # compute the scales attention weights 
        attn_weights = F.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1)

        # apply dropout to the attention weights
        attn_weights = self.dropout(attn_weights)
        
        # compute the output of the scaled dot product attention
        # (b, n_heads, seq, seq) @ (b, n_heads, seq, head_dim) -> (b, n_heads, seq, head_dim)
        context_vec = attn_weights @ V
        # (b, seq, n_heads, head_dim)
        context_vec = context_vec.transpose(1, 2)
        #reshape it back to (b, seq, d_model)
        context_vec = context_vec.contiguous().view(b, seq, self.d_model)
        output = self.wo(context_vec)
        return output



In [38]:
mha = MultiHeadAttention(n_heads, d_model)
output = mha(input)
print(f"Shape:{output.shape}\n{output[0]}")

Shape:torch.Size([2, 4, 9])
tensor([[ 0.3305,  0.8490, -0.7650, -0.5499, -0.5673,  0.6997,  0.4305, -1.4169,
         -0.1181],
        [ 0.0656,  0.1773, -0.2291, -0.2755,  0.0037,  0.8044, -0.2406, -0.8244,
         -0.1588],
        [-0.2307,  0.1386,  0.0015, -0.3392,  0.2659,  0.7047, -0.2284, -0.4015,
         -0.2836],
        [ 0.3415,  0.0311, -0.0630, -0.3728,  0.0957,  0.3276,  0.0122, -0.6248,
         -0.0620]], grad_fn=<SelectBackward0>)
