# Chapter 2: Convolutional Neural Networks (CNNs)

Welcome to Chapter 2! Here we'll explore CNNs, the powerhouse architecture for image analysis and a key tool in biological image processing.

## 📚 Table of Contents
1. [Why Convolutions?](#why-convolutions)
2. [Understanding Convolution Operation](#convolution-op)
3. [CNN Components](#cnn-components)
4. [Building Your First CNN](#first-cnn)
5. [Popular CNN Architectures](#architectures)
6. [Transfer Learning](#transfer-learning)
7. [Biology Application: Cell Image Classification](#biology-app)

---

In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns
from PIL import Image

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Set seeds
np.random.seed(42)
torch.manual_seed(42)

print('Libraries imported successfully!')
print(f'PyTorch version: {torch.__version__}')
print(f'GPU available: {torch.cuda.is_available()}')

## 1. Why Convolutions? <a id="why-convolutions"></a>

### Problems with Fully Connected Networks for Images

Imagine a small 28×28 grayscale image (like MNIST digits):
- Input size: 28 × 28 = 784 pixels
- Hidden layer: 1000 neurons
- **Parameters needed**: 784 × 1000 = 784,000!

For a color image (224×224×3):
- Input size: 224 × 224 × 3 = 150,528 pixels
- Hidden layer: 1000 neurons
- **Parameters needed**: 150,528 × 1000 = 150 million!

### Key Insights for Images

1. **Local Connectivity**: Nearby pixels are more related than distant ones
2. **Translation Invariance**: A cat is a cat whether it's in the top-left or bottom-right
3. **Hierarchical Features**: Low-level features (edges) combine to form high-level features (objects)

### Solution: Convolutional Layers

Convolutions address all three issues:
- **Local receptive fields**: Each neuron only looks at a small region
- **Shared weights**: Same filters scan across the entire image
- **Hierarchical learning**: Stack layers to build complexity

Let's visualize a convolution operation!

In [None]:
def visualize_convolution():
    """Visualize a simple 2D convolution operation."""
    
    # Create a simple input (5x5)
    input_img = np.array([
        [1, 1, 1, 0, 0],
        [0, 1, 1, 1, 0],
        [0, 0, 1, 1, 1],
        [0, 0, 1, 1, 0],
        [0, 1, 1, 0, 0]
    ])
    
    # Edge detection kernel (3x3)
    kernel = np.array([
        [-1, -1, -1],
        [-1,  8, -1],
        [-1, -1, -1]
    ])
    
    # Perform convolution manually
    output_size = input_img.shape[0] - kernel.shape[0] + 1
    output = np.zeros((output_size, output_size))
    
    for i in range(output_size):
        for j in range(output_size):
            # Extract region
            region = input_img[i:i+3, j:j+3]
            # Apply kernel
            output[i, j] = np.sum(region * kernel)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Input
    im1 = axes[0].imshow(input_img, cmap='gray', interpolation='nearest')
    axes[0].set_title('Input Image (5×5)', fontsize=13, weight='bold')
    axes[0].grid(True, which='both', color='red', linewidth=0.5, alpha=0.3)
    axes[0].set_xticks(np.arange(-0.5, 5, 1), minor=True)
    axes[0].set_yticks(np.arange(-0.5, 5, 1), minor=True)
    for i in range(5):
        for j in range(5):
            axes[0].text(j, i, str(int(input_img[i, j])), 
                        ha='center', va='center', color='red', fontsize=11, weight='bold')
    plt.colorbar(im1, ax=axes[0])
    
    # Kernel
    im2 = axes[1].imshow(kernel, cmap='RdBu', interpolation='nearest', vmin=-8, vmax=8)
    axes[1].set_title('Edge Detection Kernel (3×3)', fontsize=13, weight='bold')
    axes[1].grid(True, which='both', color='black', linewidth=0.5, alpha=0.3)
    axes[1].set_xticks(np.arange(-0.5, 3, 1), minor=True)
    axes[1].set_yticks(np.arange(-0.5, 3, 1), minor=True)
    for i in range(3):
        for j in range(3):
            axes[1].text(j, i, str(int(kernel[i, j])), 
                        ha='center', va='center', color='black', fontsize=11, weight='bold')
    plt.colorbar(im2, ax=axes[1])
    
    # Output
    im3 = axes[2].imshow(output, cmap='viridis', interpolation='nearest')
    axes[2].set_title('Output Feature Map (3×3)', fontsize=13, weight='bold')
    axes[2].grid(True, which='both', color='white', linewidth=0.5, alpha=0.3)
    axes[2].set_xticks(np.arange(-0.5, 3, 1), minor=True)
    axes[2].set_yticks(np.arange(-0.5, 3, 1), minor=True)
    for i in range(output_size):
        for j in range(output_size):
            axes[2].text(j, i, f'{output[i, j]:.0f}', 
                        ha='center', va='center', color='white', fontsize=11, weight='bold')
    plt.colorbar(im3, ax=axes[2])
    
    plt.tight_layout()
    plt.show()
    
    print('\n📊 Convolution Operation:')
    print('Input (5×5) * Kernel (3×3) = Output (3×3)')
    print('\nOutput size formula: (input_size - kernel_size + 1)')
    print('In this case: (5 - 3 + 1) = 3')

visualize_convolution()

## 2. Understanding Convolution Operation <a id="convolution-op"></a>

### Mathematical Definition

For a 2D input $I$ and kernel $K$, the convolution at position $(i, j)$ is:

$$S(i, j) = (I * K)(i, j) = \sum_{m} \sum_{n} I(i+m, j+n) \cdot K(m, n)$$

### Key Parameters

1. **Kernel Size**: Size of the filter (e.g., 3×3, 5×5)
2. **Stride**: Step size when sliding kernel (default: 1)
3. **Padding**: Add zeros around input to control output size
4. **Dilation**: Spacing between kernel elements

### Output Size Calculation

$$O = \frac{W - K + 2P}{S} + 1$$

where:
- $O$ = output size
- $W$ = input size
- $K$ = kernel size
- $P$ = padding
- $S$ = stride

In [None]:
def demonstrate_conv_parameters():
    """Demonstrate effect of different convolution parameters."""
    
    # Create a sample input
    x = torch.randn(1, 1, 7, 7)  # batch=1, channels=1, height=7, width=7
    
    print('Input shape:', x.shape)
    print('Format: (batch_size, channels, height, width)\n')
    
    # Different configurations
    configs = [
        {'kernel_size': 3, 'stride': 1, 'padding': 0, 'name': 'Default'},
        {'kernel_size': 3, 'stride': 2, 'padding': 0, 'name': 'Stride=2'},
        {'kernel_size': 3, 'stride': 1, 'padding': 1, 'name': 'Padding=1 (same)'},
        {'kernel_size': 5, 'stride': 1, 'padding': 0, 'name': 'Kernel=5×5'},
    ]
    
    results = []
    for config in configs:
        conv = nn.Conv2d(in_channels=1, out_channels=1, 
                        kernel_size=config['kernel_size'],
                        stride=config['stride'],
                        padding=config['padding'])
        output = conv(x)
        
        # Calculate output size using formula
        calc_size = int((7 - config['kernel_size'] + 2*config['padding']) / config['stride'] + 1)
        
        result = f"{config['name']:20s} | Output: {output.shape[2]}×{output.shape[3]} (calculated: {calc_size}×{calc_size})"
        results.append(result)
        print(result)
    
    print('\n💡 Key Insight:')
    print('  - Stride > 1: Reduces spatial dimensions (downsampling)')
    print('  - Padding = (kernel_size-1)/2: Maintains input size ("same" padding)')
    print('  - Larger kernels: See more context but fewer parameters')

demonstrate_conv_parameters()

## 3. CNN Components <a id="cnn-components"></a>

A typical CNN consists of:

### 1. Convolutional Layers
- Learn spatial hierarchies of features
- Share weights across spatial locations
- Each filter detects a specific pattern

### 2. Activation Functions
- Usually ReLU: $\text{ReLU}(x) = \max(0, x)$
- Introduces non-linearity

### 3. Pooling Layers
- Reduce spatial dimensions
- Provide translation invariance
- Types: Max pooling, Average pooling

### 4. Fully Connected Layers
- Final classification
- Combine all features

### 5. Dropout (Regularization)
- Randomly drop neurons during training
- Prevents overfitting

In [None]:
def visualize_pooling():
    """Visualize max pooling operation."""
    
    # Create input
    input_data = np.array([
        [1, 3, 2, 4],
        [5, 6, 7, 8],
        [9, 2, 3, 1],
        [0, 4, 5, 2]
    ])
    
    # Max pooling 2x2
    output = np.zeros((2, 2))
    for i in range(2):
        for j in range(2):
            region = input_data[i*2:(i+1)*2, j*2:(j+1)*2]
            output[i, j] = np.max(region)
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Input
    im1 = axes[0].imshow(input_data, cmap='YlOrRd', interpolation='nearest')
    axes[0].set_title('Input (4×4)', fontsize=13, weight='bold')
    axes[0].grid(True, which='both', color='black', linewidth=2)
    axes[0].set_xticks(np.arange(-0.5, 4, 1), minor=True)
    axes[0].set_yticks(np.arange(-0.5, 4, 1), minor=True)
    
    # Add pooling windows
    for i in range(2):
        for j in range(2):
            rect = plt.Rectangle((j*2-0.5, i*2-0.5), 2, 2, 
                                fill=False, edgecolor='blue', linewidth=3)
            axes[0].add_patch(rect)
    
    for i in range(4):
        for j in range(4):
            axes[0].text(j, i, str(int(input_data[i, j])), 
                        ha='center', va='center', color='black', fontsize=12, weight='bold')
    plt.colorbar(im1, ax=axes[0])
    
    # Output
    im2 = axes[1].imshow(output, cmap='YlOrRd', interpolation='nearest')
    axes[1].set_title('Max Pooled Output (2×2)', fontsize=13, weight='bold')
    axes[1].grid(True, which='both', color='black', linewidth=2)
    axes[1].set_xticks(np.arange(-0.5, 2, 1), minor=True)
    axes[1].set_yticks(np.arange(-0.5, 2, 1), minor=True)
    for i in range(2):
        for j in range(2):
            axes[1].text(j, i, str(int(output[i, j])), 
                        ha='center', va='center', color='black', fontsize=14, weight='bold')
    plt.colorbar(im2, ax=axes[1])
    
    plt.tight_layout()
    plt.show()
    
    print('\n📊 Max Pooling (2×2, stride=2):')
    print('Takes maximum value from each 2×2 region')
    print('Reduces spatial dimensions by half')
    print('Provides translation invariance')

visualize_pooling()