# Day 15: Attention Mechanisms and Transformers

## Phase 2: NLP Basics (Days 11-20)

**Estimated Time: 4-5 hours**

### Learning Objectives
- Understand attention as a soft addressing mechanism
- Implement Bahdanau (additive) and Luong (multiplicative) attention
- Master the Transformer architecture from "Attention Is All You Need"
- Implement self-attention and multi-head attention from scratch
- Understand positional encoding for sequence order
- Build a complete Transformer encoder
- Appreciate why Transformers revolutionized NLP

### Prerequisites
- Day 14: LSTM and GRU Networks
- Understanding of sequence-to-sequence models
- Linear algebra (matrix operations, softmax)
- Neural network fundamentals

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
torch.manual_seed(42)
plt.style.use('seaborn-v0_8-darkgrid')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Libraries loaded successfully!")

## 1. The Bottleneck Problem

### 1.1 Sequence-to-Sequence Limitation

In encoder-decoder architectures (Day 14):
- Encoder compresses entire input into fixed-size vector
- This single vector must capture all information
- **Bottleneck**: Long sequences lose information

**Example**: Translating a long sentence
- First words may be forgotten by the time encoding is complete
- Decoder has no way to "look back" at specific input positions

### 1.2 Attention: The Solution

**Key Idea**: Instead of single context vector, allow decoder to "attend" to all encoder states.

$$\text{context}_t = \sum_{s=1}^{S} \alpha_{t,s} \cdot h_s$$

Where:
- $h_s$: Encoder hidden state at position $s$
- $\alpha_{t,s}$: Attention weight (how much to focus on position $s$ when generating token $t$)
- $\sum_s \alpha_{t,s} = 1$ (normalized via softmax)

**Intuition**: Like a spotlight that moves across the input, focusing on relevant parts.

In [None]:
# Visualize the bottleneck problem and attention solution

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# 1. Bottleneck Problem
ax = axes[0]
ax.set_title('Seq2Seq Bottleneck Problem', fontsize=14, fontweight='bold')

# Encoder states
for i in range(5):
    rect = plt.Rectangle((i*1.5, 3), 1, 1, color='lightblue', alpha=0.8)
    ax.add_patch(rect)
    ax.text(i*1.5 + 0.5, 3.5, f'$h_{i+1}$', ha='center', va='center', fontsize=10)
    if i < 4:
        ax.arrow(i*1.5 + 1, 3.5, 0.4, 0, head_width=0.1, fc='black')

# Final context vector (bottleneck)
circle = plt.Circle((7.5, 3.5), 0.6, color='red', alpha=0.8)
ax.add_patch(circle)
ax.text(7.5, 3.5, 'c', ha='center', va='center', fontsize=12, color='white', fontweight='bold')
ax.text(7.5, 2.5, 'BOTTLENECK\n(single vector)', ha='center', fontsize=10, color='red')

# Arrow to decoder
ax.arrow(7.5, 2.9, 0, -1, head_width=0.2, fc='red')

# Decoder states
for i in range(3):
    rect = plt.Rectangle((i*2 + 2, 0.5), 1, 1, color='lightgreen', alpha=0.8)
    ax.add_patch(rect)
    ax.text(i*2 + 2.5, 1, f'$s_{i+1}$', ha='center', va='center', fontsize=10)

ax.text(4.5, 4.8, 'Encoder', ha='center', fontsize=12)
ax.text(4.5, -0.2, 'Decoder', ha='center', fontsize=12)
ax.set_xlim(-0.5, 9)
ax.set_ylim(-0.5, 5.5)
ax.axis('off')

# 2. Attention Solution
ax = axes[1]
ax.set_title('Attention Mechanism Solution', fontsize=14, fontweight='bold')

# Encoder states
for i in range(5):
    rect = plt.Rectangle((i*1.5, 4), 1, 1, color='lightblue', alpha=0.8)
    ax.add_patch(rect)
    ax.text(i*1.5 + 0.5, 4.5, f'$h_{i+1}$', ha='center', va='center', fontsize=10)

# Decoder state
rect = plt.Rectangle((3.5, 1), 1, 1, color='lightgreen', alpha=0.8)
ax.add_patch(rect)
ax.text(4, 1.5, '$s_t$', ha='center', va='center', fontsize=10)

# Attention weights (arrows with different widths)
attention_weights = [0.05, 0.1, 0.5, 0.3, 0.05]  # Focus on position 3
colors = plt.cm.Reds(np.array(attention_weights) / max(attention_weights))

for i, (w, c) in enumerate(zip(attention_weights, colors)):
    ax.plot([i*1.5 + 0.5, 4], [4, 2], color=c, linewidth=w*10 + 0.5, alpha=0.8)
    ax.text(i*1.5 + 0.5, 3.3, f'{w:.2f}', ha='center', fontsize=9)

# Context vector
circle = plt.Circle((4, 2.7), 0.3, color='purple', alpha=0.8)
ax.add_patch(circle)
ax.text(4, 2.7, 'c', ha='center', va='center', fontsize=10, color='white')
ax.text(5, 2.7, '$c_t = \\sum_i \\alpha_i h_i$', fontsize=11)

ax.text(3.5, 5.3, 'Encoder', ha='center', fontsize=12)
ax.text(4, 0.3, 'Decoder', ha='center', fontsize=12)
ax.text(7, 3.5, 'Attention\nweights', ha='center', fontsize=10, color='red')

ax.set_xlim(-0.5, 9)
ax.set_ylim(-0.5, 6)
ax.axis('off')

plt.tight_layout()
plt.show()

print("Left: Information compressed into single vector (bottleneck)")
print("Right: Decoder can attend to all encoder states with learned weights")

## 2. Attention Mechanisms

### 2.1 General Attention Framework

Attention computes:
1. **Score**: How relevant is each source position?
2. **Weights**: Normalize scores (softmax)
3. **Context**: Weighted sum of values

$$\text{score}(h_t, h_s) = \text{alignment function}$$
$$\alpha_{t,s} = \frac{\exp(\text{score}(h_t, h_s))}{\sum_{s'} \exp(\text{score}(h_t, h_{s'}))}$$
$$c_t = \sum_s \alpha_{t,s} h_s$$

### 2.2 Bahdanau Attention (Additive)

Also called "concat" attention:
$$\text{score}(h_t, h_s) = v^T \tanh(W_1 h_t + W_2 h_s)$$

### 2.3 Luong Attention (Multiplicative)

Also called "dot-product" attention:
$$\text{score}(h_t, h_s) = h_t^T W h_s \quad \text{(general)}$$
$$\text{score}(h_t, h_s) = h_t^T h_s \quad \text{(dot)}$$

Multiplicative is faster, additive sometimes more expressive.

In [None]:
class BahdanauAttention(nn.Module):
    """
    Bahdanau (Additive) Attention.
    
    score(h_t, h_s) = v^T tanh(W_1 h_t + W_2 h_s)
    """
    
    def __init__(self, hidden_dim):
        super().__init__()
        self.W1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, query, keys, values, mask=None):
        """
        query: [batch, hidden_dim] - decoder state
        keys: [batch, seq_len, hidden_dim] - encoder states
        values: [batch, seq_len, hidden_dim] - encoder states (usually same as keys)
        mask: [batch, seq_len] - optional mask for padding
        """
        # Expand query for broadcasting
        query = query.unsqueeze(1)  # [batch, 1, hidden_dim]
        
        # Compute scores
        # W1(query): [batch, 1, hidden_dim]
        # W2(keys): [batch, seq_len, hidden_dim]
        scores = self.v(torch.tanh(self.W1(query) + self.W2(keys)))  # [batch, seq_len, 1]
        scores = scores.squeeze(-1)  # [batch, seq_len]
        
        # Apply mask (set padded positions to -inf)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Softmax to get weights
        attention_weights = F.softmax(scores, dim=-1)  # [batch, seq_len]
        
        # Compute context vector
        context = torch.bmm(attention_weights.unsqueeze(1), values)  # [batch, 1, hidden_dim]
        context = context.squeeze(1)  # [batch, hidden_dim]
        
        return context, attention_weights

class LuongAttention(nn.Module):
    """
    Luong (Multiplicative) Attention.
    
    Three variants:
    - dot: score = h_t^T h_s
    - general: score = h_t^T W h_s
    - concat: score = v^T tanh(W[h_t; h_s]) (similar to Bahdanau)
    """
    
    def __init__(self, hidden_dim, method='general'):
        super().__init__()
        self.method = method
        
        if method == 'general':
            self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)
        elif method == 'concat':
            self.W = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
            self.v = nn.Linear(hidden_dim, 1, bias=False)
    
    def forward(self, query, keys, values, mask=None):
        """
        query: [batch, hidden_dim]
        keys: [batch, seq_len, hidden_dim]
        values: [batch, seq_len, hidden_dim]
        """
        if self.method == 'dot':
            # [batch, seq_len]
            scores = torch.bmm(keys, query.unsqueeze(-1)).squeeze(-1)
        
        elif self.method == 'general':
            # [batch, seq_len]
            scores = torch.bmm(self.W(keys), query.unsqueeze(-1)).squeeze(-1)
        
        elif self.method == 'concat':
            query_expanded = query.unsqueeze(1).expand(-1, keys.size(1), -1)
            concat = torch.cat([query_expanded, keys], dim=-1)
            scores = self.v(torch.tanh(self.W(concat))).squeeze(-1)
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Context
        context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)
        
        return context, attention_weights

# Test attention mechanisms
print("Attention Mechanism Implementations")
print("="*60)

batch_size = 2
seq_len = 10
hidden_dim = 64

# Dummy data
query = torch.randn(batch_size, hidden_dim)
keys = torch.randn(batch_size, seq_len, hidden_dim)
values = keys.clone()  # Usually same as keys

# Test Bahdanau
bahdanau = BahdanauAttention(hidden_dim)
context_b, weights_b = bahdanau(query, keys, values)
print(f"Bahdanau Attention:")
print(f"  Context shape: {context_b.shape}")
print(f"  Weights shape: {weights_b.shape}")
print(f"  Weights sum: {weights_b.sum(dim=-1)}")

# Test Luong variants
for method in ['dot', 'general']:
    luong = LuongAttention(hidden_dim, method=method)
    context_l, weights_l = luong(query, keys, values)
    print(f"\nLuong ({method}) Attention:")
    print(f"  Context shape: {context_l.shape}")
    print(f"  Weights sum: {weights_l.sum(dim=-1)}")

In [None]:
# Visualize attention weights

# Simulate translation attention
source_sentence = "The cat sat on the mat".split()
target_sentence = "Le chat assis sur le tapis".split()

# Simulated attention weights (what we'd expect)
attention_matrix = np.array([
    [0.8, 0.1, 0.0, 0.0, 0.1, 0.0],  # Le -> The
    [0.1, 0.8, 0.0, 0.0, 0.1, 0.0],  # chat -> cat
    [0.0, 0.1, 0.8, 0.1, 0.0, 0.0],  # assis -> sat
    [0.0, 0.0, 0.1, 0.7, 0.1, 0.1],  # sur -> on
    [0.0, 0.0, 0.0, 0.1, 0.8, 0.1],  # le -> the
    [0.0, 0.0, 0.0, 0.0, 0.1, 0.9],  # tapis -> mat
])

plt.figure(figsize=(10, 8))
plt.imshow(attention_matrix, cmap='YlOrRd', aspect='auto')
plt.colorbar(label='Attention Weight')

plt.xticks(range(len(source_sentence)), source_sentence, rotation=45, ha='right')
plt.yticks(range(len(target_sentence)), target_sentence)

plt.xlabel('Source (English)', fontsize=12)
plt.ylabel('Target (French)', fontsize=12)
plt.title('Attention Weights in Translation\n(Simulated)', fontsize=14, fontweight='bold')

# Annotate with values
for i in range(len(target_sentence)):
    for j in range(len(source_sentence)):
        plt.text(j, i, f'{attention_matrix[i, j]:.1f}', ha='center', va='center',
                color='white' if attention_matrix[i, j] > 0.5 else 'black', fontsize=10)

plt.tight_layout()
plt.show()

print("Notice: Attention learns word alignments automatically!")
print("Each target word focuses on relevant source word(s).")

## 3. Self-Attention: Attending to Yourself

### 3.1 From Cross-Attention to Self-Attention

**Cross-attention**: Query from decoder, keys/values from encoder
- "What parts of the input should I focus on?"

**Self-attention**: Query, keys, and values all from same sequence
- "How do different parts of this sequence relate to each other?"

### 3.2 Query-Key-Value Framework

For each position $i$ in sequence:
- **Query** $q_i = W_Q x_i$: "What am I looking for?"
- **Key** $k_i = W_K x_i$: "What do I contain?"
- **Value** $v_i = W_V x_i$: "What information do I provide?"

**Attention**:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

The $\sqrt{d_k}$ scaling prevents softmax saturation for high dimensions.

### 3.3 Why Self-Attention?

1. **Captures long-range dependencies**: Direct connection between any two positions
2. **Parallelizable**: No sequential computation like RNNs
3. **Interpretable**: Attention weights show what model focuses on

In [None]:
class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention.
    
    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    """
    
    def __init__(self, temperature=None):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, query, key, value, mask=None):
        """
        query: [batch, ..., seq_len_q, d_k]
        key: [batch, ..., seq_len_k, d_k]
        value: [batch, ..., seq_len_v, d_v] (seq_len_v == seq_len_k)
        mask: [batch, ..., seq_len_q, seq_len_k]
        """
        d_k = query.size(-1)
        
        if self.temperature is None:
            temperature = math.sqrt(d_k)
        else:
            temperature = self.temperature
        
        # Compute attention scores
        # [batch, ..., seq_len_q, seq_len_k]
        scores = torch.matmul(query, key.transpose(-2, -1)) / temperature
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        # [batch, ..., seq_len_q, d_v]
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

class SelfAttention(nn.Module):
    """
    Self-Attention layer.
    
    Projects input to Q, K, V and computes scaled dot-product attention.
    """
    
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        
        # Projection matrices
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention()
    
    def forward(self, x, mask=None):
        """
        x: [batch, seq_len, d_model]
        """
        # Project to Q, K, V
        Q = self.W_Q(x)  # [batch, seq_len, d_model]
        K = self.W_K(x)
        V = self.W_V(x)
        
        # Self-attention
        output, attention_weights = self.attention(Q, K, V, mask)
        
        return output, attention_weights

# Test self-attention
print("Self-Attention Implementation")
print("="*60)

batch_size = 2
seq_len = 5
d_model = 64

x = torch.randn(batch_size, seq_len, d_model)
self_attn = SelfAttention(d_model)

output, weights = self_attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"\nAttention weights for first sample:")
print(weights[0].detach().numpy().round(3))
print(f"\nEach row sums to 1: {weights[0].sum(dim=-1).detach().numpy()}")

In [None]:
# Visualize self-attention pattern

sentence = ["The", "cat", "sat", "on", "mat"]
seq_len = len(sentence)
d_model = 32

# Create embeddings (random for demo)
embeddings = torch.randn(1, seq_len, d_model)

# Compute self-attention
self_attn = SelfAttention(d_model)
_, attention_weights = self_attn(embeddings)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# 1. Attention weights matrix
ax = axes[0]
weights_np = attention_weights[0].detach().numpy()
im = ax.imshow(weights_np, cmap='Blues', aspect='auto')
plt.colorbar(im, ax=ax)

ax.set_xticks(range(seq_len))
ax.set_yticks(range(seq_len))
ax.set_xticklabels(sentence)
ax.set_yticklabels(sentence)
ax.set_xlabel('Keys (attending to)')
ax.set_ylabel('Queries (from)')
ax.set_title('Self-Attention Weights')

for i in range(seq_len):
    for j in range(seq_len):
        ax.text(j, i, f'{weights_np[i, j]:.2f}', ha='center', va='center', fontsize=9)

# 2. Illustration of self-attention connections
ax = axes[1]
ax.set_title('Self-Attention: Each position attends to all positions')

# Draw words
for i, word in enumerate(sentence):
    rect = plt.Rectangle((i*2, 0), 1.5, 0.8, color='lightblue', alpha=0.8)
    ax.add_patch(rect)
    ax.text(i*2 + 0.75, 0.4, word, ha='center', va='center', fontsize=11)

# Draw attention from "sat" (index 2)
query_idx = 2
query_weights = weights_np[query_idx]

for i, w in enumerate(query_weights):
    # Arrow from query to key
    if w > 0.05:  # Only show significant connections
        ax.annotate('', 
                   xy=(i*2 + 0.75, 0.8), 
                   xytext=(query_idx*2 + 0.75, 2),
                   arrowprops=dict(arrowstyle='->', 
                                  color=plt.cm.Reds(w), 
                                  lw=w*10,
                                  alpha=0.8))
        ax.text(i*2 + 0.75, 1.3, f'{w:.2f}', ha='center', fontsize=9)

# Query position
rect = plt.Rectangle((query_idx*2, 2), 1.5, 0.8, color='salmon', alpha=0.8)
ax.add_patch(rect)
ax.text(query_idx*2 + 0.75, 2.4, f'Query: "{sentence[query_idx]}"', ha='center', va='center', fontsize=11)

ax.set_xlim(-0.5, 10)
ax.set_ylim(-0.5, 3.5)
ax.axis('off')

plt.suptitle('Self-Attention Visualization', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Multi-Head Attention

### 4.1 Motivation

Single attention head focuses on one type of relationship.

**Multi-head attention**: Multiple attention heads in parallel
- Each head can learn different relationships
- One head: syntactic relationships
- Another head: semantic relationships
- Another head: coreference resolution

### 4.2 Mathematics

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

- Split $d_{model}$ into $h$ heads, each of dimension $d_k = d_{model} / h$
- Learn different attention patterns
- Concatenate and project back

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention from "Attention Is All You Need".
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_O = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        """
        query: [batch, seq_len_q, d_model]
        key: [batch, seq_len_k, d_model]
        value: [batch, seq_len_v, d_model]
        """
        batch_size = query.size(0)
        
        # 1. Linear projections
        Q = self.W_Q(query)  # [batch, seq_len, d_model]
        K = self.W_K(key)
        V = self.W_V(value)
        
        # 2. Reshape to [batch, num_heads, seq_len, d_k]
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 3. Apply attention
        if mask is not None:
            mask = mask.unsqueeze(1)  # Add head dimension
        
        attn_output, attention_weights = self.attention(Q, K, V, mask)
        # attn_output: [batch, num_heads, seq_len, d_k]
        # attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
        
        # 4. Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        # [batch, seq_len, d_model]
        
        # 5. Final projection
        output = self.W_O(attn_output)
        output = self.dropout(output)
        
        return output, attention_weights

# Test multi-head attention
print("Multi-Head Attention Implementation")
print("="*60)

batch_size = 2
seq_len = 10
d_model = 64
num_heads = 8

x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)

# Self-attention: query = key = value = x
output, weights = mha(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"  - {num_heads} attention heads")
print(f"  - Each head has d_k = {d_model // num_heads}")

# Parameter count
params = sum(p.numel() for p in mha.parameters())
print(f"\nTotal parameters: {params:,}")
print(f"  W_Q, W_K, W_V: {d_model * d_model} each")
print(f"  W_O: {d_model * d_model}")
print(f"  Total: 4 * {d_model}^2 = {4 * d_model**2}")

In [None]:
# Visualize multiple attention heads

sentence = ["The", "cat", "sat", "on", "the", "mat", "."]
seq_len = len(sentence)
d_model = 64
num_heads = 4

x = torch.randn(1, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)
_, attention_weights = mha(x, x, x)

# attention_weights: [1, num_heads, seq_len, seq_len]
weights_np = attention_weights[0].detach().numpy()

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for head_idx in range(num_heads):
    ax = axes[head_idx]
    im = ax.imshow(weights_np[head_idx], cmap='viridis', aspect='auto')
    
    ax.set_xticks(range(seq_len))
    ax.set_yticks(range(seq_len))
    ax.set_xticklabels(sentence, rotation=45, ha='right')
    ax.set_yticklabels(sentence)
    ax.set_title(f'Head {head_idx + 1}', fontsize=12, fontweight='bold')
    
    if head_idx >= 2:
        ax.set_xlabel('Keys')
    if head_idx % 2 == 0:
        ax.set_ylabel('Queries')

plt.suptitle('Multi-Head Attention: Different Heads Learn Different Patterns',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Each head can specialize in different types of relationships:")
print("- Syntactic structure (subject-verb)")
print("- Semantic similarity")
print("- Positional patterns")
print("- Coreference (pronouns to nouns)")

## 5. Positional Encoding

### 5.1 The Position Problem

Self-attention is **permutation invariant**:
- "The cat sat" and "sat cat The" give same attention weights
- Order information is lost!

**Solution**: Add position information to embeddings

### 5.2 Sinusoidal Positional Encoding

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

**Properties:**
1. Unique encoding for each position
2. Fixed for all sequences (no learning needed)
3. Can generalize to longer sequences
4. Relative positions are easy to compute

In [None]:
class PositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding from the Transformer paper.
    """
    
    def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        
        # Compute the positional encodings
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices
        
        # Register as buffer (not a parameter)
        pe = pe.unsqueeze(0)  # [1, max_seq_len, d_model]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        x: [batch, seq_len, d_model]
        """
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

# Visualize positional encoding
d_model = 128
max_len = 100

pe = PositionalEncoding(d_model, max_len, dropout=0.0)
pe_values = pe.pe[0].numpy()  # [max_len, d_model]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Full encoding matrix
ax = axes[0, 0]
im = ax.imshow(pe_values[:50, :64], cmap='RdBu', aspect='auto')
plt.colorbar(im, ax=ax)
ax.set_xlabel('Embedding Dimension')
ax.set_ylabel('Position')
ax.set_title('Positional Encoding Matrix (first 50 pos, 64 dims)')

# 2. Individual dimensions
ax = axes[0, 1]
positions = range(50)
for dim in [0, 1, 10, 11, 50, 51]:
    ax.plot(positions, pe_values[:50, dim], label=f'dim {dim}')
ax.set_xlabel('Position')
ax.set_ylabel('Encoding Value')
ax.set_title('Positional Encoding for Different Dimensions')
ax.legend(loc='upper right', fontsize=8)
ax.grid(True, alpha=0.3)

# 3. Frequency vs dimension
ax = axes[1, 0]
dims = range(0, d_model, 2)
wavelengths = 2 * np.pi * (10000 ** (np.array(dims) / d_model))
ax.plot(dims, wavelengths)
ax.set_xlabel('Dimension')
ax.set_ylabel('Wavelength')
ax.set_title('Wavelength of Sinusoid by Dimension')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

# 4. Similarity between positions
ax = axes[1, 1]
# Compute cosine similarity between positions
num_positions = 30
pos_encodings = pe_values[:num_positions]
norms = np.linalg.norm(pos_encodings, axis=1, keepdims=True)
normalized = pos_encodings / norms
similarity = normalized @ normalized.T

im = ax.imshow(similarity, cmap='coolwarm', vmin=-1, vmax=1)
plt.colorbar(im, ax=ax)
ax.set_xlabel('Position')
ax.set_ylabel('Position')
ax.set_title('Cosine Similarity Between Positional Encodings')

plt.suptitle('Sinusoidal Positional Encoding Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Key insights:")
print("1. Low dimensions: High frequency (rapid changes with position)")
print("2. High dimensions: Low frequency (slow changes)")
print("3. Each position has unique encoding")
print("4. Nearby positions have higher similarity")

## 6. The Transformer Architecture

### 6.1 Overview

The Transformer (Vaswani et al., 2017) uses:
- **Only attention**, no recurrence
- **Encoder-Decoder** structure
- **Multi-head self-attention**
- **Position-wise feedforward networks**
- **Residual connections and layer normalization**

### 6.2 Encoder Block

Each encoder layer:
1. Multi-Head Self-Attention
2. Add & Norm (residual connection + layer norm)
3. Position-wise Feed-Forward Network
4. Add & Norm

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

### 6.3 Decoder Block

Each decoder layer:
1. Masked Multi-Head Self-Attention
2. Add & Norm
3. Multi-Head Cross-Attention (to encoder)
4. Add & Norm
5. Position-wise Feed-Forward Network
6. Add & Norm

In [None]:
class PositionwiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network.
    
    FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
    """
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

class TransformerEncoderLayer(nn.Module):
    """
    Single Transformer Encoder Layer.
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        x: [batch, seq_len, d_model]
        """
        # Self-attention with residual connection
        attn_output, attention_weights = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x, attention_weights

class TransformerEncoder(nn.Module):
    """
    Complete Transformer Encoder.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, 
                 max_seq_len=5000, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
        
        # Encoder layers
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        x: [batch, seq_len] token indices
        """
        # Embed tokens and scale by sqrt(d_model)
        x = self.embedding(x) * math.sqrt(self.d_model)
        
        # Add positional encoding
        x = self.pos_encoding(x)
        
        # Pass through encoder layers
        attention_weights_all = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            attention_weights_all.append(attn_weights)
        
        return x, attention_weights_all

# Build Transformer Encoder
print("Transformer Encoder Implementation")
print("="*60)

vocab_size = 10000
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6

encoder = TransformerEncoder(
    vocab_size=vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    num_layers=num_layers
)

print(f"Configuration:")
print(f"  Vocabulary size: {vocab_size}")
print(f"  Model dimension: {d_model}")
print(f"  Number of heads: {num_heads}")
print(f"  Feed-forward dimension: {d_ff}")
print(f"  Number of layers: {num_layers}")

total_params = sum(p.numel() for p in encoder.parameters())
print(f"\nTotal parameters: {total_params:,}")

# Test
batch_size = 2
seq_len = 20
x = torch.randint(0, vocab_size, (batch_size, seq_len))

output, attention_weights = encoder(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of attention weight tensors: {len(attention_weights)}")

In [None]:
# Visualize Transformer architecture

fig, ax = plt.subplots(figsize=(10, 14))

# Colors
emb_color = '#90EE90'
attn_color = '#FFB6C1'
ffn_color = '#87CEEB'
norm_color = '#DDA0DD'

y_start = 12

# Input
rect = plt.Rectangle((3, y_start), 4, 0.6, color='lightgray', alpha=0.8)
ax.add_patch(rect)
ax.text(5, y_start + 0.3, 'Input Tokens', ha='center', va='center', fontsize=11)

# Embedding
rect = plt.Rectangle((3, y_start - 1.2), 4, 0.6, color=emb_color, alpha=0.8)
ax.add_patch(rect)
ax.text(5, y_start - 0.9, 'Token Embedding + Positional Encoding', ha='center', va='center', fontsize=10)
ax.arrow(5, y_start, 0, -0.3, head_width=0.1, fc='black')

# Encoder layers (draw 2 as example)
for layer_idx in range(2):
    y_offset = layer_idx * 4
    y_base = y_start - 2.5 - y_offset
    
    # Layer box
    rect = plt.Rectangle((2.5, y_base - 3.5), 5, 3.8, fill=False, 
                         edgecolor='black', linewidth=2, linestyle='--')
    ax.add_patch(rect)
    ax.text(0.5, y_base - 1.7, f'Encoder\nLayer {layer_idx + 1}', ha='center', 
           va='center', fontsize=11, fontweight='bold')
    
    # Multi-Head Attention
    rect = plt.Rectangle((3, y_base - 0.8), 4, 0.6, color=attn_color, alpha=0.8)
    ax.add_patch(rect)
    ax.text(5, y_base - 0.5, 'Multi-Head\nSelf-Attention', ha='center', va='center', fontsize=9)
    
    # Add & Norm
    rect = plt.Rectangle((3, y_base - 1.8), 4, 0.6, color=norm_color, alpha=0.8)
    ax.add_patch(rect)
    ax.text(5, y_base - 1.5, 'Add & Norm', ha='center', va='center', fontsize=9)
    
    # Feed-Forward
    rect = plt.Rectangle((3, y_base - 2.8), 4, 0.6, color=ffn_color, alpha=0.8)
    ax.add_patch(rect)
    ax.text(5, y_base - 2.5, 'Feed Forward', ha='center', va='center', fontsize=9)
    
    # Add & Norm
    rect = plt.Rectangle((3, y_base - 3.8), 4, 0.6, color=norm_color, alpha=0.8)
    ax.add_patch(rect)
    ax.text(5, y_base - 3.5, 'Add & Norm', ha='center', va='center', fontsize=9)
    
    # Arrows
    ax.arrow(5, y_base, 0, -0.1, head_width=0.1, fc='black')
    ax.arrow(5, y_base - 0.8, 0, -0.3, head_width=0.1, fc='black')
    ax.arrow(5, y_base - 1.8, 0, -0.3, head_width=0.1, fc='black')
    ax.arrow(5, y_base - 2.8, 0, -0.3, head_width=0.1, fc='black')
    
    # Residual connections
    ax.plot([7.5, 8, 8, 7.5], [y_base, y_base, y_base - 1.5, y_base - 1.5], 'g-', linewidth=2)
    ax.plot([7.5, 8, 8, 7.5], [y_base - 1.8, y_base - 1.8, y_base - 3.5, y_base - 3.5], 'g-', linewidth=2)

# ... (more layers)
ax.text(5, y_start - 11, '⋮', ha='center', fontsize=20)
ax.text(5, y_start - 11.5, f'(×{num_layers} layers)', ha='center', fontsize=10)

# Output
rect = plt.Rectangle((3, y_start - 13), 4, 0.6, color='lightyellow', alpha=0.8)
ax.add_patch(rect)
ax.text(5, y_start - 12.7, 'Encoder Output', ha='center', va='center', fontsize=11)

ax.set_xlim(0, 10)
ax.set_ylim(-2, 13)
ax.axis('off')
ax.set_title('Transformer Encoder Architecture', fontsize=16, fontweight='bold', pad=20)

# Legend
legend_elements = [
    plt.Rectangle((0, 0), 1, 1, color=emb_color, alpha=0.8, label='Embedding'),
    plt.Rectangle((0, 0), 1, 1, color=attn_color, alpha=0.8, label='Attention'),
    plt.Rectangle((0, 0), 1, 1, color=ffn_color, alpha=0.8, label='Feed-Forward'),
    plt.Rectangle((0, 0), 1, 1, color=norm_color, alpha=0.8, label='Layer Norm'),
]
ax.legend(handles=legend_elements, loc='upper left')

plt.tight_layout()
plt.show()

## 7. Transformer for Classification

### 7.1 Using Encoder for Sentence Classification

In [None]:
class TransformerClassifier(nn.Module):
    """
    Transformer Encoder for sequence classification.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers,
                 num_classes, max_seq_len=512, dropout=0.1, pooling='cls'):
        super().__init__()
        
        self.pooling = pooling
        self.encoder = TransformerEncoder(
            vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len, dropout
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, num_classes)
        )
    
    def forward(self, x, mask=None):
        """
        x: [batch, seq_len]
        """
        # Encode
        encoded, _ = self.encoder(x, mask)  # [batch, seq_len, d_model]
        
        # Pool
        if self.pooling == 'cls':
            # Use first token (like BERT's [CLS])
            pooled = encoded[:, 0, :]
        elif self.pooling == 'mean':
            # Mean pooling
            pooled = encoded.mean(dim=1)
        elif self.pooling == 'max':
            # Max pooling
            pooled = encoded.max(dim=1)[0]
        
        # Classify
        logits = self.classifier(pooled)
        
        return logits

# Test classifier
print("Transformer Classifier")
print("="*60)

model = TransformerClassifier(
    vocab_size=10000,
    d_model=128,
    num_heads=4,
    d_ff=512,
    num_layers=2,
    num_classes=2,
    pooling='mean'
)

print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
batch_size = 4
seq_len = 50
x = torch.randint(0, 10000, (batch_size, seq_len))
output = model(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output (logits): {output}")

## 8. Transformers vs RNNs

### 8.1 Computational Comparison

| Aspect | RNN/LSTM | Transformer |
|--------|----------|-------------|
| Parallelization | Sequential | Fully parallel |
| Long-range dependencies | Difficult (vanishing gradient) | Direct connections |
| Computational complexity | O(n) per layer | O(n²) per layer |
| Memory | O(1) for long sequences | O(n²) attention matrix |
| Training speed | Slow (sequential) | Fast (parallel) |
| Interpretability | Hidden state is opaque | Attention weights interpretable |

### 8.2 Path Length Comparison

Maximum path length between any two positions:
- **RNN**: O(n) - must traverse sequence
- **Transformer**: O(1) - direct attention connection

In [None]:
# Compare Transformer vs LSTM on sequence length

import time

def measure_forward_time(model, seq_lengths, vocab_size, batch_size=32, num_runs=3):
    """Measure forward pass time for different sequence lengths."""
    times = []
    
    for seq_len in seq_lengths:
        x = torch.randint(0, vocab_size, (batch_size, seq_len))
        
        # Warm up
        _ = model(x)
        
        # Measure
        run_times = []
        for _ in range(num_runs):
            start = time.time()
            _ = model(x)
            end = time.time()
            run_times.append(end - start)
        
        times.append(np.mean(run_times))
    
    return times

# Create models
vocab_size = 5000
d_model = 64

# LSTM classifier
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_classes=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.lstm = nn.LSTM(d_model, d_model, batch_first=True)
        self.fc = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        embedded = self.embedding(x)
        _, (hidden, _) = self.lstm(embedded)
        return self.fc(hidden.squeeze(0))

lstm_model = LSTMClassifier(vocab_size, d_model)
transformer_model = TransformerClassifier(
    vocab_size, d_model, num_heads=4, d_ff=256, num_layers=2, num_classes=2
)

# Measure times
seq_lengths = [10, 25, 50, 100, 200]

print("Measuring forward pass times...")
lstm_times = measure_forward_time(lstm_model, seq_lengths, vocab_size)
transformer_times = measure_forward_time(transformer_model, seq_lengths, vocab_size)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Time comparison
ax = axes[0]
ax.plot(seq_lengths, lstm_times, 'b-o', linewidth=2, markersize=8, label='LSTM')
ax.plot(seq_lengths, transformer_times, 'r-s', linewidth=2, markersize=8, label='Transformer')
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Forward Pass Time (seconds)')
ax.set_title('Inference Time vs Sequence Length')
ax.legend()
ax.grid(True, alpha=0.3)

# Path length illustration
ax = axes[1]
rnn_path = seq_lengths  # O(n)
transformer_path = [1] * len(seq_lengths)  # O(1)

ax.plot(seq_lengths, rnn_path, 'b-o', linewidth=2, markersize=8, label='RNN (O(n))')
ax.plot(seq_lengths, transformer_path, 'r-s', linewidth=2, markersize=8, label='Transformer (O(1))')
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Maximum Path Length')
ax.set_title('Path Length Between Distant Positions')
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle('Transformer vs LSTM Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nKey Differences:")
print("1. Transformer: Constant path length for any two positions")
print("2. LSTM: Path length grows with distance")
print("3. Transformer: Better at long-range dependencies")
print("4. Transformer: More parallelizable (but O(n²) memory)")

## 9. Why Transformers Changed Everything

### 9.1 The Revolution

Since "Attention Is All You Need" (2017):

**NLP:**
- BERT (2018): Bidirectional encoder, pretraining
- GPT (2018-2023): Autoregressive generation
- T5 (2020): Text-to-text framework
- ChatGPT, Claude, etc.

**Computer Vision:**
- Vision Transformer (ViT, 2020)
- DETR for object detection
- SWIN Transformer

**Other Domains:**
- AlphaFold 2 (protein structure)
- Music generation
- Code generation

### 9.2 Key Innovations

1. **Scalability**: Train on massive datasets efficiently
2. **Transfer learning**: Pre-train once, fine-tune for many tasks
3. **Long-range dependencies**: Direct attention connections
4. **Interpretability**: Attention weights show model focus

In [None]:
# Timeline of Transformer-based models

fig, ax = plt.subplots(figsize=(14, 8))

models = [
    (2017, 'Transformer', 'Original', 65e6),
    (2018.4, 'GPT-1', 'Decoder', 117e6),
    (2018.8, 'BERT', 'Encoder', 340e6),
    (2019.2, 'GPT-2', 'Decoder', 1.5e9),
    (2019.8, 'XLNet', 'Encoder', 340e6),
    (2020.0, 'T5', 'Enc-Dec', 11e9),
    (2020.5, 'GPT-3', 'Decoder', 175e9),
    (2022.0, 'ChatGPT', 'Decoder', 175e9),
    (2023.0, 'GPT-4', 'Decoder', 1000e9),
]

years = [m[0] for m in models]
names = [m[1] for m in models]
types = [m[2] for m in models]
params = [m[3] for m in models]

# Color by type
type_colors = {'Original': 'gray', 'Encoder': 'blue', 'Decoder': 'green', 'Enc-Dec': 'orange'}
colors = [type_colors[t] for t in types]

# Plot
ax.scatter(years, np.log10(params), c=colors, s=200, alpha=0.7, zorder=5)

for year, name, param in zip(years, names, params):
    ax.annotate(f'{name}\n({param/1e9:.1f}B)' if param >= 1e9 else f'{name}\n({param/1e6:.0f}M)',
               (year, np.log10(param)),
               textcoords="offset points",
               xytext=(0, 15),
               ha='center',
               fontsize=9)

ax.set_xlabel('Year', fontsize=12)
ax.set_ylabel('Parameters (log scale)', fontsize=12)
ax.set_yticks([7, 8, 9, 10, 11, 12])
ax.set_yticklabels(['10M', '100M', '1B', '10B', '100B', '1T'])
ax.set_title('Growth of Transformer-Based Models', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)

# Legend
for type_name, color in type_colors.items():
    ax.scatter([], [], c=color, s=100, label=type_name, alpha=0.7)
ax.legend(loc='upper left', title='Architecture')

plt.tight_layout()
plt.show()

print("Exponential growth in model size!")
print(f"From {models[0][3]/1e6:.0f}M parameters (2017) to ~1T parameters (2023)")
print("That's ~15,000x increase in 6 years!")

## 10. Summary and Key Takeaways

### What We Learned

1. **Attention solves the bottleneck problem**
   - Decoder can access all encoder states
   - Learned alignments between sequences

2. **Self-attention captures relationships**
   - Query-Key-Value framework
   - Scaled dot-product: $\text{softmax}(QK^T/\sqrt{d_k})V$
   - Direct connection between any positions

3. **Multi-head attention enables diverse patterns**
   - Multiple attention heads in parallel
   - Each head specializes in different relationships

4. **Positional encoding provides order information**
   - Sinusoidal encodings
   - Learnable alternatives

5. **Transformer architecture:**
   - Multi-Head Attention → Add & Norm → FFN → Add & Norm
   - Residual connections preserve gradients
   - Fully parallelizable

6. **Advantages over RNNs:**
   - Constant path length for long-range dependencies
   - Parallelizable training
   - Better scalability

### Next Steps (Days 16-20)

- **Day 16**: Sequence-to-Sequence models with attention
- **Day 17**: BERT and pretraining
- **Day 18**: GPT and autoregressive models
- **Day 19**: Fine-tuning pretrained models
- **Day 20**: Project 2 - NLP Application

## Exercises

### Exercise 1: Masked Self-Attention
Implement causal (masked) self-attention for autoregressive generation.

### Exercise 2: Transformer Decoder
Build the complete Transformer decoder with masked self-attention and cross-attention.

### Exercise 3: Learnable Positional Encodings
Implement learnable positional embeddings and compare with sinusoidal.

### Exercise 4: Relative Position Attention
Implement relative positional encodings (Shaw et al., 2018).

### Exercise 5: Attention Visualization Tool
Build an interactive tool to visualize attention patterns for different inputs.

### Exercise 6: Sparse Attention
Implement local/sparse attention patterns (like Longformer) for long sequences.

### Exercise 7: Vision Transformer
Apply the Transformer architecture to image classification (ViT-style).

In [None]:
# Starter code for Exercise 1: Masked Self-Attention

def create_causal_mask(seq_len):
    """
    Create causal mask for autoregressive attention.
    
    Position i can only attend to positions j <= i
    """
    # Upper triangular matrix with -inf (will become 0 after softmax)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return ~mask  # True = attend, False = mask

# Visualize causal mask
seq_len = 6
mask = create_causal_mask(seq_len)

plt.figure(figsize=(6, 6))
plt.imshow(mask.float().numpy(), cmap='Blues')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Mask\n(White = Masked, Blue = Attend)')
plt.xticks(range(seq_len))
plt.yticks(range(seq_len))

for i in range(seq_len):
    for j in range(seq_len):
        text = '✓' if mask[i, j] else '✗'
        plt.text(j, i, text, ha='center', va='center', fontsize=12)

plt.tight_layout()
plt.show()

print("Causal mask ensures position i only attends to positions <= i")
print("This enables autoregressive (left-to-right) generation.")
print("\nExercise: Integrate this mask into MultiHeadAttention for decoder!")

## References

1. Vaswani, A., et al. (2017). "Attention Is All You Need." NeurIPS.
2. Bahdanau, D., et al. (2015). "Neural Machine Translation by Jointly Learning to Align and Translate." ICLR.
3. Luong, M. T., et al. (2015). "Effective Approaches to Attention-based Neural Machine Translation." EMNLP.
4. Devlin, J., et al. (2019). "BERT: Pre-training of Deep Bidirectional Transformers." NAACL.
5. Radford, A., et al. (2018-2023). "GPT Series." OpenAI.
6. Shaw, P., et al. (2018). "Self-Attention with Relative Position Representations." NAACL.
7. Dosovitskiy, A., et al. (2020). "An Image is Worth 16x16 Words: Transformers for Image Recognition." ICLR.