# Chapter 3: Transformers

Welcome to Chapter 3! Here we explore the revolutionary Transformer architecture that has transformed NLP and is now impacting biology.

## 📚 Table of Contents
1. [Introduction to Attention](#intro)
2. [Self-Attention Mechanism](#self-attention)
3. [Multi-Head Attention](#multi-head)
4. [Positional Encoding](#positional)
5. [Complete Transformer Architecture](#architecture)
6. [BERT and GPT](#bert-gpt)
7. [Biology Application: Protein Sequence Analysis](#biology-app)

---

In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import seaborn as sns
from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2Model

plt.style.use('seaborn-v0_8-darkgrid')
np.random.seed(42)
torch.manual_seed(42)

print('✓ Libraries imported')
print(f'PyTorch: {torch.__version__}')

## 1. Introduction to Attention <a id="intro"></a>

### The Problem with RNNs

Traditional RNNs process sequences sequentially:
- **Slow**: Cannot parallelize
- **Memory**: Struggle with long sequences
- **Gradient**: Vanishing/exploding gradients

### The Attention Solution

**Key Idea**: Let the model learn which parts of the input to focus on.

**Attention Formula**:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

where:
- $Q$ = Query (what we're looking for)
- $K$ = Key (what we have)
- $V$ = Value (what we return)
- $d_k$ = dimension of keys (for scaling)

### Intuition

Think of it like a dictionary lookup:
1. **Query**: "What am I looking for?"
2. **Key**: "Does this match what you want?"
3. **Value**: "Here's the information"

The attention mechanism computes similarity between query and all keys, then returns a weighted sum of values.

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query matrix (batch, seq_len, d_k)
        K: Key matrix (batch, seq_len, d_k)
        V: Value matrix (batch, seq_len, d_v)
        mask: Optional mask
    
    Returns:
        output: Attention output
        attention_weights: Attention scores
    """
    d_k = Q.size(-1)
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Compute output
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test with simple example
print('Testing Attention Mechanism:')
seq_len = 4
d_model = 8

Q = torch.randn(1, seq_len, d_model)
K = torch.randn(1, seq_len, d_model)
V = torch.randn(1, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f'\nInput shapes:')
print(f'  Q: {Q.shape}')
print(f'  K: {K.shape}')
print(f'  V: {V.shape}')
print(f'\nOutput shapes:')
print(f'  Output: {output.shape}')
print(f'  Attention weights: {weights.shape}')

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(weights[0].detach().numpy(), annot=True, fmt='.2f', 
            cmap='YlOrRd', cbar_kws={'label': 'Attention Weight'})
plt.xlabel('Key Position', fontsize=12)
plt.ylabel('Query Position', fontsize=12)
plt.title('Attention Weight Matrix', fontsize=14, weight='bold')
plt.tight_layout()
plt.show()

print('\n💡 Each row shows where each query attends to!')

## 2. Self-Attention Mechanism <a id="self-attention"></a>

In **self-attention**, Q, K, and V all come from the same input!

### Why Self-Attention?

Allows each position to attend to all positions in the sequence:
- Capture long-range dependencies
- Parallel computation
- No distance bias

### How it Works

1. Start with input embeddings $X$
2. Create Q, K, V by linear transformations:
   - $Q = XW_Q$
   - $K = XW_K$
   - $V = XW_V$
3. Compute attention

### Example: Protein Sequence

For sequence "ACGT":
- A might attend strongly to C (if they interact)
- G might attend to T (complementary bases)
- Self-attention learns these relationships!

In [None]:
class SelfAttention(nn.Module):
    """Self-Attention layer."""
    
    def __init__(self, d_model):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        
        # Linear transformations for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        # Create Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Compute attention
        output, attention_weights = scaled_dot_product_attention(Q, K, V)
        
        return output, attention_weights

# Test
d_model = 64
seq_len = 5
batch_size = 2

x = torch.randn(batch_size, seq_len, d_model)
self_attn = SelfAttention(d_model)

output, weights = self_attn(x)

print('Self-Attention Layer:')
print(f'Input shape: {x.shape}')
print(f'Output shape: {output.shape}')
print(f'Attention weights shape: {weights.shape}')
print('\n✓ Self-attention preserves sequence length!')

## 3. Multi-Head Attention <a id="multi-head"></a>

### Why Multiple Heads?

Single attention focuses on one aspect. Multiple heads allow:
- Different representation subspaces
- Attend to different positions
- Capture various relationships

### Formula

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where each head is:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

### Analogy

Think of multiple editors reviewing the same text:
- One checks grammar
- One checks style  
- One checks content
- Combined feedback is comprehensive!

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention layer."""
    
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear layers for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        """Split into multiple heads."""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Combine multiple heads."""
        batch_size, _, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, x):
        # Linear transformations
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Split into heads
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Attention for each head
        output, attention = scaled_dot_product_attention(Q, K, V)
        
        # Combine heads
        output = self.combine_heads(output)
        
        # Final linear
        output = self.W_o(output)
        
        return output, attention

# Test
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)

output, attention = mha(x)

print('Multi-Head Attention:')
print(f'Number of heads: {num_heads}')
print(f'Model dimension: {d_model}')
print(f'Dimension per head: {d_model // num_heads}')
print(f'\nInput shape: {x.shape}')
print(f'Output shape: {output.shape}')
print(f'Attention shape: {attention.shape}')
print('\n✓ Each head attends differently!')

## 4. Positional Encoding <a id="positional"></a>

### The Problem

Attention has no notion of position! "ACGT" vs "TGCA" would be treated the same.

### Solution: Positional Encoding

Add position information to embeddings:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)$$

where:
- $pos$ = position in sequence
- $i$ = dimension index
- $d$ = model dimension

### Why Sinusoidal?

- Allows model to learn relative positions
- Works for sequences longer than training
- Smooth, continuous representation

In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding using sinusoidal functions."""
    
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # Add batch dimension
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # Add positional encoding
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize positional encoding
d_model = 128
max_len = 50
pos_enc = PositionalEncoding(d_model, max_len)

# Get encodings
encodings = pos_enc.pe[0, :max_len, :].numpy()

plt.figure(figsize=(14, 6))
plt.imshow(encodings.T, cmap='RdBu', aspect='auto')
plt.colorbar(label='Encoding Value')
plt.xlabel('Position in Sequence', fontsize=12)
plt.ylabel('Encoding Dimension', fontsize=12)
plt.title('Positional Encoding Visualization', fontsize=14, weight='bold')
plt.tight_layout()
plt.show()

print('\n💡 Key Properties:')
print('  - Each position has unique encoding')
print('  - Different frequencies for different dimensions')
print('  - Allows model to learn relative positions')