# Chapter 3: Transformers

Welcome to one of the most exciting chapters! 🚀

Transformers revolutionized AI starting in 2017. They power ChatGPT, Google Translate, protein structure prediction (AlphaFold), and much more. In this chapter, we'll demystify how they work.

**What you'll learn:**
- Why transformers are different from RNNs and CNNs
- The "attention mechanism" - the key innovation (simpler than it sounds!)
- How transformers process sequences (text, DNA, proteins)
- Apply transformers to biological sequence analysis

**Prerequisites:**
- Chapter 1 (Neural Networks Basics)
- Basic understanding of sequences (like sentences or DNA)
- Comfort with matrix operations (we'll explain the concepts visually)

**Don't be intimidated!** Transformers may sound complex, but the core idea is intuitive: "pay attention to the important parts."

## 📚 Table of Contents
1. [The Problem with Sequential Processing](#problem)
2. [Attention Mechanism - The Key Idea](#attention)
3. [Self-Attention Explained](#self-attention)
4. [Multi-Head Attention](#multi-head)
5. [Positional Encoding](#positional)
6. [The Complete Transformer Architecture](#architecture)
7. [BERT and GPT Overview](#models)
8. [Biology Application: Protein Sequence Analysis](#biology-app)

---


In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import seaborn as sns
from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2Model

plt.style.use('seaborn-v0_8-darkgrid')
np.random.seed(42)
torch.manual_seed(42)

print('✓ Libraries imported')
print(f'PyTorch: {torch.__version__}')

## 1. Introduction to Attention <a id="intro"></a>

### The Problem with RNNs

Traditional RNNs process sequences sequentially:
- **Slow**: Cannot parallelize
- **Memory**: Struggle with long sequences
- **Gradient**: Vanishing/exploding gradients

### The Attention Solution

**Key Idea**: Let the model learn which parts of the input to focus on.

**Attention Formula**:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

where:
- $Q$ = Query (what we're looking for)
- $K$ = Key (what we have)
- $V$ = Value (what we return)
- $d_k$ = dimension of keys (for scaling)

### Intuition

Think of it like a dictionary lookup:
1. **Query**: "What am I looking for?"
2. **Key**: "Does this match what you want?"
3. **Value**: "Here's the information"

The attention mechanism computes similarity between query and all keys, then returns a weighted sum of values.

### Attention Mechanism: Connections to Classical Methods

The attention mechanism [@vaswani2017attention; @bahdanau2014neural] has deep connections to classical statistical and machine learning concepts.

#### Attention as Weighted Regression

At its core, attention is performing **weighted regression** or **kernel smoothing** [@hastie2009elements]:

$$\text{Attention}(Q, K, V) = \sum_{i} \alpha_i V_i$$

where the weights $\alpha_i$ are computed based on the similarity between query $Q$ and keys $K_i$:

$$\alpha_i = \frac{\exp(Q \cdot K_i / \sqrt{d_k})}{\sum_j \exp(Q \cdot K_j / \sqrt{d_k})}$$

**This is analogous to:**

1. **Kernel Regression** [@scholkopf2002learning]: 
   - Predict output as weighted average of training outputs
   - Weights based on kernel similarity $k(x, x_i)$
   - Attention uses $\exp(Q \cdot K_i / \sqrt{d_k})$ as the "kernel"

2. **k-Nearest Neighbors (k-NN) with soft weights:**
   - Instead of hard selection of k neighbors, attention uses soft weights
   - All positions contribute, but closer ones (higher similarity) contribute more

3. **Nadaraya-Watson Estimator** (non-parametric regression):
   $$\hat{f}(x) = \frac{\sum_i K(x, x_i) y_i}{\sum_i K(x, x_i)}$$
   - Replace kernel $K$ with attention weights → you get attention!

#### Connection to Matrix Factorization and Low-Rank Approximation

The attention mechanism can be viewed as a form of **matrix factorization** [@koren2009matrix]:

Given input $X$ (sequence length $n$ × embedding dimension $d$):
- $Q = XW_Q$, $K = XW_K$, $V = XW_V$ (where $W$ matrices project to dimension $d_k$)
- Attention computes: $\text{softmax}(QK^T/\sqrt{d_k}) V$

**Key insight:** $QK^T$ is a **rank-$d_k$ approximation** to the full $n \times n$ interaction matrix!

If we wanted to model all pairwise interactions directly:
- We'd need an $n \times n$ parameter matrix (quadratic in sequence length)
- Instead, we factor it: $QK^T = (XW_Q)(XW_K)^T$ 
- This is a **low-rank factorization** with rank at most $d_k$ (much smaller than $n$ for long sequences)

**Connection to classical methods:**
- **Principal Component Analysis (PCA):** Low-rank approximation to covariance matrix
- **Singular Value Decomposition (SVD):** Matrix factorization $A = U\Sigma V^T$
- **Attention:** Learned, query-dependent low-rank approximation to interaction matrix

**Why this matters:** Like PCA/SVD, attention captures the "most important" interactions in a compressed form, but unlike classical methods, it's:
- Learned end-to-end for the task
- Query-dependent (changes based on what we're looking for)
- Non-linear (through softmax and subsequent layers)

#### Attention as an Associative Memory

From a neuroscience perspective, attention implements an **associative memory** similar to Hopfield networks:

- **Key-value storage:** Keys $K$ are like memory addresses, Values $V$ are stored content
- **Retrieval:** Query $Q$ retrieves a weighted combination of values based on similarity to keys
- **Soft retrieval:** Unlike exact memory lookup, attention performs soft, differentiable retrieval

This is analogous to **content-addressable memory** in computer science or **prototype-based models** in psychology and machine learning.



In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query matrix (batch, seq_len, d_k)
        K: Key matrix (batch, seq_len, d_k)
        V: Value matrix (batch, seq_len, d_v)
        mask: Optional mask
    
    Returns:
        output: Attention output
        attention_weights: Attention scores
    """
    d_k = Q.size(-1)
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Compute output
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test with simple example
print('Testing Attention Mechanism:')
seq_len = 4
d_model = 8

Q = torch.randn(1, seq_len, d_model)
K = torch.randn(1, seq_len, d_model)
V = torch.randn(1, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f'\nInput shapes:')
print(f'  Q: {Q.shape}')
print(f'  K: {K.shape}')
print(f'  V: {V.shape}')
print(f'\nOutput shapes:')
print(f'  Output: {output.shape}')
print(f'  Attention weights: {weights.shape}')

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(weights[0].detach().numpy(), annot=True, fmt='.2f', 
            cmap='YlOrRd', cbar_kws={'label': 'Attention Weight'})
plt.xlabel('Key Position', fontsize=12)
plt.ylabel('Query Position', fontsize=12)
plt.title('Attention Weight Matrix', fontsize=14, weight='bold')
plt.tight_layout()
plt.show()

print('\n💡 Each row shows where each query attends to!')

## 2. Self-Attention Mechanism <a id="self-attention"></a>

In **self-attention**, Q, K, and V all come from the same input!

### Why Self-Attention?

Allows each position to attend to all positions in the sequence:
- Capture long-range dependencies
- Parallel computation
- No distance bias

### How it Works

1. Start with input embeddings $X$
2. Create Q, K, V by linear transformations:
   - $Q = XW_Q$
   - $K = XW_K$
   - $V = XW_V$
3. Compute attention

### Example: Protein Sequence

For sequence "ACGT":
- A might attend strongly to C (if they interact)
- G might attend to T (complementary bases)
- Self-attention learns these relationships!

### Self-Attention: Graph-Theoretic and Statistical Perspectives

#### Self-Attention as Graph Neural Network

Self-attention can be viewed as a complete **graph neural network** where:
- Nodes: Each position in the sequence
- Edges: Attention weights between all pairs of positions
- Message passing: Each node aggregates information from all other nodes, weighted by attention

**Connection to graph theory:**
- Self-attention computes on a **fully-connected graph**
- Attention weights define edge strengths (learned, not fixed)
- Similar to **spectral graph convolutions** but with learned adjacency matrix

#### Relationship to Covariance and Correlation

The $QK^T$ matrix in self-attention computes pairwise similarities between all positions:

$$(QK^T)_{ij} = q_i^T k_j = \sum_{d=1}^{d_k} q_{i,d} k_{j,d}$$

This is similar to computing a **correlation matrix** in statistics, where we measure how related different variables (positions) are.

**Classical covariance matrix:**
$$\Sigma_{ij} = \text{Cov}(X_i, X_j) = \mathbb{E}[(X_i - \mu_i)(X_j - \mu_j)]$$

**Self-attention similarity:**
$$A_{ij} = \text{softmax}(q_i^T k_j / \sqrt{d_k})$$

Both capture relationships between elements, but:
- Covariance is computed from data statistics (second moments)
- Attention is computed through learned projections (task-specific)
- Attention uses softmax normalization (creates a probability distribution)

#### Connection to Factor Analysis

Self-attention with multiple heads performs something similar to **factor analysis** [@bishop2006pattern]:

**Factor analysis:** Assumes observed variables are generated from latent factors:
$$X = WF + \epsilon$$

**Multi-head attention:** Different heads attend to different "latent aspects" of the sequence:
$$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

Each head can be thought of as discovering a different "factor" or "aspect" of the data:
- Head 1: Local patterns (neighboring positions)
- Head 2: Long-range dependencies  
- Head 3: Specific semantic relationships



In [None]:
class SelfAttention(nn.Module):
    """Self-Attention layer."""
    
    def __init__(self, d_model):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        
        # Linear transformations for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        # Create Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Compute attention
        output, attention_weights = scaled_dot_product_attention(Q, K, V)
        
        return output, attention_weights

# Test
d_model = 64
seq_len = 5
batch_size = 2

x = torch.randn(batch_size, seq_len, d_model)
self_attn = SelfAttention(d_model)

output, weights = self_attn(x)

print('Self-Attention Layer:')
print(f'Input shape: {x.shape}')
print(f'Output shape: {output.shape}')
print(f'Attention weights shape: {weights.shape}')
print('\n✓ Self-attention preserves sequence length!')

## 3. Multi-Head Attention <a id="multi-head"></a>

### Why Multiple Heads?

Single attention focuses on one aspect. Multiple heads allow:
- Different representation subspaces
- Attend to different positions
- Capture various relationships

### Formula

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where each head is:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

### Analogy

Think of multiple editors reviewing the same text:
- One checks grammar
- One checks style  
- One checks content
- Combined feedback is comprehensive!

### Multi-Head Attention: Ensemble Methods and Subspace Learning

#### Connection to Ensemble Learning

Multi-head attention is conceptually similar to **ensemble methods** in machine learning [@hastie2009elements]:

**Random Forests:** Combine multiple decision trees trained on different data subsets
**Boosting:** Combine multiple weak learners with different weights
**Multi-head Attention:** Combine multiple attention "experts" focusing on different representation subspaces

**Key differences:**
- Random forests: Trees operate independently on same features
- Multi-head attention: Heads operate independently on different learned projections

Each attention head learns a different **view** of the relationships in the sequence, analogous to:
- Different features in feature selection
- Different components in PCA
- Different factors in factor analysis

#### Subspace Learning and Dimensionality Reduction

Each attention head operates in a lower-dimensional subspace:
- Full embedding dimension: $d_{\text{model}}$ (e.g., 512)
- Each head dimension: $d_k = d_{\text{model}} / h$ (e.g., 64 with 8 heads)

This is a form of **dimensionality reduction** similar to:

**PCA projection:** Project data to top principal components
**Random projection:** Project to random lower-dimensional subspace (Johnson-Lindenstrauss lemma)
**Attention heads:** Project to *learned* task-specific subspaces

**Why multiple subspaces help:**
- In statistics, we know that data often lies on lower-dimensional manifolds
- Different aspects of data may live in different subspaces
- Multi-head attention automatically discovers these task-relevant subspaces

#### Mathematical Formulation as Block-Structured Computation

Multi-head attention can be written as:

$$\text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)$$
$$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

This creates a **block-structured transformation**, where:
- Input dimension is partitioned into $h$ blocks
- Each block is transformed independently
- Results are concatenated and linearly combined

**Connection to group theory:** This structure is related to **block-diagonal matrices** and **direct sum decompositions** in linear algebra - we're decomposing the transformation into independent subproblems.



In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention layer."""
    
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear layers for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        """Split into multiple heads."""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Combine multiple heads."""
        batch_size, _, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, x):
        # Linear transformations
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Split into heads
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Attention for each head
        output, attention = scaled_dot_product_attention(Q, K, V)
        
        # Combine heads
        output = self.combine_heads(output)
        
        # Final linear
        output = self.W_o(output)
        
        return output, attention

# Test
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)

output, attention = mha(x)

print('Multi-Head Attention:')
print(f'Number of heads: {num_heads}')
print(f'Model dimension: {d_model}')
print(f'Dimension per head: {d_model // num_heads}')
print(f'\nInput shape: {x.shape}')
print(f'Output shape: {output.shape}')
print(f'Attention shape: {attention.shape}')
print('\n✓ Each head attends differently!')

## 4. Positional Encoding <a id="positional"></a>

### The Problem

Attention has no notion of position! "ACGT" vs "TGCA" would be treated the same.

### Solution: Positional Encoding

Add position information to embeddings:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)$$

where:
- $pos$ = position in sequence
- $i$ = dimension index
- $d$ = model dimension

### Why Sinusoidal?

- Allows model to learn relative positions
- Works for sequences longer than training
- Smooth, continuous representation

### Positional Encoding: Fourier Analysis Connection

#### Connection to Fourier Series

The sinusoidal positional encoding:
$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d})$$

is directly inspired by **Fourier analysis** [@strang2016introduction]!

**Classical Fourier series:** Any periodic function can be represented as:
$$f(t) = a_0 + \sum_{n=1}^{\infty} [a_n \cos(n\omega t) + b_n \sin(n\omega t)]$$

**Positional encoding:** Uses different frequencies for different dimensions:
- Low dimensions: High frequency (capture fine-grained position differences)
- High dimensions: Low frequency (capture coarse position information)

**Why this works:**
- **Completeness:** Fourier bases can represent any function (universal approximation)
- **Orthogonality:** Different frequencies are orthogonal → independent information
- **Smooth interpolation:** Sinusoidal functions generalize well to unseen positions
- **Relative position:** $PE_{pos+k}$ can be written as a linear function of $PE_{pos}$ (enables learning relative positions)

#### Connection to Feature Engineering in Time Series

In classical time series analysis [@hastie2009elements], we often add temporal features:
- Hour of day (cyclical: 0-23)
- Day of week (cyclical: 0-6)
- Month of year (cyclical: 0-11)

For cyclical features, we use **sine-cosine encoding**:
$$x_{\text{hour}} \rightarrow [\sin(2\pi \cdot x_{\text{hour}}/24), \cos(2\pi \cdot x_{\text{hour}}/24)]$$

**Positional encoding does the same thing** but:
- Uses multiple frequencies (not just one cycle)
- Applies to abstract sequence positions (not necessarily time)
- Provides a rich, high-dimensional positional representation

#### Alternative View: Positional Basis Functions

Positional encoding provides a set of **basis functions** over position space:
- Each dimension is a basis function (sine or cosine of different frequency)
- Position is represented as coordinates in this basis
- Similar to representing functions in terms of basis functions in functional analysis

This is analogous to:
- **Polynomial bases:** $1, x, x^2, x^3, ...$ (used in polynomial regression)
- **Radial basis functions (RBF):** $\exp(-||x - c||^2)$ (used in RBF networks)
- **Wavelet bases:** Multiscale basis functions (used in signal processing)

The transformer learns to use this positional basis in combination with content information to solve the task.



In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding using sinusoidal functions."""
    
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # Add batch dimension
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # Add positional encoding
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize positional encoding
d_model = 128
max_len = 50
pos_enc = PositionalEncoding(d_model, max_len)

# Get encodings
encodings = pos_enc.pe[0, :max_len, :].numpy()

plt.figure(figsize=(14, 6))
plt.imshow(encodings.T, cmap='RdBu', aspect='auto')
plt.colorbar(label='Encoding Value')
plt.xlabel('Position in Sequence', fontsize=12)
plt.ylabel('Encoding Dimension', fontsize=12)
plt.title('Positional Encoding Visualization', fontsize=14, weight='bold')
plt.tight_layout()
plt.show()

print('\n💡 Key Properties:')
print('  - Each position has unique encoding')
print('  - Different frequencies for different dimensions')
print('  - Allows model to learn relative positions')