# Attention and Self-Attention: Exercises
## NLP and LLMs Course - AIMS South Africa 2025

### Instructions
Complete all TODO sections in this notebook. Test your implementations with the provided test cases. 

**Learning Objectives:**
1. Implement dot-product attention mechanism
2. Implement scaled dot-product attention
3. Implement self-attention layer
4. Build an attention-based encoder
5. Understand multi-head attention

**Submission:** Complete notebook with all cells executed successfully.

In [None]:
# Import required libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device configuration - GPU support
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    print("‚úì GPU acceleration enabled!")
else:
    print("‚ö†Ô∏è  Running on CPU (slower but works fine)")

## Exercise 1: Dot-Product Attention (20 points)

Implement the basic dot-product attention mechanism.

**Formula:**
```
scores = Q @ K^T
attention_weights = softmax(scores)
output = attention_weights @ V
```

**Args:**
- `query`: (batch_size, hidden_dim)
- `keys`: (batch_size, seq_len, hidden_dim)
- `values`: (batch_size, seq_len, hidden_dim)

**Returns:**
- `attention_output`: (batch_size, hidden_dim)
- `attention_weights`: (batch_size, seq_len)

In [None]:
def dot_product_attention(query, keys, values):
    """
    Implement dot-product attention mechanism.
    """
    # TODO: Step 1 - Compute attention scores (query @ keys^T)
    # Hint: query needs to be (batch_size, 1, hidden_dim) for batch matrix multiplication
    # Hint: keys need to be transposed to (batch_size, hidden_dim, seq_len)
    scores = None  # YOUR CODE HERE
    
    # TODO: Step 2 - Apply softmax to get attention weights
    # The weights should sum to 1 along the sequence dimension
    attention_weights = None  # YOUR CODE HERE
    
    # TODO: Step 3 - Compute weighted sum of values
    # Hint: attention_weights should be (batch_size, 1, seq_len)
    attention_output = None  # YOUR CODE HERE
    
    return attention_output, attention_weights

In [None]:
# Test your implementation
print("Testing dot_product_attention...")
batch_size, seq_len, hidden_dim = 2, 4, 8
query = torch.randn(batch_size, hidden_dim)
keys = torch.randn(batch_size, seq_len, hidden_dim)
values = torch.randn(batch_size, seq_len, hidden_dim)

output, weights = dot_product_attention(query, keys, values)

# Assertions
assert output.shape == (batch_size, hidden_dim), f"Output shape should be {(batch_size, hidden_dim)}, got {output.shape}"
assert weights.shape == (batch_size, seq_len), f"Weights shape should be {(batch_size, seq_len)}, got {weights.shape}"
assert torch.allclose(weights.sum(dim=1), torch.ones(batch_size), atol=1e-6), "Attention weights should sum to 1"

print("‚úì All tests passed!")
print(f"Attention output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"Attention weights (first sample): {weights[0]}")

## Exercise 2: Scaled Dot-Product Attention (20 points)

Implement scaled dot-product attention used in Transformers.

**Formula:**
```
scores = (Q @ K^T) / sqrt(d_k)
attention_weights = softmax(scores)
output = attention_weights @ V
```

The scaling factor prevents the dot products from growing too large in high dimensions.

In [None]:
def scaled_dot_product_attention(query, keys, values, mask=None):
    """
    Implement scaled dot-product attention.
    
    Args:
        query: (batch_size, hidden_dim)
        keys: (batch_size, seq_len, hidden_dim)
        values: (batch_size, seq_len, hidden_dim)
        mask: Optional (batch_size, seq_len) mask (1 for valid, 0 for masked)
    
    Returns:
        attention_output: (batch_size, hidden_dim)
        attention_weights: (batch_size, seq_len)
    """
    # TODO: Get the dimension for scaling
    d_k = None  # YOUR CODE HERE
    
    # TODO: Compute attention scores and scale by sqrt(d_k)
    scores = None  # YOUR CODE HERE
    
    # TODO: Apply mask if provided (set masked positions to -inf)
    if mask is not None:
        pass  # YOUR CODE HERE
    
    # TODO: Apply softmax to get attention weights
    attention_weights = None  # YOUR CODE HERE
    
    # TODO: Compute weighted sum of values
    attention_output = None  # YOUR CODE HERE
    
    return attention_output, attention_weights

In [None]:
# Test your implementation
print("Testing scaled_dot_product_attention...")
output_scaled, weights_scaled = scaled_dot_product_attention(query, keys, values)

# Test with mask
mask = torch.ones(batch_size, seq_len)
mask[:, 2:] = 0  # Mask out positions after index 2
output_masked, weights_masked = scaled_dot_product_attention(query, keys, values, mask)

# Assertions
assert output_scaled.shape == (batch_size, hidden_dim), f"Output shape incorrect"
assert weights_scaled.shape == (batch_size, seq_len), f"Weights shape incorrect"
assert torch.allclose(weights_scaled.sum(dim=1), torch.ones(batch_size), atol=1e-6), "Weights should sum to 1"
assert torch.allclose(weights_masked[:, 2:], torch.zeros(batch_size, seq_len-2), atol=1e-6), "Masked positions should have 0 weight"

print("‚úì All tests passed!")
print(f"Weights without mask: {weights_scaled[0]}")
print(f"Weights with mask: {weights_masked[0]}")

## Exercise 3: Self-Attention Layer (30 points)

Implement a complete self-attention layer with learnable projections for Q, K, and V.

**Key Points:**
- All Q, K, V come from the same input sequence
- Use learnable linear projections: W_q, W_k, W_v
- Output should have the same shape as input

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        """
        Self-Attention layer.
        
        Args:
            embed_dim: Embedding dimension
        """
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        
        # TODO: Initialize learnable projection matrices
        # Hint: Use nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_q = None  # YOUR CODE HERE
        self.W_k = None  # YOUR CODE HERE
        self.W_v = None  # YOUR CODE HERE
        
    def forward(self, x, mask=None):
        """
        Forward pass.
        
        Args:
            x: Input tensor (batch_size, seq_len, embed_dim)
            mask: Optional mask (batch_size, seq_len, seq_len)
        
        Returns:
            output: (batch_size, seq_len, embed_dim)
            attention_weights: (batch_size, seq_len, seq_len)
        """
        batch_size, seq_len, embed_dim = x.shape
        
        # TODO: Project input to queries, keys, and values
        Q = None  # YOUR CODE HERE - shape: (batch_size, seq_len, embed_dim)
        K = None  # YOUR CODE HERE - shape: (batch_size, seq_len, embed_dim)
        V = None  # YOUR CODE HERE - shape: (batch_size, seq_len, embed_dim)
        
        # TODO: Compute attention scores: Q @ K^T
        # Hint: Use torch.bmm or torch.matmul
        scores = None  # YOUR CODE HERE - shape: (batch_size, seq_len, seq_len)
        
        # TODO: Scale scores
        scores = None  # YOUR CODE HERE
        
        # TODO: Apply mask if provided
        if mask is not None:
            pass  # YOUR CODE HERE
        
        # TODO: Apply softmax along the last dimension
        attention_weights = None  # YOUR CODE HERE
        
        # TODO: Compute weighted sum: attention_weights @ V
        output = None  # YOUR CODE HERE - shape: (batch_size, seq_len, embed_dim)
        
        return output, attention_weights

In [None]:
# Test your implementation
print("Testing SelfAttention...")
batch_size, seq_len, embed_dim = 2, 5, 8
x = torch.randn(batch_size, seq_len, embed_dim)

self_attn = SelfAttention(embed_dim)
output, attn_weights = self_attn(x)

# Assertions
assert output.shape == (batch_size, seq_len, embed_dim), f"Output shape should be {(batch_size, seq_len, embed_dim)}, got {output.shape}"
assert attn_weights.shape == (batch_size, seq_len, seq_len), f"Attention weights shape should be {(batch_size, seq_len, seq_len)}, got {attn_weights.shape}"
assert torch.allclose(attn_weights.sum(dim=-1), torch.ones(batch_size, seq_len), atol=1e-6), "Attention weights should sum to 1 along last dimension"

print("‚úì All tests passed!")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

## Exercise 4: Multi-Head Attention (30 points)

Implement multi-head attention, which allows the model to attend to different representation subspaces.

**Key Idea:**
- Split the embedding dimension into multiple heads
- Apply attention in parallel for each head
- Concatenate the outputs

**Formula:**
```
head_i = Attention(Q_i, K_i, V_i)
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_O
```

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        """
        Multi-head attention.
        
        Args:
            embed_dim: Embedding dimension (must be divisible by num_heads)
            num_heads: Number of attention heads
        """
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # TODO: Initialize projection matrices
        self.W_q = None  # YOUR CODE HERE
        self.W_k = None  # YOUR CODE HERE
        self.W_v = None  # YOUR CODE HERE
        self.W_o = None  # YOUR CODE HERE - output projection
        
    def forward(self, x, mask=None):
        """
        Forward pass.
        
        Args:
            x: (batch_size, seq_len, embed_dim)
            mask: Optional mask
        
        Returns:
            output: (batch_size, seq_len, embed_dim)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        batch_size, seq_len, embed_dim = x.shape
        
        # TODO: Project to Q, K, V
        Q = None  # YOUR CODE HERE
        K = None  # YOUR CODE HERE
        V = None  # YOUR CODE HERE
        
        # TODO: Reshape for multi-head attention
        # Split embed_dim into (num_heads, head_dim)
        # Final shape: (batch_size, num_heads, seq_len, head_dim)
        Q = None  # YOUR CODE HERE
        K = None  # YOUR CODE HERE
        V = None  # YOUR CODE HERE
        
        # TODO: Compute attention scores for all heads
        # Shape: (batch_size, num_heads, seq_len, seq_len)
        scores = None  # YOUR CODE HERE
        
        # TODO: Scale
        scores = None  # YOUR CODE HERE
        
        # TODO: Apply mask if provided
        if mask is not None:
            pass  # YOUR CODE HERE
        
        # TODO: Apply softmax
        attention_weights = None  # YOUR CODE HERE
        
        # TODO: Apply attention to values
        # Shape: (batch_size, num_heads, seq_len, head_dim)
        attended = None  # YOUR CODE HERE
        
        # TODO: Reshape back to (batch_size, seq_len, embed_dim)
        # Hint: transpose and reshape to concatenate heads
        attended = None  # YOUR CODE HERE
        
        # TODO: Apply output projection
        output = None  # YOUR CODE HERE
        
        return output, attention_weights

In [None]:
# Test your implementation
print("Testing MultiHeadAttention...")
batch_size, seq_len, embed_dim, num_heads = 2, 5, 16, 4
x = torch.randn(batch_size, seq_len, embed_dim)

mha = MultiHeadAttention(embed_dim, num_heads)
output, attn_weights = mha(x)

# Assertions
assert output.shape == (batch_size, seq_len, embed_dim), f"Output shape incorrect"
assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len), f"Attention weights shape incorrect"
assert torch.allclose(attn_weights.sum(dim=-1), torch.ones(batch_size, num_heads, seq_len), atol=1e-6), "Weights should sum to 1"

print("‚úì All tests passed!")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

## Bonus Exercise: Causal Masking (10 bonus points)

Implement a causal mask for autoregressive generation (like in GPT).

**Purpose:** Prevent the model from attending to future tokens.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal (lower-triangular) mask.
    
    Args:
        seq_len: Sequence length
    
    Returns:
        mask: (seq_len, seq_len) boolean tensor
              True for valid positions, False for masked
    """
    # TODO: Create lower triangular mask
    # Hint: Use torch.tril
    mask = None  # YOUR CODE HERE
    
    return mask

In [None]:
# Test causal mask
mask = create_causal_mask(5)
print("Causal mask:")
print(mask.int())

# Visualize
plt.figure(figsize=(6, 6))
sns.heatmap(mask.int().numpy(), cmap='Blues', cbar=False, square=True, 
            xticklabels=range(1, 6), yticklabels=range(1, 6))
plt.title('Causal Attention Mask\n(1 = can attend, 0 = masked)')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.show()

## Submission Checklist

Before submitting, ensure:

- [ ] All TODO sections are completed
- [ ] All test cells run without errors
- [ ] All assertions pass
- [ ] Code is well-commented
- [ ] You understand what each part does

**Total Points: 100 + 10 bonus**

Good luck! üöÄ