# Lab 1: Transformer Attention Mechanism

**Module 1 - Foundations of Modern LLMs**

| Duration | Difficulty | Framework | Exercises |
|----------|------------|-----------|----------|
| 90 min | Intermediate | PyTorch | 4 |

## Learning Objectives

- Implement scaled dot-product attention from scratch
- Build multi-head attention mechanism
- Visualize attention patterns
- Understand masking for causal attention

## Setup

First, let's install and import the required libraries.

In [None]:
# Install dependencies if needed
# !pip install torch numpy matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

---

## Exercise 1: Scaled Dot-Product Attention

Implement the core attention mechanism used in transformers.

**Formula:**
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- Q (Query): What am I looking for?
- K (Key): What do I contain?
- V (Value): What information do I provide?
- $d_k$: Dimension of keys (scaling factor)

**Your Task:** Complete the `scaled_dot_product_attention` function below.

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query tensor of shape (batch, seq_len, d_k)
        K: Key tensor of shape (batch, seq_len, d_k)
        V: Value tensor of shape (batch, seq_len, d_v)
        mask: Optional mask tensor
    
    Returns:
        output: Attention output
        attention_weights: Attention weight matrix
    """
    d_k = Q.size(-1)
    
    # TODO: Step 1 - Compute attention scores (Q @ K^T)
    scores = None  # Your code here
    
    # TODO: Step 2 - Scale by sqrt(d_k)
    scores = None  # Your code here
    
    # TODO: Step 3 - Apply mask if provided (set masked positions to -inf)
    if mask is not None:
        pass  # Your code here
    
    # TODO: Step 4 - Apply softmax to get attention weights
    attention_weights = None  # Your code here
    
    # TODO: Step 5 - Multiply by values
    output = None  # Your code here
    
    return output, attention_weights

In [None]:
# Test your implementation
batch_size = 2
seq_len = 4
d_k = 8

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Output shape: {output.shape}")  # Expected: (2, 4, 8)
print(f"Weights shape: {weights.shape}")  # Expected: (2, 4, 4)
print(f"Weights sum per row: {weights.sum(dim=-1)}")  # Should be all 1s

---

## Exercise 2: Multi-Head Attention

Implement multi-head attention which allows the model to attend to different representation subspaces.

**Formula:**
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

Where each head is:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

**Your Task:** Complete the `MultiHeadAttention` class below.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # TODO: Create linear projections for Q, K, V and output
        self.W_q = None  # Your code here
        self.W_k = None  # Your code here
        self.W_v = None  # Your code here
        self.W_o = None  # Your code here
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # TODO: Step 1 - Linear projections
        Q = None  # Your code here
        K = None  # Your code here
        V = None  # Your code here
        
        # TODO: Step 2 - Reshape for multi-head: (batch, seq, d_model) -> (batch, heads, seq, d_k)
        Q = None  # Your code here
        K = None  # Your code here
        V = None  # Your code here
        
        # TODO: Step 3 - Apply scaled dot-product attention
        attn_output, attn_weights = None, None  # Your code here
        
        # TODO: Step 4 - Concatenate heads: (batch, heads, seq, d_k) -> (batch, seq, d_model)
        attn_output = None  # Your code here
        
        # TODO: Step 5 - Final linear projection
        output = None  # Your code here
        
        return output, attn_weights

In [None]:
# Test your implementation
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)

output, weights = mha(x, x, x)  # Self-attention

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")  # Expected: (2, 10, 64)
print(f"Attention weights shape: {weights.shape}")  # Expected: (2, 8, 10, 10)

---

## Exercise 3: Attention Visualization

Visualize attention patterns to understand what the model is "looking at".

**Your Task:** Complete the visualization function and analyze the attention patterns.

In [None]:
def visualize_attention(attention_weights, tokens, head_idx=0):
    """
    Visualize attention weights as a heatmap.
    
    Args:
        attention_weights: Tensor of shape (batch, heads, seq, seq)
        tokens: List of token strings
        head_idx: Which attention head to visualize
    """
    # TODO: Extract weights for first batch item and specified head
    weights = None  # Your code here - shape should be (seq, seq)
    
    # TODO: Create heatmap using matplotlib
    # Hint: Use plt.imshow() with 'Blues' colormap
    
    # Your visualization code here
    pass

In [None]:
# Create sample attention scenario
tokens = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(tokens)

# Create sample input
x = torch.randn(1, seq_len, d_model)

# Get attention weights
mha = MultiHeadAttention(d_model, num_heads)
_, attention_weights = mha(x, x, x)

# Visualize different heads
for head in range(min(4, num_heads)):
    visualize_attention(attention_weights, tokens, head_idx=head)

---

## Exercise 4: Causal (Masked) Attention

Implement causal masking for autoregressive models like GPT. This prevents the model from "seeing the future".

**Your Task:** Create a causal mask and apply it to attention.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal (lower triangular) mask.
    
    Args:
        seq_len: Length of the sequence
    
    Returns:
        mask: Boolean tensor where True = keep, False = mask
    """
    # TODO: Create lower triangular mask
    # Hint: Use torch.tril() or torch.ones() with appropriate masking
    mask = None  # Your code here
    
    return mask

In [None]:
# Test causal mask
seq_len = 5
mask = create_causal_mask(seq_len)

print("Causal Mask:")
print(mask.int())

# Expected output:
# tensor([[1, 0, 0, 0, 0],
#         [1, 1, 0, 0, 0],
#         [1, 1, 1, 0, 0],
#         [1, 1, 1, 1, 0],
#         [1, 1, 1, 1, 1]])

In [None]:
# Apply causal mask to attention
Q = torch.randn(1, seq_len, d_model)
K = torch.randn(1, seq_len, d_model)
V = torch.randn(1, seq_len, d_model)

# Attention without mask
output_no_mask, weights_no_mask = scaled_dot_product_attention(Q, K, V)

# Attention with causal mask
causal_mask = create_causal_mask(seq_len)
output_masked, weights_masked = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

print("Attention weights WITHOUT mask:")
print(weights_no_mask[0].detach().numpy().round(2))

print("\nAttention weights WITH causal mask:")
print(weights_masked[0].detach().numpy().round(2))

---

## Checkpoint

Congratulations! You've completed Lab 1. You should now understand:

- How scaled dot-product attention computes relevance between tokens
- How multi-head attention allows learning multiple attention patterns
- How to visualize and interpret attention weights
- How causal masking enables autoregressive generation

**Next:** Lab 2 - Building LangChain Agents