# 2. Self-Attention

Self-attention is when a sequence attends to itself. Each position can attend to all positions, including itself!
This is the building block of transformers.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


## Self-Attention: Q, K, V from Same Input

In self-attention, Q, K, and V all come from the same input sequence.
We use learned linear transformations to create Q, K, V from the input.


In [None]:
# Step 1: Create Q, K, V from input
seq_len, d_model = 5, 8

# Input sequence (e.g., word embeddings)
x = torch.randn(seq_len, d_model)
print(f"Input shape: {x.shape}")

# Linear transformations to create Q, K, V
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

Q = W_q(x)  # Query
K = W_k(x)  # Key
V = W_v(x)  # Value

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print("\nIn self-attention, Q, K, V all come from the same input x!")


## Self-Attention Layer

Let's build a complete self-attention layer from scratch!


In [None]:
class SelfAttention(nn.Module):
    """Self-attention layer from scratch"""
    
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        
        # Linear layers to create Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, x):
        """
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
        Returns:
            output: [batch_size, seq_len, d_model]
            attention_weights: [batch_size, seq_len, seq_len]
        """
        batch_size, seq_len, d_model = x.shape
        
        # Create Q, K, V
        Q = self.W_q(x)  # [batch_size, seq_len, d_model]
        K = self.W_k(x)  # [batch_size, seq_len, d_model]
        V = self.W_v(x)  # [batch_size, seq_len, d_model]
        
        # Compute attention scores: QK^T
        scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch_size, seq_len, seq_len]
        
        # Scale by sqrt(d_model)
        scores = scores / np.sqrt(d_model)
        
        # Softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)  # [batch_size, seq_len, seq_len]
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)  # [batch_size, seq_len, d_model]
        
        return output, attention_weights

# Test the self-attention layer
d_model = 64
seq_len = 10
batch_size = 2

self_attn = SelfAttention(d_model)
x = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = self_attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights for first sequence, first position:")
print(attn_weights[0, 0])
print(f"Sum: {attn_weights[0, 0].sum():.3f} (should be 1.0)")
