In [2]:
import torch
import torch.nn as nn

### Define input of size(no of samples = 20, seq_length = 10, embed_dim = 8)

In [16]:
n = 20
seq_len = 10
embed_dim = 8
batch = 2
no_head = 2
input_data = torch.arange(n * seq_len * embed_dim).float().reshape(n, seq_len, embed_dim)
input_data.shape

torch.Size([20, 10, 8])

### Now intialize wq, wk, wv of shape (embed_dim, embed_dim) bcz the output must match the input_data dim (batch, seq_len, embed_dim)

In [17]:
wq = nn.Linear(embed_dim, embed_dim, bias=False)
wk = nn.Linear(embed_dim, embed_dim, bias=False)
wv = nn.Linear(embed_dim, embed_dim, bias=False)

wq.weight.shape, wv.weight.shape, wk.weight.shape

(torch.Size([8, 8]), torch.Size([8, 8]), torch.Size([8, 8]))

### Calculate q, k, v vector with input_data later looking into whole data through loop

In [18]:
q, k , v = wq(input_data[:2]), wk(input_data[:2]), wv(input_data[:2])
q.shape, k.shape, v.shape

(torch.Size([2, 10, 8]), torch.Size([2, 10, 8]), torch.Size([2, 10, 8]))

### split the data to both the heads(single self attention shouldn't get whole sequence rather a slice it should get another head get remaining)
- i.e. each head should get (embed_dim / no_head) to capture different perspective of sentence
- so q, k, v vector dimension should become (batch, no_head, seq_len, embed_dim / no_head) lets make it

In [22]:
assert embed_dim / no_head, "embed_dim / no_head should leave reminder 0."
q = q.reshape(batch, no_head, seq_len, int(embed_dim / no_head))
k = k.reshape(batch, no_head, seq_len, int(embed_dim / no_head))
v = v.reshape(batch, no_head, seq_len, int(embed_dim / no_head))

q.shape, k.shape, v.shape

(torch.Size([2, 2, 10, 4]),
 torch.Size([2, 2, 10, 4]),
 torch.Size([2, 2, 10, 4]))

### Calculate the attention score

- Attention score = (Q . KT) / sqrt(d) 
- Attention  = softmax(Attention score) . V

In [33]:
import math
K_transpose = k.transpose(-2, -1) # only make transpose to seq_length & embedding dimensio
# K_transpose.shape
head_dim = embed_dim // no_head
# dim = -1 make sure it apply softmax along seq_length axis
# (batch, num_head, seq_length, embed_dim) . (batch, num_head, embed_dim, seq_length) = (batch, num_head, seq_length, seq_length)
attention_score = torch.softmax((q @ K_transpose) / math.sqrt(head_dim), dim=-1)
# (batch, num_head, seq_length, seq_length) . (batch, num_head, seq_length,  head_dim) = (batch, num_head, seq_length,  head_dim)
attention = attention_score @ v
attention

tensor([[[[ 1.0065e+01,  2.5342e+01,  1.2975e+01,  1.5621e+00],
          [ 2.3878e+00, -3.3438e+00, -1.7794e+00,  2.0573e+00],
          [ 9.8347e+00,  2.9392e+01,  1.5190e+01,  2.9186e-01],
          [ 2.3956e+00, -3.4459e+00, -1.8458e+00,  2.0694e+00],
          [ 9.8347e+00,  2.9392e+01,  1.5190e+01,  2.9186e-01],
          [ 2.3956e+00, -3.4460e+00, -1.8459e+00,  2.0694e+00],
          [ 9.8347e+00,  2.9392e+01,  1.5190e+01,  2.9186e-01],
          [ 2.3956e+00, -3.4460e+00, -1.8459e+00,  2.0694e+00],
          [ 9.8347e+00,  2.9392e+01,  1.5190e+01,  2.9186e-01],
          [ 2.3956e+00, -3.4460e+00, -1.8459e+00,  2.0694e+00]],

         [[ 1.9714e+01,  6.3272e+01,  3.1811e+01, -1.0734e+00],
          [ 1.1811e+01,  3.6168e+01,  1.8514e+01,  1.8804e-02],
          [ 1.9714e+01,  6.3272e+01,  3.1811e+01, -1.0734e+00],
          [ 1.1811e+01,  3.6168e+01,  1.8514e+01,  1.8804e-02],
          [ 1.9714e+01,  6.3272e+01,  3.1811e+01, -1.0734e+00],
          [ 1.1811e+01,  3.6168e+01,  

### Combine all attention outputs
- Step 1: attention = (batch, num_head, seq_length,  head_dim) now i want to make it (batch, seq_length, embed_dim)
      
        - convert (batch, num_head, seq_length,  head_dim) -> (batch, seq_length,  num_head,  head_dim)
        - then reshape (batch, seq_length,  num_head,  head_dim) -> (batch, seq_length,  num_head * head_dim)

In [41]:
print(attention.shape)
attention = attention.transpose(2, 1)
attention = attention.reshape(batch, seq_len, embed_dim)
print(attention.shape)

torch.Size([2, 10, 8])
torch.Size([2, 10, 8])


### Final Linear Projection:- Mixing Information Across Heads
Each attention head learns different relationships between tokens:

- Head 1 might capture position-based attention.

- Head 2 might focus on syntax.

- Head 3 might capture coreference resolution.

When you concatenate these heads → you get embed_dim = num_heads × head_dim. But it's just raw concatenation — there’s no interaction between heads yet.

 - The final linear projection mixes and combines the information from each head, similar to how an MLP might blend features.

In [47]:
# Dimension for output_proj is (embed_dim, embed_dim) bcz attention: (batch, seq_length,  embed_dim) . (embed_dim, embed_dim) = (batch, seq_length,  embed_dim)
output_projection = nn.Linear(embed_dim, embed_dim)
mha_output = output_projection(attention)

In [None]:
import math
import torch
import torch.nn as nn
from typing import List, Annotated

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: Annotated[int, "No of self attention needed"],
                 embed_dim: Annotated[int, "dimension of each word"],
                 seq_length: Annotated[int, "Length of sentence after padding"],
                 bias : Annotated[bool, "Required bias during trining"] = False,
                 mask: Annotated[bool, "normal MHA or masked MHA?"] = False) -> None:
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim % num_heads != 0"
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.embed_dim = embed_dim
        self.head_dim = self.embed_dim // self.num_heads
        self.wq = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.wk = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.wv = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.output_projection = nn.Linear(self.embed_dim, self.embed_dim)
        self.require_mask = mask
        print("All parameters are set for multihead attention")

    def forward(self, batched_input_data:Annotated[torch.Tensor, "batch of data from the input data"]) -> torch.Tensor:
        batch = batched_input_data.size(0)
        q = self.wq(batched_input_data)
        k = self.wk(batched_input_data)
        v = self.wv(batched_input_data)

        # Split the q, k, v(embed_dim) dimension as (num_head, embed_dim / num_head)
        q = q.reshape(batch, self.seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(batch, self.seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(batch, self.seq_length, self.num_heads, self.head_dim).transpose(1, 2)


        # Calculate Attention
        k_transpose = k.transpose(-2, -1)
        score = (q @ k_transpose) / math.sqrt(self.head_dim)
        mask = torch.triu(torch.ones(self.seq_length, self.seq_length), diagonal=1).bool() if self.require_mask else torch.zeros(self.seq_length, self.seq_length).bool()

        # Anyhow broadcasting works no need of unsqeeze but its good practice to 
        # avoid broadcasting in Attentions, but clearly this step is optional
        mask = mask.unsqueeze(0).unsqueeze(0)
        score = score.masked_fill(mask, float("-inf"))
        attention_score = torch.softmax(score, dim=-1)
        attention = attention_score @ v

        # concat output of all heads
        attention = attention.transpose(1, 2)
        attention = attention.reshape(batch, self.seq_length, self.embed_dim)

        # Since they are simple concatination to acutally mix all heads details we need a linear layer
        
        mha_output = self.output_projection(attention)
        return mha_output
    
    def __call__(self, batched_input_data:Annotated[torch.Tensor, "batch of data from the input data"]):
        return self.forward(batched_input_data=batched_input_data)

In [91]:
mha = MultiHeadAttention(2, 10, 8)

All parameters are set for multihead attention


In [95]:
from torch.utils.data import DataLoader, TensorDataset

batch_size = 2
data = torch.randn(100, 10, 8)
dataset = TensorDataset(data)
loader = DataLoader(dataset, batch_size=batch_size)

mha = MultiHeadAttention(embed_dim=8, num_heads=2, seq_length=10)

for batch in loader:
    x = batch[0]  # shape: (2, 10, 8)
    out = mha(x)  # shape: (2, 10, 8)
    print(out.shape)
    print(out)
    break  # test one batch for now


All parameters are set for multihead attention
torch.Size([2, 10, 8])
tensor([[[-0.2354, -0.1877,  0.1168, -0.0770,  0.0946, -0.1002, -0.0696,
           0.3571],
         [-0.0869, -0.0991,  0.1638, -0.1026, -0.2098, -0.3868,  0.1530,
           0.3094],
         [-0.2698, -0.1626,  0.1131, -0.0524,  0.0779, -0.0901, -0.0244,
           0.3530],
         [-0.1712, -0.1059,  0.1276, -0.2048,  0.0497, -0.2817,  0.0845,
           0.3163],
         [-0.1566, -0.1520,  0.1457, -0.1357,  0.0571, -0.2097,  0.0937,
           0.3678],
         [-0.1077, -0.1232,  0.1558, -0.1313, -0.0856, -0.3185,  0.1598,
           0.3364],
         [-0.2102, -0.1894,  0.1443, -0.1102,  0.1576, -0.0961,  0.0029,
           0.3992],
         [-0.1075, -0.1111,  0.1154, -0.1638, -0.1535, -0.3867,  0.1337,
           0.2721],
         [-0.0326, -0.0999,  0.1546, -0.2402, -0.1516, -0.4747,  0.1807,
           0.2834],
         [-0.1586, -0.0969,  0.1165, -0.1488, -0.1341, -0.3506,  0.1090,
           0.2774]],

In [60]:
n = 20
seq_len = 10
embed_dim = 8
batch = 2
no_head = 2
input_data = torch.arange(n * seq_len * embed_dim).float().reshape(n, seq_len, embed_dim)
input_data.shape

torch.Size([20, 10, 8])

In [81]:
torch.triu(torch.ones(6, 6), diagonal=1).bool()

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])

In [82]:
scores = torch.tensor([
    [0.1, 0.3, 0.5, 0.7],
    [0.2, 0.4, 0.6, 0.8],
    [0.9, 0.1, 0.2, 0.3],
    [0.3, 0.4, 0.5, 0.6]
])

In [87]:
mask = torch.triu(torch.ones(4, 4), diagonal=1).bool()
mask

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [88]:
scores.masked_fill(mask, float("-inf"))

tensor([[0.1000,   -inf,   -inf,   -inf],
        [0.2000, 0.4000,   -inf,   -inf],
        [0.9000, 0.1000, 0.2000,   -inf],
        [0.3000, 0.4000, 0.5000, 0.6000]])

In [89]:
torch.zeros(5, 5).bool()

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

### But if you carefully observe transformers archetecture there are 3 types of attention 1) Multi-Head Attention 2) Masked-Multi head Attention 3) Cross-Multihead attention

### But if you observe our implementation till now it only suppoerts MHA & MMHA

- Now to integrate CMHA with our code we need to tweek changes in forward pass
- forward method will accept 3 parameter (q_input, k_input = None, v_input = None) if k_input & v_input is None which means self attention else cross attention
- also make changes in the dimension i.e. instead of seq_len we need to consider both qdim(output seq) and kdim(inp seq)

### After making above changes final MultiHead attention class looks something like this

In [2]:
import math
import torch
import torch.nn as nn
from typing import List, Annotated, Union


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: Annotated[int, "No of self attention needed"],
                 embed_dim: Annotated[int, "dimension of each word"],
                 bias: Annotated[bool, "Required bias during trining"] = False,) -> None:
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim % num_heads != 0"
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = self.embed_dim // self.num_heads
        self.wq = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.wk = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.wv = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.output_projection = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self,
                q_input: Annotated[torch.Tensor, "batch of data from the input data"],
                k_input: Union[torch.Tensor, None] = None,
                v_input: Union[torch.Tensor, None] = None,
                mask: Annotated[bool, "normal MHA or masked MHA?"] = False) -> torch.Tensor:
        batch = q_input.size(0)

        # Below code make sures algorithm is self attention not cross attention
        if k_input is None:
            k_input = q_input
        if v_input is None:
            v_input = q_input

        q = self.wq(q_input)
        k = self.wk(k_input)
        v = self.wv(v_input)

        # Sequence length can be diffrent for input data and output data
        T_q, T_k = q_input.size(1), k_input.size(1)

        # Split the q, k, v(embed_dim) dimension as (num_head, embed_dim / num_head)
        q = q.reshape(batch, T_q, self.num_heads,
                      self.head_dim).transpose(1, 2)
        k = k.reshape(batch, T_k, self.num_heads,
                      self.head_dim).transpose(1, 2)
        v = v.reshape(batch, T_k, self.num_heads,
                      self.head_dim).transpose(1, 2)

        # Calculate Attention
        k_transpose = k.transpose(-2, -1)  # (b, k, d) (b, d, p) = (b, k, p)
        score = (q @ k_transpose) / math.sqrt(self.head_dim)
        mask = torch.triu(torch.ones(T_q, T_k), diagonal=1).bool(
        ) if mask else torch.zeros(T_q, T_k).bool()

        # Anyhow broadcasting works no need of unsqeeze but its good practice to
        # avoid broadcasting in Attentions, but clearly this step is optional
        mask = mask.unsqueeze(0).unsqueeze(0)
        score = score.masked_fill(mask, float("-inf"))
        attention_score = torch.softmax(score, dim=-1)
        attention = attention_score @ v

        # concat output of all heads
        attention = attention.transpose(1, 2)

        # Attention should have sequence of length = output sequence length.
        attention = attention.reshape(batch, T_q, self.embed_dim)

        # Since they are simple concatination to acutally mix all heads details we need a linear layer

        mha_output = self.output_projection(attention)
        return mha_output


### Testing Cross-Attention with simple example

In [3]:
# --- Setup for the Test ---

# 1. Define hyperparameters
batch_size = 32
embed_dim = 256
num_heads = 8
decoder_seq_len = 15  # The length of the target sequence
encoder_seq_len = 20  # The length of the source sequence

# 2. Instantiate the MultiHeadAttention layer
mha_layer = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim)

# 3. Create the input data
# This represents the decoder's input (e.g., the partially generated translation)
decoder_input = torch.randn(batch_size, decoder_seq_len, embed_dim) 

# This represents the encoder's output, which provides the context
encoder_output = torch.randn(batch_size, encoder_seq_len, embed_dim)

print("--- Testing Cross-Attention ---")
print(f"Query Input Shape (from decoder):      {decoder_input.shape}")
print(f"Key/Value Input Shape (from encoder):  {encoder_output.shape}\n")

# --- Perform the Cross-Attention ---

# The decoder_input provides the queries.
# The encoder_output provides the keys and values.
output = mha_layer(q_input=decoder_input, k_input=encoder_output, v_input=encoder_output)

print(f"Final Output Shape: {output.shape}")
print("\nNotice the output sequence length matches the query's sequence length.")


--- Testing Cross-Attention ---
Query Input Shape (from decoder):      torch.Size([32, 15, 256])
Key/Value Input Shape (from encoder):  torch.Size([32, 20, 256])

Final Output Shape: torch.Size([32, 15, 256])

Notice the output sequence length matches the query's sequence length.
