# Building LLMs From Scratch (Part 8): Causal Attention

Welcome to Part 8! In this notebook, we'll implement **Causal Attention** (also known as Masked Self-Attention). This prevents the model from "cheating" by looking at future tokens, which is essential for autoregressive text generation.

### üîó Quick Links
- **Medium Article**: [Part 8: Causal Attention](https://soloshun.medium.com/building-llms-from-scratch-part-8-causal-attention)
- **GitHub Repository**: [llm-from-scratch](https://github.com/soloeinsteinmit/llm-from-scratch)

### üìã What We'll Cover
1. **The Problem**: Why standard self-attention "cheats"
2. **The Solution**: Diagonal masking to hide future tokens
3. **Implementation**: Why we mask BEFORE softmax
4. **Causal Attention Class**: Production-ready implementation
5. **Dropout**: Regularization in attention


## Setup and Imports

Let's import the necessary libraries.


In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")


## The Problem: Future Leakage

Let's set up our example and see the problem with standard self-attention.


In [None]:
words = ['Your', 'journey', 'starts', 'with', 'one', 'step']

inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one
    [0.05, 0.80, 0.55], # step
])

d_in = inputs.shape[-1]  # embedding dimension = 3
d_out = 2  # output dimension

print(f"üìù Sentence: '{' '.join(words)}'")
print(f"üß† Input shape: {inputs.shape}")
print(f"üìê d_in={d_in}, d_out={d_out}")


Let's create a standard self-attention mechanism (from Part 7) to see the problem:


In [None]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

# Test it
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
output = sa_v2(inputs)

print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")


Now let's look at the attention weights to see the problem:


In [None]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

print("Attention Weights (each row sums to 1):")
print(attn_weights)
print(f"\n‚ùå Problem: Token 'journey' (row 1) can see 'starts', 'with', 'one', 'step'")
print("   These are FUTURE tokens that shouldn't exist yet during generation!")


## The Solution: Causal Masking

To fix this, we need to create a **mask** that prevents tokens from attending to future positions. We use a **lower triangular mask** where:
- **0** = allowed (past and current tokens)
- **1** = masked (future tokens)


In [None]:
# Create a causal mask using torch.triu (triangular upper)
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))

print("Lower Triangular Mask (1 = keep, 0 = hide):")
print(mask_simple)
print(f"\n‚úÖ Each row can only see tokens up to and including its position")


### Approach 1: Mask After Softmax (WRONG ‚ùå)

Let's first try the naive approach of masking AFTER softmax:


In [None]:
# Multiply attention weights by mask
masked_simple = attn_weights * mask_simple

print("After masking (but rows don't sum to 1):")
print(masked_simple)

# Need to renormalize
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums

print("\nAfter renormalization:")
print(masked_simple_norm)

print("\n‚ùå Problem: The softmax denominator included future tokens!")
print("   This causes subtle DATA LEAKAGE. The probabilities of past tokens")
print("   were influenced by the presence of future tokens.")


### Approach 2: Mask Before Softmax (CORRECT ‚úÖ)

The correct way is to mask BEFORE applying softmax. We replace future positions with `-inf`, which becomes 0 after softmax.


In [None]:
# Create upper triangular mask (1 = hide, 0 = keep)
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

print("Upper Triangular Mask (1 = hide, 0 = keep):")
print(mask)

# Replace masked positions with -inf
masked_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)

print("\nAttention Scores after masking with -inf:")
print(masked_scores)

# Apply softmax
attn_weights_correct = torch.softmax(masked_scores / keys.shape[-1]**0.5, dim=-1)

print("\nAttention Weights after softmax:")
print(attn_weights_correct)

print("\n‚úÖ Correct! Future tokens had NO influence on the probabilities.")
print("   Each row properly sums to 1.0")


## Dropout in Attention

**Dropout** is a regularization technique that randomly zeros out a fraction of values during training to prevent overfitting.


In [None]:
torch.manual_seed(123)
dropout = nn.Dropout(0.5)

# Example with simple tensor
example = torch.ones(6, 6)
print("Original tensor:")
print(example)

dropped = dropout(example)
print(f"\nAfter dropout (p=0.5):")
print(dropped)

print("\nüí° Key points:")
print("- Dropout randomly zeros out 50% of values")
print("- Remaining values are scaled by 2x (1/(1-p))")
print("- This prevents overfitting during training")
print("- During inference, dropout is automatically disabled")


Now let's apply dropout to our attention weights:


In [None]:
torch.manual_seed(123)
dropped_weights = dropout(attn_weights_correct)

print("Attention weights after dropout:")
print(dropped_weights)
print("\n‚úÖ Some attention weights are now zero!")


## Building the CausalAttention Class

Now let's put everything together into a production-ready PyTorch module.


In [None]:
class CausalAttention(nn.Module):
    """
    Causal self-attention with trainable weights and masking.
    
    Prevents tokens from attending to future positions,
    which is essential for autoregressive language modeling.
    """
    
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        
        # Register causal mask as a buffer
        # (non-trainable, but moves with model and gets saved/loaded)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x):
        """
        Forward pass with causal masking.
        
        Args:
            x: Input tensor of shape (batch_size, num_tokens, d_in)
            
        Returns:
            context_vectors: Output of shape (batch_size, num_tokens, d_out)
        """
        b, num_tokens, d_in = x.shape
        
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # Compute attention scores
        # Note: transpose(1, 2) swaps the sequence and embedding dimensions
        attn_scores = queries @ keys.transpose(1, 2)
        
        # Apply causal mask (hide future tokens)
        # Slice mask to match current sequence length
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        
        # Scale and apply softmax
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        
        # Apply dropout
        attn_weights = self.dropout(attn_weights)
        
        # Compute context vectors
        context_vec = attn_weights @ values
        return context_vec

print("‚úÖ CausalAttention class defined!")


### Understanding `register_buffer`

The `register_buffer` method is crucial for our mask. Let's understand why:

**What it does:**
- Tells PyTorch: "This tensor is part of the module's state, but NOT a trainable parameter"

**Benefits:**
- ‚úÖ Saved and loaded with `model.state_dict()`
- ‚úÖ Automatically moved to GPU/CPU with `model.to(device)`
- ‚ùå NOT updated during backpropagation (no gradients)

**Why we need it:**
- The causal mask is fixed for a given context length
- It doesn't change during training
- But it needs to follow the model when we save/load or move to GPU


## Testing CausalAttention

Let's test our causal attention with a single sequence:


In [None]:
torch.manual_seed(789)

# Create causal attention
ca = CausalAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=len(words),
    dropout=0.0  # No dropout for testing
)

# Add batch dimension
inputs_batched = inputs.unsqueeze(0)

# Forward pass
context_vecs = ca(inputs_batched)

print(f"Input shape: {inputs_batched.shape}")
print(f"Output shape: {context_vecs.shape}")
print(f"\nContext vectors:\n{context_vecs[0]}")
print("\n‚úÖ Causal attention working correctly!")


## Testing with Batches

Modern deep learning processes multiple sequences at once. Let's test with a batch:


In [None]:
# Create a batch (duplicate the inputs for demonstration)
batch = torch.stack((inputs, inputs), dim=0)

print(f"Batch shape: {batch.shape}")
print(f"  - Batch size: {batch.shape[0]}")
print(f"  - Sequence length: {batch.shape[1]}")
print(f"  - Embedding dim: {batch.shape[2]}")

# Apply causal attention
torch.manual_seed(123)
ca_batch = CausalAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=batch.shape[1],
    dropout=0.0
)

context_vecs_batch = ca_batch(batch)

print(f"\nOutput shape: {context_vecs_batch.shape}")
print(f"Output:\n{context_vecs_batch}")

print("\n‚úÖ Causal attention handles batches seamlessly!")


## Visualizing the Mask

Let's visualize how the causal mask works:


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

context_length = 6
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Mask
im1 = ax1.imshow(mask.numpy(), cmap='RdYlGn_r', vmin=0, vmax=1)
ax1.set_title('Causal Mask\n(0 = allowed, 1 = masked)', fontsize=14, pad=20)
ax1.set_xlabel('Key Position (attending TO)')
ax1.set_ylabel('Query Position (attending FROM)')

for i in range(context_length):
    for j in range(context_length):
        text = ax1.text(j, i, 'X' if mask[i, j] == 1 else '‚úì',
                      ha="center", va="center", color="white", 
                      fontsize=16, weight='bold')

plt.colorbar(im1, ax=ax1, shrink=0.8)

# Plot 2: Attention weights after masking
torch.manual_seed(123)
sample_scores = torch.randn(context_length, context_length)
masked_scores = sample_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights_vis = torch.softmax(masked_scores, dim=-1)

im2 = ax2.imshow(attn_weights_vis.numpy(), cmap='Blues', vmin=0, vmax=1)
ax2.set_title('Causal Attention Weights\n(after masking & softmax)', fontsize=14, pad=20)
ax2.set_xlabel('Key Position')
ax2.set_ylabel('Query Position')

for i in range(context_length):
    for j in range(context_length):
        text = ax2.text(j, i, f'{attn_weights_vis[i, j]:.2f}',
                      ha="center", va="center",
                      color="white" if attn_weights_vis[i, j] > 0.5 else "black",
                      fontsize=10)

plt.colorbar(im2, ax=ax2, shrink=0.8)
plt.tight_layout()
plt.show()

print("‚úÖ Visualization complete!")


## Summary and Key Takeaways

### üéØ What We Learned

1. **The Problem**: Standard self-attention allows tokens to see future tokens, causing data leakage during training and failure during generation.

2. **The Solution**: Causal masking prevents tokens from attending to future positions by:
   - Creating an upper triangular mask
   - Replacing future positions with `-inf` BEFORE softmax
   - The softmax then naturally converts these to zero

3. **Implementation Details**:
   - Use `torch.triu()` to create the upper triangular mask
   - Use `masked_fill()` to replace masked positions with `-inf`
   - Use `register_buffer()` to store the mask with the model
   - Use `transpose(1, 2)` for batch matrix multiplication

4. **Dropout**: Regularization technique that randomly zeros out attention weights during training to prevent overfitting.

### ‚úÖ Our Implementation Now Has

- ‚úÖ Trainable Weights (Q, K, V)
- ‚úÖ Scaled Dot-Product Attention
- ‚úÖ Causal Masking
- ‚úÖ Dropout
- ‚úÖ Batch Support

### üîú What's Next?

In **Part 9**, we'll implement **Multi-Head Attention**, which runs multiple causal attention mechanisms in parallel to capture different types of relationships!

---

### üìö Resources

- **Medium Article**: [Building LLMs From Scratch (Part 8): Causal Attention](https://soloshun.medium.com/building-llms-from-scratch-part-8-causal-attention)
- **GitHub**: [llm-from-scratch](https://github.com/soloeinsteinmit/llm-from-scratch)
- **Previous Part**: [Part 7: Self-Attention with Trainable Weights](https://medium.com/@soloshun/building-llms-from-scratch-part-7-self-attention-with-trainable-weights)

---

Thank you for following along! üôè
