In [None]:
import torch
import torch.nn as nn
import snntorch as snn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import random
from deap import base, creator, tools, algorithms
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
import math

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

In [32]:
def create_snn(hidden1, hidden2, beta):
    class SNNModel(nn.Module):
        def __init__(self):
            super().__init__()
            # Define linear layers
            self.fc1 = nn.Linear(28*28, hidden1)
            self.lif1 = snn.Leaky(beta=beta)
            self.fc2 = nn.Linear(hidden1, hidden2)
            self.lif2 = snn.Leaky(beta=beta)
            self.fc3 = nn.Linear(hidden2, 10)
            self.lif3 = snn.Leaky(beta=beta)

        def forward(self, x, num_steps=10):
            # Initialize membrane potentials
            mem1 = self.lif1.init_leaky()
            mem2 = self.lif2.init_leaky()
            mem3 = self.lif3.init_leaky()
            spk_out = 0
            for _ in range(num_steps):
                cur1 = self.fc1(x.view(x.size(0), -1))
                spk1, mem1 = self.lif1(cur1, mem1)
                cur2 = self.fc2(spk1)
                spk2, mem2 = self.lif2(cur2, mem2)
                cur3 = self.fc3(spk2)
                spk3, mem3 = self.lif3(cur3, mem3)
                spk_out += spk3
            return spk_out / num_steps
    return SNNModel()

In [None]:
def bounded_mutation(individual, bounds, mu=0, sigma=0.2, indpb=0.2):
    """Custom mutation that keeps values within specified bounds"""
    for i, (low, high) in enumerate(bounds):
        if random.random() < indpb:
            # Apply Gaussian mutation
            individual[i] += random.gauss(mu, sigma * (high - low))
            # Clip to bounds
            individual[i] = max(low, min(high, individual[i]))
    return individual,

# --- 4. Evaluation Function using CrossEntropyLoss ---
def evaluate_model(individual):
    # Cast and sanitize hyperparameters with bounds checking
    hidden1 = max(64, min(256, int(round(individual[0]))))
    hidden2 = max(64, min(256, int(round(individual[1]))))
    beta = max(0.5, min(0.99, float(individual[2])))
    lr = max(0.0001, min(0.01, float(individual[3])))

    model = create_snn(hidden1, hidden2, beta)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    # Training (short for speed)
    model.train()
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, targets)
        loss.backward()
        optimizer.step()
        if batch_idx >= 20:
            break

    # Validation on subset
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, pred = outputs.max(1)
            total += targets.size(0)
            correct += (pred == targets).sum().item()
            if total >= 1000:
                break
    accuracy = correct / total
    return (accuracy,)

In [None]:
def plot_individual(individual, ax, color='blue', alpha=0.7):
    """Plot a single individual's parameters as a bar chart"""
    params = ['Hidden1', 'Hidden2', 'Beta', 'LR×1000']
    values = [individual[0], individual[1], individual[2], individual[3]*1000]  # Scale LR for visibility
    
    bars = ax.bar(params, values, color=color, alpha=alpha)
    ax.set_ylim(0, 300)
    ax.set_ylabel('Parameter Value')
    
    # Add value labels on bars
    for bar, val in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 5,
                f'{val:.1f}', ha='center', va='bottom', fontsize=8)

def plot_population(population, generation, min_fitness, max_fitness):
    """Plot population diversity and fitness distribution"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
    
    # Extract parameters and fitness
    hidden1_vals = [ind[0] for ind in population]
    hidden2_vals = [ind[1] for ind in population]
    beta_vals = [ind[2] for ind in population]
    lr_vals = [ind[3] for ind in population]
    fitness_vals = [ind.fitness.values[0] if ind.fitness.valid else 0 for ind in population]
    
    # 1. Parameter scatter plots with fitness coloring
    scatter1 = ax1.scatter(hidden1_vals, hidden2_vals, c=fitness_vals, 
                          cmap='viridis', s=100, alpha=0.7, vmin=min_fitness, vmax=max_fitness)
    ax1.set_xlabel('Hidden Layer 1 Size')
    ax1.set_ylabel('Hidden Layer 2 Size')
    ax1.set_title(f'Gen {generation}: Layer Sizes vs Fitness')
    ax1.grid(True, alpha=0.3)
    
    # 2. Beta vs Learning Rate
    scatter2 = ax2.scatter(beta_vals, lr_vals, c=fitness_vals, 
                          cmap='viridis', s=100, alpha=0.7, vmin=min_fitness, vmax=max_fitness)
    ax2.set_xlabel('Beta (Leak Factor)')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title(f'Gen {generation}: Beta vs Learning Rate')
    ax2.grid(True, alpha=0.3)
    
    # 3. Fitness distribution
    ax3.hist(fitness_vals, bins=max(3, len(population)//2), alpha=0.7, color='skyblue', edgecolor='black')
    ax3.axvline(np.mean(fitness_vals), color='red', linestyle='--', label=f'Mean: {np.mean(fitness_vals):.3f}')
    ax3.axvline(np.max(fitness_vals), color='green', linestyle='--', label=f'Max: {np.max(fitness_vals):.3f}')
    ax3.set_xlabel('Fitness (Accuracy)')
    ax3.set_ylabel('Count')
    ax3.set_title(f'Gen {generation}: Fitness Distribution')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Best individual parameters
    best_idx = np.argmax(fitness_vals)
    best_individual = population[best_idx]
    plot_individual(best_individual, ax4, color='gold')
    ax4.set_title(f'Gen {generation}: Best Individual (Acc: {fitness_vals[best_idx]:.3f})')
    
    # Add colorbar
    cbar = plt.colorbar(scatter1, ax=ax4)
    cbar.set_label('Fitness (Accuracy)')
    
    plt.tight_layout()
    return fig

In [None]:
def plot_evolution_progress(logbook):
    """Plot evolution statistics over generations"""
    gen = logbook.select("gen")
    avg_fitness = logbook.select("avg")
    max_fitness = logbook.select("max")
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # 1. Fitness evolution
    ax1.plot(gen, avg_fitness, 'b-', label='Average Fitness', linewidth=2, marker='o')
    ax1.plot(gen, max_fitness, 'r-', label='Maximum Fitness', linewidth=2, marker='s')
    ax1.fill_between(gen, avg_fitness, max_fitness, alpha=0.2)
    ax1.set_xlabel('Generation')
    ax1.set_ylabel('Fitness (Accuracy)')
    ax1.set_title('Evolution Progress: Fitness Over Generations')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Improvement rate
    improvement = np.diff(max_fitness)
    ax2.bar(gen[1:], improvement, alpha=0.7, color=['green' if x > 0 else 'red' for x in improvement])
    ax2.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    ax2.set_xlabel('Generation')
    ax2.set_ylabel('Fitness Improvement')
    ax2.set_title('Generation-to-Generation Improvement')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_architecture_evolution(populations_history):
    """Plot how architecture parameters evolved over time"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    generations = list(range(len(populations_history)))
    param_names = ['Hidden Layer 1', 'Hidden Layer 2', 'Beta (Leak Factor)', 'Learning Rate']
    param_indices = [0, 1, 2, 3]
    
    for idx, (ax, param_name, param_idx) in enumerate(zip(axes.flat, param_names, param_indices)):
        # Extract parameter values for each generation
        for gen, population in enumerate(populations_history):
            values = [ind[param_idx] for ind in population if ind.fitness.valid]
            fitness = [ind.fitness.values[0] for ind in population if ind.fitness.valid]
            
            # Scatter plot with fitness coloring
            scatter = ax.scatter([gen] * len(values), values, c=fitness, 
                              cmap='viridis', alpha=0.6, s=50)
        
        ax.set_xlabel('Generation')
        ax.set_ylabel(param_name)
        ax.set_title(f'Evolution of {param_name}')
        ax.grid(True, alpha=0.3)
    
    # Add colorbar to the last subplot
    cbar = plt.colorbar(scatter, ax=axes.flat[-1])
    cbar.set_label('Fitness (Accuracy)')
    
    plt.tight_layout()
    return fig

In [None]:
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)

# Define bounds for each parameter
BOUNDS = [(64, 256), (64, 256), (0.5, 0.99), (0.0001, 0.01)]

toolbox = base.Toolbox()
toolbox.register("hidden1", random.uniform, 64, 256)
toolbox.register("hidden2", random.uniform, 64, 256)
toolbox.register("beta", random.uniform, 0.5, 0.99)
toolbox.register("lr", random.uniform, 0.0001, 0.01)
toolbox.register("individual", tools.initCycle, creator.Individual,
                 (toolbox.hidden1, toolbox.hidden2, toolbox.beta, toolbox.lr), n=1)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", evaluate_model)
toolbox.register("mate", tools.cxBlend, alpha=0.5)
toolbox.register("mutate", bounded_mutation, bounds=BOUNDS, mu=0, sigma=0.1, indpb=0.2)
toolbox.register("select", tools.selTournament, tournsize=3)

In [31]:
def run_evolution():
    pop = toolbox.population(n=6)
    hof = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("max", np.max)

    pop, log = algorithms.eaSimple(pop, toolbox, cxpb=0.5, mutpb=0.3, ngen=10,
                                   stats=stats, halloffame=hof, verbose=True)
    print("\nBest individual:", hof[0])
    print("Accuracy:", hof[0].fitness.values[0])

if __name__ == '__main__':
    run_evolution()

gen	nevals	avg    	max     
0  	6     	0.71875	0.777344
1  	5     	0.735514	0.777344
2  	1     	0.750163	0.777344
3  	2     	0.764974	0.777344
4  	4     	0.731445	0.777344
5  	4     	0.708984	0.777344
6  	3     	0.720215	0.777344
7  	1     	0.74821 	0.777344


ValueError: Invalid learning rate: -0.033257832242173886