# Lab 1: Transformer Attention Mechanism - SOLUTIONS

**Module 1 - Foundations of Modern LLMs**

| Duration | Difficulty | Framework | Exercises |
|----------|------------|-----------|----------|
| 90 min | Intermediate | PyTorch | 4 |

## Learning Objectives

- Implement scaled dot-product attention from scratch
- Build multi-head attention mechanism
- Visualize attention patterns
- Understand masking for causal attention

## Setup

In [None]:
# Install dependencies if needed
# !pip install torch numpy matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

---

## Exercise 1: Scaled Dot-Product Attention - SOLUTION

**Formula:**
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query tensor of shape (batch, seq_len, d_k) or (batch, heads, seq_len, d_k)
        K: Key tensor of shape (batch, seq_len, d_k) or (batch, heads, seq_len, d_k)
        V: Value tensor of shape (batch, seq_len, d_v) or (batch, heads, seq_len, d_v)
        mask: Optional mask tensor
    
    Returns:
        output: Attention output
        attention_weights: Attention weight matrix
    """
    d_k = Q.size(-1)
    
    # Step 1 - Compute attention scores (Q @ K^T)
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Step 2 - Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)
    
    # Step 3 - Apply mask if provided (set masked positions to -inf)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4 - Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5 - Multiply by values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

In [None]:
# Test the implementation
batch_size = 2
seq_len = 4
d_k = 8

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Output shape: {output.shape}")  # Expected: (2, 4, 8)
print(f"Weights shape: {weights.shape}")  # Expected: (2, 4, 4)
print(f"Weights sum per row: {weights.sum(dim=-1)}")  # Should be all 1s

---

## Exercise 2: Multi-Head Attention - SOLUTION

**Formula:**
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        seq_len = Q.size(1)
        
        # Step 1 - Linear projections
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # Step 2 - Reshape for multi-head: (batch, seq, d_model) -> (batch, heads, seq, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Step 3 - Apply scaled dot-product attention
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Step 4 - Concatenate heads: (batch, heads, seq, d_k) -> (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Step 5 - Final linear projection
        output = self.W_o(attn_output)
        
        return output, attn_weights

In [None]:
# Test the implementation
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)

output, weights = mha(x, x, x)  # Self-attention

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")  # Expected: (2, 10, 64)
print(f"Attention weights shape: {weights.shape}")  # Expected: (2, 8, 10, 10)

---

## Exercise 3: Attention Visualization - SOLUTION

In [None]:
def visualize_attention(attention_weights, tokens, head_idx=0):
    """
    Visualize attention weights as a heatmap.
    
    Args:
        attention_weights: Tensor of shape (batch, heads, seq, seq)
        tokens: List of token strings
        head_idx: Which attention head to visualize
    """
    # Extract weights for first batch item and specified head
    weights = attention_weights[0, head_idx].detach().numpy()
    
    # Create figure
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Create heatmap
    im = ax.imshow(weights, cmap='Blues', aspect='auto')
    
    # Set ticks and labels
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right')
    ax.set_yticklabels(tokens)
    
    # Add colorbar
    plt.colorbar(im, ax=ax)
    
    # Labels
    ax.set_xlabel('Key (attending to)')
    ax.set_ylabel('Query (attending from)')
    ax.set_title(f'Attention Weights - Head {head_idx}')
    
    # Add value annotations
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            text = ax.text(j, i, f'{weights[i, j]:.2f}',
                          ha='center', va='center', fontsize=8,
                          color='white' if weights[i, j] > 0.5 else 'black')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Create sample attention scenario
tokens = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(tokens)

# Create sample input
x = torch.randn(1, seq_len, d_model)

# Get attention weights
mha = MultiHeadAttention(d_model, num_heads)
_, attention_weights = mha(x, x, x)

# Visualize different heads
for head in range(min(4, num_heads)):
    visualize_attention(attention_weights, tokens, head_idx=head)

---

## Exercise 4: Causal (Masked) Attention - SOLUTION

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal (lower triangular) mask.
    
    Args:
        seq_len: Length of the sequence
    
    Returns:
        mask: Boolean tensor where True = keep, False = mask
    """
    # Create lower triangular mask
    mask = torch.tril(torch.ones(seq_len, seq_len))
    
    return mask

In [None]:
# Test causal mask
seq_len = 5
mask = create_causal_mask(seq_len)

print("Causal Mask:")
print(mask.int())

# Expected output:
# tensor([[1, 0, 0, 0, 0],
#         [1, 1, 0, 0, 0],
#         [1, 1, 1, 0, 0],
#         [1, 1, 1, 1, 0],
#         [1, 1, 1, 1, 1]])

In [None]:
# Apply causal mask to attention
Q = torch.randn(1, seq_len, d_model)
K = torch.randn(1, seq_len, d_model)
V = torch.randn(1, seq_len, d_model)

# For scaled_dot_product_attention, we need to handle the shape
# Project to d_k for the test
d_k_test = 8
Q_test = torch.randn(1, seq_len, d_k_test)
K_test = torch.randn(1, seq_len, d_k_test)
V_test = torch.randn(1, seq_len, d_k_test)

# Attention without mask
output_no_mask, weights_no_mask = scaled_dot_product_attention(Q_test, K_test, V_test)

# Attention with causal mask
causal_mask = create_causal_mask(seq_len)
output_masked, weights_masked = scaled_dot_product_attention(Q_test, K_test, V_test, mask=causal_mask)

print("Attention weights WITHOUT mask:")
print(weights_no_mask[0].detach().numpy().round(2))

print("\nAttention weights WITH causal mask:")
print(weights_masked[0].detach().numpy().round(2))

print("\nNotice: With causal mask, upper triangular values are 0 (future tokens masked)")

In [None]:
# Visualize the causal mask effect
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Without mask
im1 = axes[0].imshow(weights_no_mask[0].detach().numpy(), cmap='Blues')
axes[0].set_title('Attention WITHOUT Causal Mask')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')
plt.colorbar(im1, ax=axes[0])

# With mask
im2 = axes[1].imshow(weights_masked[0].detach().numpy(), cmap='Blues')
axes[1].set_title('Attention WITH Causal Mask')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

---

## Bonus: Complete Transformer Block

Here's how multi-head attention fits into a complete transformer block:

In [None]:
class TransformerBlock(nn.Module):
    """A complete transformer block with attention + FFN."""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, num_heads)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output, attn_weights = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # FFN with residual connection
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        
        return x, attn_weights


# Test transformer block
block = TransformerBlock(d_model=64, num_heads=8, d_ff=256)
x = torch.randn(2, 10, 64)
output, weights = block(x)
print(f"Transformer block output shape: {output.shape}")

---

## Checkpoint

Congratulations! You've completed Lab 1. You should now understand:

- How scaled dot-product attention computes relevance between tokens
- How multi-head attention allows learning multiple attention patterns
- How to visualize and interpret attention weights
- How causal masking enables autoregressive generation

**Next:** Lab 2 - Building LangChain Agents