# 11. Activation Function Comparison

This notebook compares different activation functions (ReLU, Sigmoid, Tanh) on the same neural network architecture.

## Experiment Overview
- **Goal**: Compare ReLU vs Sigmoid vs Tanh activation functions
- **Model**: Simple MLP with different activations
- **Features**: Training dynamics, convergence analysis, gradient flow
- **Learning**: Understanding activation function effects on training

## What You'll Learn
- Activation function properties
- Training dynamics comparison
- Gradient flow analysis
- Convergence behavior differences


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add scripts directory to path
sys.path.append('../scripts')
from utils import load_mnist_data, get_device, set_seed, create_synthetic_data

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

# Load MNIST dataset
print("Loading MNIST dataset...")
train_loader, val_loader, test_loader = load_mnist_data(batch_size=64, test_split=0.2)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


In [None]:
# Define models with different activation functions
class ActivationComparisonModel(nn.Module):
    def __init__(self, activation='relu', input_size=784, hidden_size=128, num_classes=10):
        super(ActivationComparisonModel, self).__init__()
        
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.2)
        
        # Set activation function
        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'sigmoid':
            self.activation = F.sigmoid
        elif activation == 'tanh':
            self.activation = F.tanh
        else:
            raise ValueError(f"Unknown activation: {activation}")
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        
        x = self.activation(self.fc1(x))
        x = self.dropout(x)
        
        x = self.activation(self.fc2(x))
        x = self.dropout(x)
        
        x = self.fc3(x)
        return x

# Create models with different activations
activations = ['relu', 'sigmoid', 'tanh']
models = {}

for activation in activations:
    models[activation] = ActivationComparisonModel(activation=activation).to(device)
    print(f"{activation.upper()} Model:")
    print(f"  Parameters: {sum(p.numel() for p in models[activation].parameters()):,}")
    print(f"  Model size: {sum(p.numel() for p in models[activation].parameters()) * 4 / 1024 / 1024:.2f} MB")
    print()

# Visualize activation functions
x = torch.linspace(-5, 5, 100)
activations_funcs = {
    'ReLU': F.relu,
    'Sigmoid': F.sigmoid,
    'Tanh': F.tanh
}

plt.figure(figsize=(12, 4))
for i, (name, func) in enumerate(activations_funcs.items(), 1):
    plt.subplot(1, 3, i)
    y = func(x)
    plt.plot(x.numpy(), y.numpy(), linewidth=2)
    plt.title(f'{name} Activation')
    plt.xlabel('Input')
    plt.ylabel('Output')
    plt.grid(True)
    plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)

plt.tight_layout()
plt.savefig('../results/plots/activation_functions.png', dpi=300, bbox_inches='tight')
plt.show()
