<a href="https://colab.research.google.com/github/sameekshya1999/A-Riemann-Zeta-Function-Inspired-Optimizer-for-Deep-Learning/blob/main/Zeta_Explore.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision scipy



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import math

# Simple Feedforward NN
class SimpleNN(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=10):
        super(SimpleNN, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# Enhanced Zeta-inspired optimizer
class ZetaOptimizer(optim.Optimizer):
    def __init__(self, params, lr=1e-3, s_min=1.001, s_max=2.0, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=0.3, adam_mix=0.5, total_steps=5000):
        if lr <= 0.0:
            raise ValueError("Learning rate must be positive")
        if s_min <= 1.0 or s_max <= 1.0:
            raise ValueError("s_min and s_max must be > 1 for convergence")
        if not 0.0 <= beta1 < 1.0:
            raise ValueError("Beta1 must be in [0, 1)")
        if not 0.0 <= beta2 < 1.0:
            raise ValueError("Beta2 must be in [0, 1)")
        if eps <= 0.0:
            raise ValueError("Epsilon must be positive")
        if clip_norm <= 0.0:
            raise ValueError("Clip norm must be positive")
        if not 0.0 <= adam_mix <= 1.0:
            raise ValueError("adam_mix must be in [0, 1]")

        defaults = dict(lr=lr, s_min=s_min, s_max=s_max, beta1=beta1, beta2=beta2, eps=eps, clip_norm=clip_norm, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)
        super(ZetaOptimizer, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        max_terms = 100
        for group in self.param_groups:
            lr = group['lr']
            s_min = group['s_min']
            s_max = group['s_max']
            beta1 = group['beta1']
            beta2 = group['beta2']
            eps = group['eps']
            clip_norm = group['clip_norm']
            zeta_damp = group['zeta_damp']
            adam_mix = group['adam_mix']
            total_steps = group['total_steps']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                grad = grad.clamp(min=-clip_norm, max=clip_norm)
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['m'] = torch.zeros_like(p.data)
                    state['v'] = torch.zeros_like(p.data)
                    state['grad_norm_ema'] = torch.zeros_like(p.data)
                    state['prev_grad'] = torch.zeros_like(p.data)
                    state['loss_ema'] = 1.0

                state['step'] += 1
                m, v = state['m'], state['v']
                grad_norm_ema = state['grad_norm_ema']
                prev_grad = state['prev_grad']
                loss_ema = state['loss_ema']
                step = state['step']

                # Dynamic s scheduling
                s = s_min + (s_max - s_min) * min(step / total_steps, 1.0)
                zeta_s = sum(1.0 / (n ** s) for n in range(1, max_terms + 1))

                # Per-parameter gradient norm
                grad_norm_raw = torch.sqrt((grad ** 2).sum(dim=-1, keepdim=True))
                grad_norm_ema.mul_(0.9).add_(grad_norm_raw, alpha=0.1)

                # Adaptive damping based on loss and gradient norm
                if closure is not None and loss is not None:
                    state['loss_ema'] = 0.9 * loss_ema + 0.1 * loss.item()
                adaptive_damp = zeta_damp * (1.0 + grad_norm_ema / (1.0 + grad_norm_ema)) * (1.0 / max(0.1, loss_ema))
                zeta_factor = adaptive_damp / zeta_s * (1.0 / (1.0 + step * 0.005))

                # Gradient consistency (cosine similarity)
                grad_flat = grad.view(-1)
                prev_grad_flat = prev_grad.view(-1)
                cos_sim = torch.dot(grad_flat, prev_grad_flat) / (grad_flat.norm() * prev_grad_flat.norm() + eps)
                momentum_boost = 1.0 + zeta_factor * 0.2 * max(0.0, cos_sim.item())
                state['prev_grad'].copy_(grad)

                m.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                v.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                m_hat = m / (1.0 - beta1 ** step)
                v_hat = v / (1.0 - beta2 ** step)

                # Adam update
                adam_update = m_hat / (torch.sqrt(v_hat) + eps)

                # Zeta update
                grad_norm = torch.sqrt(v_hat).add_(eps)
                norm_factor = torch.clamp(grad_norm, max=0.5)
                zeta_scaled_lr = lr * zeta_factor / (grad_norm ** (s - 1.0) * norm_factor)
                zeta_update = zeta_scaled_lr * m_hat * momentum_boost

                # Hybrid update
                final_update = adam_mix * adam_update + (1.0 - adam_mix) * zeta_update

                # Step-wise LR decay on plateau
                lr_mult = 1.0 if loss_ema > 0.1 else 0.5
                current_lr = lr * lr_mult * (0.5 * (1.0 + math.cos(math.pi * step / (total_steps * 1.2))))

                p.data.add_(-current_lr * final_update)

        return loss

# Training function with accuracy
def train(model, optimizer, data_loader, criterion, epochs=5, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step(closure=lambda: criterion(model(inputs), labels))
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        avg_loss = total_loss / len(data_loader)
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, Train Accuracy: {accuracy:.2f}%")
    return model

# Test function
def test(model, data_loader, criterion, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

# Function to get data loaders for different datasets
def get_data_loaders(dataset_name, batch_size=64):
    if dataset_name == "FashionMNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ])
        train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
        input_dim = 1 * 28 * 28
        output_dim = 10
    elif dataset_name == "SVHN":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
        ])
        train_dataset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 10
    elif dataset_name == "STL10":
        transform = transforms.Compose([
            transforms.Resize(32),  # Resize to match other datasets
            transforms.ToTensor(),
            transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713))
        ])
        train_dataset = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 10
    elif dataset_name == "Flowers102":
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.4330, 0.3819, 0.2964), (0.2620, 0.2344, 0.2534))
        ])
        train_dataset = torchvision.datasets.Flowers102(root='./data', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.Flowers102(root='./data', split='test', download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 102
    else:
        raise ValueError("Unsupported dataset")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader, input_dim, output_dim

def main():
    # Hyperparameters
    batch_size = 64
    lr_zeta = 0.0015
    epochs = 5
    s_min = 1.001
    s_max = 2.0
    zeta_damp = 0.3
    adam_mix = 0.5
    total_steps = 5000
    datasets = ["FashionMNIST", "SVHN", "STL10", "Flowers102"]  # Multiple datasets

    criterion = nn.CrossEntropyLoss()

    for dataset_name in datasets:
        print(f"\n=== Processing Dataset: {dataset_name} ===")
        train_loader, test_loader, input_dim, output_dim = get_data_loaders(dataset_name, batch_size)

        # Model and Optimizer
        model_zeta = SimpleNN(input_dim=input_dim, output_dim=output_dim)
        optimizer_zeta = ZetaOptimizer(model_zeta.parameters(), lr=lr_zeta, s_min=s_min, s_max=s_max, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)

        # Train with ZetaOptimizer
        print(f"Training with ZetaOptimizer on {dataset_name}:")
        model_zeta = train(model_zeta, optimizer_zeta, train_loader, criterion, epochs)
        print(f"Evaluating ZetaOptimizer on {dataset_name} test set:")
        test(model_zeta, test_loader, criterion)

if __name__ == "__main__":
    main()


=== Processing Dataset: FashionMNIST ===


100%|██████████| 26.4M/26.4M [00:01<00:00, 17.2MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 274kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.11MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 15.5MB/s]


Training with ZetaOptimizer on FashionMNIST:
Epoch 1/5, Train Loss: 0.4724, Train Accuracy: 83.07%
Epoch 2/5, Train Loss: 0.3582, Train Accuracy: 86.94%
Epoch 3/5, Train Loss: 0.3176, Train Accuracy: 88.34%
Epoch 4/5, Train Loss: 0.2871, Train Accuracy: 89.50%
Epoch 5/5, Train Loss: 0.2656, Train Accuracy: 90.36%
Evaluating ZetaOptimizer on FashionMNIST test set:
Test Loss: 0.3331, Test Accuracy: 87.89%

=== Processing Dataset: SVHN ===


100%|██████████| 182M/182M [00:05<00:00, 34.8MB/s]
100%|██████████| 64.3M/64.3M [00:03<00:00, 17.6MB/s]


Training with ZetaOptimizer on SVHN:
Epoch 1/5, Train Loss: 1.2025, Train Accuracy: 62.53%
Epoch 2/5, Train Loss: 0.8060, Train Accuracy: 75.99%
Epoch 3/5, Train Loss: 0.6655, Train Accuracy: 80.44%
Epoch 4/5, Train Loss: 0.5705, Train Accuracy: 83.48%
Epoch 5/5, Train Loss: 0.5159, Train Accuracy: 85.37%
Evaluating ZetaOptimizer on SVHN test set:
Test Loss: 0.6703, Test Accuracy: 81.85%

=== Processing Dataset: STL10 ===


100%|██████████| 2.64G/2.64G [01:26<00:00, 30.6MB/s]


Training with ZetaOptimizer on STL10:
Epoch 1/5, Train Loss: 1.9016, Train Accuracy: 32.34%
Epoch 2/5, Train Loss: 1.5717, Train Accuracy: 43.90%
Epoch 3/5, Train Loss: 1.3885, Train Accuracy: 49.82%
Epoch 4/5, Train Loss: 1.2591, Train Accuracy: 55.90%
Epoch 5/5, Train Loss: 1.1427, Train Accuracy: 61.12%
Evaluating ZetaOptimizer on STL10 test set:
Test Loss: 1.7054, Test Accuracy: 41.90%

=== Processing Dataset: Flowers102 ===


100%|██████████| 345M/345M [00:11<00:00, 30.0MB/s]
100%|██████████| 502/502 [00:00<00:00, 2.32MB/s]
100%|██████████| 15.0k/15.0k [00:00<00:00, 32.9MB/s]


Training with ZetaOptimizer on Flowers102:
Epoch 1/5, Train Loss: 4.5370, Train Accuracy: 3.24%
Epoch 2/5, Train Loss: 3.5739, Train Accuracy: 18.14%
Epoch 3/5, Train Loss: 2.7715, Train Accuracy: 37.45%
Epoch 4/5, Train Loss: 2.1261, Train Accuracy: 53.92%
Epoch 5/5, Train Loss: 1.5991, Train Accuracy: 68.33%
Evaluating ZetaOptimizer on Flowers102 test set:
Test Loss: 3.9310, Test Accuracy: 13.66%


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import math

# Simple Feedforward NN
class SimpleNN(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=10):
        super(SimpleNN, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# Enhanced Zeta-inspired optimizer
class ZetaOptimizer(optim.Optimizer):
    def __init__(self, params, lr=1e-3, s_min=1.001, s_max=2.0, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=0.3, adam_mix=0.5, total_steps=5000):
        if lr <= 0.0:
            raise ValueError("Learning rate must be positive")
        if s_min <= 1.0 or s_max <= 1.0:
            raise ValueError("s_min and s_max must be > 1 for convergence")
        if not 0.0 <= beta1 < 1.0:
            raise ValueError("Beta1 must be in [0, 1)")
        if not 0.0 <= beta2 < 1.0:
            raise ValueError("Beta2 must be in [0, 1)")
        if eps <= 0.0:
            raise ValueError("Epsilon must be positive")
        if clip_norm <= 0.0:
            raise ValueError("Clip norm must be positive")
        if not 0.0 <= adam_mix <= 1.0:
            raise ValueError("adam_mix must be in [0, 1]")

        defaults = dict(lr=lr, s_min=s_min, s_max=s_max, beta1=beta1, beta2=beta2, eps=eps, clip_norm=clip_norm, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)
        super(ZetaOptimizer, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        max_terms = 100
        for group in self.param_groups:
            lr = group['lr']
            s_min = group['s_min']
            s_max = group['s_max']
            beta1 = group['beta1']
            beta2 = group['beta2']
            eps = group['eps']
            clip_norm = group['clip_norm']
            zeta_damp = group['zeta_damp']
            adam_mix = group['adam_mix']
            total_steps = group['total_steps']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                grad = grad.clamp(min=-clip_norm, max=clip_norm)
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['m'] = torch.zeros_like(p.data)
                    state['v'] = torch.zeros_like(p.data)
                    state['grad_norm_ema'] = torch.zeros_like(p.data)
                    state['prev_grad'] = torch.zeros_like(p.data)
                    state['loss_ema'] = 1.0

                state['step'] += 1
                m, v = state['m'], state['v']
                grad_norm_ema = state['grad_norm_ema']
                prev_grad = state['prev_grad']
                loss_ema = state['loss_ema']
                step = state['step']

                # Dynamic s scheduling
                s = s_min + (s_max - s_min) * min(step / total_steps, 1.0)
                zeta_s = sum(1.0 / (n ** s) for n in range(1, max_terms + 1))

                # Per-parameter gradient norm
                grad_norm_raw = torch.sqrt((grad ** 2).sum(dim=-1, keepdim=True))
                grad_norm_ema.mul_(0.9).add_(grad_norm_raw, alpha=0.1)

                # Adaptive damping based on loss and gradient norm
                if closure is not None and loss is not None:
                    state['loss_ema'] = 0.9 * loss_ema + 0.1 * loss.item()
                adaptive_damp = zeta_damp * (1.0 + grad_norm_ema / (1.0 + grad_norm_ema)) * (1.0 / max(0.1, loss_ema))
                zeta_factor = adaptive_damp / zeta_s * (1.0 / (1.0 + step * 0.005))

                # Gradient consistency (cosine similarity)
                grad_flat = grad.view(-1)
                prev_grad_flat = prev_grad.view(-1)
                cos_sim = torch.dot(grad_flat, prev_grad_flat) / (grad_flat.norm() * prev_grad_flat.norm() + eps)
                momentum_boost = 1.0 + zeta_factor * 0.2 * max(0.0, cos_sim.item())
                state['prev_grad'].copy_(grad)

                m.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                v.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                m_hat = m / (1.0 - beta1 ** step)
                v_hat = v / (1.0 - beta2 ** step)

                # Adam update
                adam_update = m_hat / (torch.sqrt(v_hat) + eps)

                # Zeta update
                grad_norm = torch.sqrt(v_hat).add_(eps)
                norm_factor = torch.clamp(grad_norm, max=0.5)
                zeta_scaled_lr = lr * zeta_factor / (grad_norm ** (s - 1.0) * norm_factor)
                zeta_update = zeta_scaled_lr * m_hat * momentum_boost

                # Hybrid update
                final_update = adam_mix * adam_update + (1.0 - adam_mix) * zeta_update

                # Step-wise LR decay on plateau
                lr_mult = 1.0 if loss_ema > 0.1 else 0.5
                current_lr = lr * lr_mult * (0.5 * (1.0 + math.cos(math.pi * step / (total_steps * 1.2))))

                p.data.add_(-current_lr * final_update)

        return loss

# Training function with accuracy
def train(model, optimizer, data_loader, criterion, epochs=5, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            if isinstance(optimizer, ZetaOptimizer):
                optimizer.step(closure=lambda: criterion(model(inputs), labels))
            else:
                optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        avg_loss = total_loss / len(data_loader)
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, Train Accuracy: {accuracy:.2f}%")
    return model

# Test function
def test(model, data_loader, criterion, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

# Function to get data loaders for different datasets
def get_data_loaders(dataset_name, batch_size=64):
    if dataset_name == "FashionMNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ])
        train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
        input_dim = 1 * 28 * 28
        output_dim = 10
    elif dataset_name == "SVHN":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
        ])
        train_dataset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 10
    elif dataset_name == "STL10":
        transform = transforms.Compose([
            transforms.Resize(32),  # Resize to match other datasets
            transforms.ToTensor(),
            transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713))
        ])
        train_dataset = torchvision.datasets.STL10(root='./data', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.STL10(root='./data', split='test', download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 10
    elif dataset_name == "Flowers102":
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.4330, 0.3819, 0.2964), (0.2620, 0.2344, 0.2534))
        ])
        train_dataset = torchvision.datasets.Flowers102(root='./data', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.Flowers102(root='./data', split='test', download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 102
    else:
        raise ValueError("Unsupported dataset")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader, input_dim, output_dim

def main():
    # Hyperparameters
    batch_size = 64
    lr_adam = 0.001
    lr_sgd = 0.01
    lr_zeta = 0.0015
    epochs = 5
    s_min = 1.001
    s_max = 2.0
    zeta_damp = 0.3
    adam_mix = 0.5
    total_steps = 5000
    datasets = ["FashionMNIST", "SVHN", "STL10", "Flowers102"]  # Multiple datasets

    criterion = nn.CrossEntropyLoss()

    for dataset_name in datasets:
        print(f"\n=== Processing Dataset: {dataset_name} ===")
        train_loader, test_loader, input_dim, output_dim = get_data_loaders(dataset_name, batch_size)

        # Models
        model_adam = SimpleNN(input_dim=input_dim, output_dim=output_dim)
        model_sgd = SimpleNN(input_dim=input_dim, output_dim=output_dim)
        model_zeta = SimpleNN(input_dim=input_dim, output_dim=output_dim)

        # Optimizers
        optimizer_adam = optim.Adam(model_adam.parameters(), lr=lr_adam)
        optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=lr_sgd, momentum=0.9)
        optimizer_zeta = ZetaOptimizer(model_zeta.parameters(), lr=lr_zeta, s_min=s_min, s_max=s_max, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)

        # Train and evaluate Adam
        print(f"Training with Adam on {dataset_name}:")
        model_adam = train(model_adam, optimizer_adam, train_loader, criterion, epochs)
        print(f"Evaluating Adam on {dataset_name} test set:")
        test(model_adam, test_loader, criterion)

        # Train and evaluate SGD
        print(f"\nTraining with SGD on {dataset_name}:")
        model_sgd = train(model_sgd, optimizer_sgd, train_loader, criterion, epochs)
        print(f"Evaluating SGD on {dataset_name} test set:")
        test(model_sgd, test_loader, criterion)

        # Train and evaluate ZetaOptimizer
        print(f"\nTraining with ZetaOptimizer on {dataset_name}:")
        model_zeta = train(model_zeta, optimizer_zeta, train_loader, criterion, epochs)
        print(f"Evaluating ZetaOptimizer on {dataset_name} test set:")
        test(model_zeta, test_loader, criterion)

if __name__ == "__main__":
    main()


=== Processing Dataset: FashionMNIST ===
Training with Adam on FashionMNIST:
Epoch 1/5, Train Loss: 0.4657, Train Accuracy: 83.18%
Epoch 2/5, Train Loss: 0.3525, Train Accuracy: 87.11%
Epoch 3/5, Train Loss: 0.3159, Train Accuracy: 88.38%
Epoch 4/5, Train Loss: 0.2917, Train Accuracy: 89.20%
Epoch 5/5, Train Loss: 0.2746, Train Accuracy: 89.71%
Evaluating Adam on FashionMNIST test set:
Test Loss: 0.3580, Test Accuracy: 87.02%

Training with SGD on FashionMNIST:
Epoch 1/5, Train Loss: 0.4911, Train Accuracy: 82.24%
Epoch 2/5, Train Loss: 0.3668, Train Accuracy: 86.57%
Epoch 3/5, Train Loss: 0.3284, Train Accuracy: 88.08%
Epoch 4/5, Train Loss: 0.3063, Train Accuracy: 88.75%
Epoch 5/5, Train Loss: 0.2871, Train Accuracy: 89.43%
Evaluating SGD on FashionMNIST test set:
Test Loss: 0.3374, Test Accuracy: 87.68%

Training with ZetaOptimizer on FashionMNIST:
Epoch 1/5, Train Loss: 0.4743, Train Accuracy: 83.08%
Epoch 2/5, Train Loss: 0.3551, Train Accuracy: 87.12%
Epoch 3/5, Train Loss: 0.31