# Chapter 4: Vision Transformers (ViT)

Welcome to the cutting edge of computer vision! 🔬

In 2020, researchers asked: "Can we apply transformers (from Chapter 3) to images?" The answer was a resounding YES! Vision Transformers (ViT) now compete with or beat CNNs on many tasks.

**What you'll learn:**
- How to turn images into sequences (so transformers can process them)
- The Vision Transformer architecture
- When to use ViT vs CNN
- Apply ViT to medical/biological images

**Prerequisites:**
- Chapter 2 (CNNs) - to understand the comparison
- Chapter 3 (Transformers) - ViT builds directly on these concepts

**The Big Idea:** Instead of treating an image as a grid of pixels, we break it into patches (like puzzle pieces) and feed them to a transformer!

## 📚 Table of Contents
1. [From CNNs to Transformers](#cnn-to-vit)
2. [Patch Embedding - Breaking Images into Pieces](#patch-embedding)
3. [Vision Transformer Architecture](#architecture)
4. [Position Embeddings for Images](#position)
5. [ViT vs CNN: When to Use Each](#comparison)
6. [Hybrid Architectures](#hybrid)
7. [Biology Application: Medical Image Analysis](#biology-app)

---


In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import seaborn as sns
from PIL import Image

plt.style.use('seaborn-v0_8-darkgrid')
np.random.seed(42)
torch.manual_seed(42)

print('✓ Libraries imported')
print(f'PyTorch: {torch.__version__}')

## 1. From CNNs to Transformers <a id="cnn-to-vit"></a>

### CNNs: The Traditional Approach

**Strengths**:
- Strong inductive biases (locality, translation invariance)
- Efficient for images
- Well-established architectures

**Limitations**:
- Limited receptive field
- Fixed spatial relationships
- Difficulty with long-range dependencies

### Vision Transformers: The New Paradigm

**Key Idea**: Treat an image as a sequence of patches!

**Process**:
1. Split image into patches (e.g., 16×16)
2. Flatten each patch into a vector
3. Add positional embeddings
4. Feed through Transformer encoder

**Benefits**:
- Global receptive field from layer 1
- Flexible attention patterns
- Scalable with data

### The Trade-off

- **CNNs**: Better with small datasets (strong inductive bias)
- **ViT**: Better with large datasets (learns from data)

### Paper: "An Image is Worth 16x16 Words"

This groundbreaking 2020 paper showed ViT can match or exceed CNNs when trained on sufficient data.

In [None]:
def visualize_patch_extraction():
    """Visualize how an image is split into patches."""
    
    # Create a sample image (8x8 for visualization)
    image = np.random.rand(8, 8, 3)
    patch_size = 2
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Original image with grid
    axes[0].imshow(image)
    axes[0].set_title('Original Image (8×8)', fontsize=14, weight='bold')
    
    # Draw patch boundaries
    for i in range(0, 8, patch_size):
        axes[0].axhline(y=i-0.5, color='red', linewidth=2)
        axes[0].axvline(x=i-0.5, color='red', linewidth=2)
    axes[0].axhline(y=7.5, color='red', linewidth=2)
    axes[0].axvline(x=7.5, color='red', linewidth=2)
    
    # Number patches
    patch_num = 0
    for i in range(0, 8, patch_size):
        for j in range(0, 8, patch_size):
            axes[0].text(j + patch_size/2 - 0.5, i + patch_size/2 - 0.5, 
                        str(patch_num), ha='center', va='center',
                        fontsize=16, weight='bold', color='yellow',
                        bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))
            patch_num += 1
    
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    
    # Extract and visualize patches as sequence
    num_patches_per_dim = 8 // patch_size
    num_patches = num_patches_per_dim ** 2
    
    # Show patches as a sequence
    patch_display = np.zeros((patch_size * 2, patch_size * num_patches, 3))
    patch_num = 0
    for i in range(0, 8, patch_size):
        for j in range(0, 8, patch_size):
            patch = image[i:i+patch_size, j:j+patch_size, :]
            col_start = patch_num * patch_size
            patch_display[0:patch_size, col_start:col_start+patch_size, :] = patch
            patch_display[patch_size:, col_start:col_start+patch_size, :] = patch
            patch_num += 1
    
    axes[1].imshow(patch_display)
    axes[1].set_title(f'Patch Sequence ({num_patches} patches)', fontsize=14, weight='bold')
    axes[1].set_xlabel('Patch sequence →', fontsize=12)
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    
    plt.tight_layout()
    plt.show()
    
    print(f'\n📊 Patch Statistics:')
    print(f'  Image size: 8×8×3 = {8*8*3} values')
    print(f'  Patch size: {patch_size}×{patch_size}')
    print(f'  Number of patches: {num_patches}')
    print(f'  Each patch: {patch_size}×{patch_size}×3 = {patch_size*patch_size*3} values')
    print(f'\n  For standard ViT (224×224 image, 16×16 patches):')
    print(f'    Number of patches: (224/16)² = 196 patches')
    print(f'    Sequence length: 196 + 1 (class token) = 197')

visualize_patch_extraction()

## 2. Patch Embeddings <a id="patch-embed"></a>

### Converting Images to Sequences

Given an image $\mathbf{x} \in \mathbb{R}^{H \times W \times C}$:

1. **Split into patches**: $\mathbf{x}_p \in \mathbb{R}^{N \times (P^2 \cdot C)}$
   - $N = HW/P^2$ (number of patches)
   - $P$ = patch size

2. **Linear projection**: $\mathbf{z}_0 = [\mathbf{x}_{class}; \mathbf{x}_p^1\mathbf{E}; ...; \mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{pos}$
   - $\mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D}$ (patch embedding)
   - $\mathbf{E}_{pos} \in \mathbb{R}^{(N+1) \times D}$ (positional embedding)
   - $\mathbf{x}_{class}$ = learnable class token

### Class Token

A learnable embedding prepended to the sequence:
- Aggregates information from all patches
- Used for final classification
- Similar to [CLS] token in BERT

In [None]:
class PatchEmbedding(nn.Module):
    """Convert image to patch embeddings."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Use Conv2d for efficiency (equivalent to splitting and linear projection)
        self.proj = nn.Conv2d(
            in_channels, 
            embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
        
    def forward(self, x):
        # x: (batch_size, channels, height, width)
        x = self.proj(x)  # (batch_size, embed_dim, n_patches**0.5, n_patches**0.5)
        x = x.flatten(2)  # (batch_size, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (batch_size, n_patches, embed_dim)
        return x

# Test patch embedding
img_size = 224
patch_size = 16
in_channels = 3
embed_dim = 768
batch_size = 4

patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

# Create sample images
images = torch.randn(batch_size, in_channels, img_size, img_size)

# Get patch embeddings
embeddings = patch_embed(images)

print('Patch Embedding:')
print(f'Input image shape: {images.shape}')
print(f'  (batch_size, channels, height, width)')
print(f'\nOutput embedding shape: {embeddings.shape}')
print(f'  (batch_size, n_patches, embed_dim)')
print(f'\nNumber of patches: {patch_embed.n_patches}')
print(f'Embedding dimension: {embed_dim}')
print('\n✓ Image successfully converted to sequence!')

## 3. Vision Transformer Architecture <a id="vit-arch"></a>

### Complete ViT Pipeline

```
Image (224×224×3)
    ↓
Patch Embedding (196 patches of 16×16)
    ↓
Add Class Token (197 tokens)
    ↓
Add Positional Embedding
    ↓
Transformer Encoder (L layers)
    ↓
Extract Class Token
    ↓
MLP Head (Classification)
    ↓
Output (num_classes)
```

### Transformer Encoder Block

Each block contains:
1. Layer Normalization
2. Multi-Head Self-Attention
3. Residual connection
4. Layer Normalization
5. MLP (Feed-Forward)
6. Residual connection

### ViT Variants

- **ViT-Base**: 12 layers, 768 dim, 12 heads, 86M params
- **ViT-Large**: 24 layers, 1024 dim, 16 heads, 307M params
- **ViT-Huge**: 32 layers, 1280 dim, 16 heads, 632M params

In [None]:
class VisionTransformer(nn.Module):
    """Simplified Vision Transformer implementation."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
                 num_classes=1000, embed_dim=768, depth=12, num_heads=12, 
                 mlp_ratio=4.0, dropout=0.1):
        super(VisionTransformer, self).__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches
        
        # Class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)  # (B, N, D)
        batch_size = x.shape[0]
        
        # Add class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, N+1, D)
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer encoding
        x = self.transformer(x)
        
        # Classification from class token
        x = self.norm(x[:, 0])  # Take class token
        x = self.head(x)
        
        return x

# Create ViT-Tiny for demonstration
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    num_classes=10,  # Reduced for demo
    embed_dim=192,   # Smaller model
    depth=6,         # Fewer layers
    num_heads=3,
    mlp_ratio=4.0,
    dropout=0.1
)

# Test forward pass
x = torch.randn(2, 3, 224, 224)
output = model(x)

print('Vision Transformer:')
print(f'Input shape: {x.shape}')
print(f'Output shape: {output.shape}')

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'\nTotal parameters: {total_params:,}')
print('\n✓ Complete ViT model!')
print('\n💡 For comparison:')
print('  ViT-Base: ~86M parameters')
print('  ResNet-50: ~25M parameters')
print('  ViT needs more data but scales better!')

## 4. ViT vs CNN: When to Use Each <a id="comparison"></a>

### When to Use CNNs

✅ **Small to medium datasets** (< 1M images)
✅ **Limited computational resources**
✅ **Need for translation invariance** (strong inductive bias)
✅ **Real-time inference** (faster)
✅ **Local pattern recognition**

### When to Use ViT

✅ **Large datasets** (> 10M images)
✅ **Abundant computational resources**
✅ **Global context important** (long-range dependencies)
✅ **Transfer learning from large models**
✅ **When you can pretrain on huge datasets**

### Performance Comparison

| Aspect | CNN (ResNet) | ViT |
|--------|--------------|-----|
| Inductive bias | Strong (locality) | Weak |
| Data requirement | Low | High |
| Training time | Faster | Slower |
| Inference speed | Faster | Slower |
| Scalability | Limited | Excellent |
| Interpretability | Medium | High (attention maps) |

### Biology Applications

**Use CNNs for**:
- Cell counting (limited data)
- Quick microscopy analysis
- Real-time diagnosis

**Use ViT for**:
- Large medical image datasets
- Multi-scale tissue analysis
- When pretrained models available
- Complex pathology images

### Practical Tip

🎯 **Best of both worlds**: Use pretrained ViT models! They're trained on ImageNet or larger datasets (like JFT-300M) and work well even on small target datasets through transfer learning.

## 5. Key Takeaways

### Vision Transformers

1. **Patch Embedding**: Treat images as sequences of patches
2. **Self-Attention**: Global receptive field from layer 1
3. **Class Token**: Learnable token for classification
4. **Positional Encoding**: Preserve spatial information
5. **Scalability**: Performance improves with data and model size

### Design Choices

- **Patch Size**: 16×16 is standard, smaller = more patches = more computation
- **Model Size**: Bigger is better (with sufficient data)
- **Pretraining**: Essential for good performance
- **Regularization**: Important to prevent overfitting

### Future Directions

1. **Efficiency**: Swin Transformer, DeiT (data-efficient)
2. **Hybrid Models**: Combining CNNs and Transformers
3. **Self-Supervised Learning**: MAE (Masked Autoencoders)
4. **Multi-Modal**: CLIP (vision + language)

---

## 🎓 Congratulations!

You've completed the Vision Transformers chapter! You now understand:
- How to convert images to sequences
- The complete ViT architecture
- When to use ViT vs CNNs
- How to implement ViT components

### Next Steps

1. **Practice**: Implement ViT from scratch
2. **Experiment**: Try different patch sizes
3. **Apply**: Use pretrained ViT on your data
4. **Explore**: Check out Swin Transformer, DeiT, DINO

### Additional Resources

- Original ViT paper: "An Image is Worth 16x16 Words"
- timm library: Pre-implemented ViT variants
- Hugging Face Transformers: Easy-to-use ViT models

---

**🎉 You've completed the Deep Learning Biology Codebook!**

**You've learned**:
- ✅ Neural Networks fundamentals
- ✅ Convolutional Neural Networks
- ✅ Transformers architecture
- ✅ Vision Transformers

**Keep learning and applying these powerful techniques to biological problems!** 🧬🤖