# üöÄ Day 1 Morning: Attention Mechanism from Scratch

## Learning Objectives
- ‚úÖ Understand the attention mechanism deeply
- ‚úÖ Implement scaled dot-product attention in PyTorch
- ‚úÖ Visualize attention weights
- ‚úÖ Understand Q, K, V matrices

## Key Formula
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- **Q** (Query): What we're looking for
- **K** (Key): What we're looking at
- **V** (Value): What we actually get
- **d_k**: Dimension of keys (for scaling)

In [None]:
# Install required packages (Colab already has most of these)
!pip install -q torch matplotlib seaborn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("‚úÖ Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")

## 1Ô∏è‚É£ Implement Scaled Dot-Product Attention

In [None]:
class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention
    
    Formula: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
    """
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k
        
    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q: Query (batch_size, seq_len, d_k)
            K: Key (batch_size, seq_len, d_k)
            V: Value (batch_size, seq_len, d_v)
            mask: Optional mask
        Returns:
            output: (batch_size, seq_len, d_v)
            attention_weights: (batch_size, seq_len, seq_len)
        """
        # Step 1: Compute attention scores (QK^T)
        scores = torch.matmul(Q, K.transpose(-2, -1))
        
        # Step 2: Scale by sqrt(d_k)
        scores = scores / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        
        # Step 3: Apply mask (optional)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Step 4: Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Step 5: Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

print("‚úÖ Attention class defined!")

## 2Ô∏è‚É£ Visualization Helper Function

In [None]:
def visualize_attention(attention_weights, tokens=None, title="Attention Weights"):
    """
    Visualize attention weights as a heatmap
    """
    plt.figure(figsize=(10, 8))
    
    weights = attention_weights.detach().cpu().numpy()
    
    sns.heatmap(weights, annot=True, fmt='.2f', cmap='viridis',
                xticklabels=tokens if tokens else range(weights.shape[1]),
                yticklabels=tokens if tokens else range(weights.shape[0]),
                cbar_kws={'label': 'Attention Weight'})
    
    plt.xlabel('Key Position', fontsize=12)
    plt.ylabel('Query Position', fontsize=12)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

print("‚úÖ Visualization function ready!")

## 3Ô∏è‚É£ Example 1: Simple Attention

In [None]:
print("="*60)
print("Example 1: Simple Attention Mechanism")
print("="*60)

# Hyperparameters
batch_size = 1
seq_len = 4
d_k = 8  # Dimension of queries and keys
d_v = 8  # Dimension of values

# Create random Q, K, V matrices
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)

print(f"\nInput shapes:")
print(f"  Q (Query): {Q.shape}")
print(f"  K (Key): {K.shape}")
print(f"  V (Value): {V.shape}")

# Apply attention
attention = ScaledDotProductAttention(d_k)
output, attention_weights = attention(Q, K, V)

print(f"\nOutput shapes:")
print(f"  Output: {output.shape}")
print(f"  Attention weights: {attention_weights.shape}")

print(f"\nüìä Attention weights (should sum to 1 for each query):")
print(attention_weights[0])
print(f"\nSum per row: {attention_weights[0].sum(dim=-1)}")

# Visualize
tokens = ['The', 'cat', 'sat', 'down']
visualize_attention(attention_weights[0], tokens, "Simple Attention Example")

## 4Ô∏è‚É£ Example 2: Self-Attention

In self-attention, Q, K, V all come from the same input!

In [None]:
print("="*60)
print("Example 2: Self-Attention")
print("="*60)

batch_size = 1
seq_len = 5
d_model = 16

# Simulate input embeddings
X = torch.randn(batch_size, seq_len, d_model)

# Linear projections to get Q, K, V from same input
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

Q = W_q(X)
K = W_k(X)
V = W_v(X)

print(f"\nInput X shape: {X.shape}")
print(f"Q, K, V shapes: {Q.shape}")

# Apply self-attention
attention = ScaledDotProductAttention(d_model)
output, attention_weights = attention(Q, K, V)

print(f"\nSelf-attention output shape: {output.shape}")

# Visualize
tokens = ['I', 'love', 'machine', 'learning', '!']
visualize_attention(attention_weights[0], tokens, "Self-Attention Example")

## 5Ô∏è‚É£ Example 3: Masked (Causal) Attention

Used in GPT-like models - each position can only attend to itself and previous positions!

In [None]:
print("="*60)
print("Example 3: Masked (Causal) Attention")
print("="*60)

batch_size = 1
seq_len = 5
d_k = 8

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

# Create causal mask (lower triangular)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)

print(f"\nCausal mask (1 = can attend, 0 = cannot attend):")
print(mask[0])

# Apply masked attention
attention = ScaledDotProductAttention(d_k)
output, attention_weights = attention(Q, K, V, mask)

print(f"\nüìä Masked attention weights:")
print(attention_weights[0])
print("\nüí° Notice: Each position only attends to itself and previous positions!")

# Visualize
tokens = ['The', 'cat', 'is', 'very', 'cute']
visualize_attention(attention_weights[0], tokens, "Causal (Masked) Attention")

## üìù Key Takeaways

1. **Attention allows the model to focus on relevant parts of the input**
   - Each query attends to all keys with different weights
   - Weights sum to 1 (softmax normalization)

2. **Scaling by ‚àöd_k prevents softmax saturation**
   - Without scaling, dot products can become very large
   - Large values ‚Üí softmax becomes too peaked ‚Üí gradients vanish

3. **Masks enable different attention patterns**
   - Causal mask: For autoregressive models (GPT)
   - Padding mask: Ignore padding tokens

4. **Attention weights are interpretable**
   - We can visualize what the model focuses on
   - Useful for debugging and understanding model behavior

## üéØ Next Steps

Continue to **Day 1 Afternoon: Multi-Head Attention** to learn:
- Why use multiple attention heads?
- How do heads specialize?
- Implementing multi-head attention from scratch

## üß™ Experiment!

Try modifying the code:
1. Change `d_k` and observe how attention patterns change
2. Create your own sentences and visualize attention
3. Try different mask patterns
4. What happens without scaling (remove `/ sqrt(d_k)`)?