## Multi-Head Self-Attention (With Causal Masking)

It helps the model understand which words in a sentence are important to each other.

### How Self-attention works?

1. Each word in the sentence is turned into a number.
2. Each word creates three special versions of itself:

    a. Query (Q) -> "What am I looking for?"
   
    b. Key (K) -> "What do I contain?"
   
    c. Value (V) -> "What information do I give?"
4. The model compares each word with every other word using dot product.
    a. This tells the model how much attention one word should pay to another.
5. The model applies a softmax function to make all attention values sum to 1.
    a. This means each word gets a percentage of importance.
6. The final output is a weighted sum of the Values (V) based on attention scores.

## What is Multi-Head Attention?

Instead of doing this once, we do it multiple times (heads) to capture different relationships.
- One head may focus on subjects (cat -> sat).
- Another head may focus on objects (sat -> mat).


## What is Casual Masking?

When generating text, we dont want a word to see future words. We use a mask to block attention to future words.

For example, in:
"The cat sat..."
When generating sat, the model should only look at "The cat" and NOT at words ahead. So, we set future words to 0 in the attention matrix.

## Why is Self-Attention Important?

- It helps model focus on relevant words while ignoring unrelated ones.
- It understands word relationships in a sentence.
- Unlike older models (like RNNs), it can look at the entire sentence at once.

## Math of Multi Head Self Attention
- Step 1: Convert words to vectors (embeddings)
- Step 2: Create Query(Q), Key(K) and Value (V) Matrices.

To get Q, K, and V, we multiply the embeddings by weight matrices:

- Q = X . W<sub>Q</sub>
- K = X . W<sub>K</sub>
- V = X . W<sub>V</sub>

## Compute Attention Scores Using Scaled Dot-Product

After getting Q and K, we compute the attention scores using the following dot product.

Attention Score = Q . K<sup>T</sup>

- The dot product b/w Q and K tells us how similar two words aer.
- If two words are similar, their dot product is high and vice versa.


## Apply Scaling

We divide above attention scores by √(d<sub>k</sub>)

Scaled Score = QK<sup>T</sup>/√(d<sub>k</sub>)

## Apply Softmax

After scaling, we apply the softmax function row-wise:

Softmax(Scaled Scores)

### Why Softmax?
- Converts scores into probabilities
- Each row sums to 1, meaning the each word "distributes" its attention properly.

### Compute Weighted Sum with V (Values)

Output = Softmax Scores x V

In [1]:
import torch
import torch.nn.functional as F

In [2]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    attn_scores = torch.matmul(Q, K.transpose(-2, -1))
    attn_scores = attn_scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    if mask is not None:
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

    attn_probs = F.softmax(attn_scores, dim=-1)
    output = torch.matmul(attn_probs, V)
    return output, attn_probs

In [3]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_k = embed_size // num_heads  # Dimension per head
        
        # Linear layers to project input to Q, K, V
        self.W_q = torch.nn.Linear(embed_size, embed_size)
        self.W_k = torch.nn.Linear(embed_size, embed_size)
        self.W_v = torch.nn.Linear(embed_size, embed_size)
        
        # Final linear layer after concatenation of heads
        self.W_o = torch.nn.Linear(embed_size, embed_size)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.shape[0]
        
        # Apply linear transformations and split into multiple heads
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Compute scaled dot-product attention
        attention_output, attn_probs = scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and pass through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        output = self.W_o(attention_output)
        
        return output, attn_probs

In [4]:
# Test Case
batch_size = 1
seq_len = 3
embed_size = 8
num_heads = 2

torch.manual_seed(42)
input_tensor = torch.rand(batch_size, seq_len, embed_size)

# Create Multi-Head Attention Module
multi_head_attn = MultiHeadAttention(embed_size, num_heads)

# Forward Pass
output, attn_probs = multi_head_attn(input_tensor, input_tensor, input_tensor)

print("Output Shape:", output.shape)  # Expected: (1, 3, 8)
print("Attention Weights:", attn_probs)

Output Shape: torch.Size([1, 3, 8])
Attention Weights: tensor([[[[0.3297, 0.3297, 0.3406],
          [0.3209, 0.3352, 0.3439],
          [0.3286, 0.3308, 0.3406]],

         [[0.3601, 0.2791, 0.3608],
          [0.3198, 0.3448, 0.3354],
          [0.3501, 0.3008, 0.3492]]]], grad_fn=<SoftmaxBackward0>)


In [5]:
# Test Case with Actual Words
vocab = {"she": 0, "loves": 1, "cats": 2, "dogs": 3}
embed_size = 8
num_heads = 2
seq_len = 3
batch_size = 1

# Define Embedding Layer
torch.manual_seed(42)
embedding_layer = torch.nn.Embedding(len(vocab), embed_size)

# Sample sentence: "She loves cats"
input_tokens = torch.tensor([[vocab["she"], vocab["loves"], vocab["cats"]]])  # Shape: (1, 3)
input_embeddings = embedding_layer(input_tokens)  # Convert tokens to embeddings

# Create Multi-Head Attention Module
multi_head_attn = MultiHeadAttention(embed_size, num_heads)

# Forward Pass
output, attn_probs = multi_head_attn(input_embeddings, input_embeddings, input_embeddings)

print("Input Words: ['She', 'loves', 'cats']")
print("Output Shape:", output.shape)  # Expected: (1, 3, 8)
print("Attention Weights:", attn_probs)


Input Words: ['She', 'loves', 'cats']
Output Shape: torch.Size([1, 3, 8])
Attention Weights: tensor([[[[0.1895, 0.1881, 0.6225],
          [0.2294, 0.3782, 0.3925],
          [0.3702, 0.4393, 0.1906]],

         [[0.4134, 0.3067, 0.2798],
          [0.2474, 0.3777, 0.3750],
          [0.3383, 0.3749, 0.2868]]]], grad_fn=<SoftmaxBackward0>)
