<a href="https://colab.research.google.com/github/sameekshya1999/A-Riemann-Zeta-Function-Inspired-Optimizer-for-Deep-Learning/blob/main/Mnist_and_CIFAR_10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch scipy


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

# **MNIST**

In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import math

# Simple Feedforward NN for CIFAR-10 classification
class SimpleNN(nn.Module):
    def __init__(self, input_dim=3*32*32, hidden_dim=128, output_dim=10):
        super(SimpleNN, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# Enhanced Zeta-inspired optimizer
class ZetaOptimizer(optim.Optimizer):
    def __init__(self, params, lr=1e-3, s_min=1.001, s_max=2.0, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=0.3, adam_mix=0.5, total_steps=5000):
        if lr <= 0.0:
            raise ValueError("Learning rate must be positive")
        if s_min <= 1.0 or s_max <= 1.0:
            raise ValueError("s_min and s_max must be > 1 for convergence")
        if not 0.0 <= beta1 < 1.0:
            raise ValueError("Beta1 must be in [0, 1)")
        if not 0.0 <= beta2 < 1.0:
            raise ValueError("Beta2 must be in [0, 1)")
        if eps <= 0.0:
            raise ValueError("Epsilon must be positive")
        if clip_norm <= 0.0:
            raise ValueError("Clip norm must be positive")
        if not 0.0 <= adam_mix <= 1.0:
            raise ValueError("adam_mix must be in [0, 1]")

        defaults = dict(lr=lr, s_min=s_min, s_max=s_max, beta1=beta1, beta2=beta2, eps=eps, clip_norm=clip_norm, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)
        super(ZetaOptimizer, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        max_terms = 100
        for group in self.param_groups:
            lr = group['lr']
            s_min = group['s_min']
            s_max = group['s_max']
            beta1 = group['beta1']
            beta2 = group['beta2']
            eps = group['eps']
            clip_norm = group['clip_norm']
            zeta_damp = group['zeta_damp']
            adam_mix = group['adam_mix']
            total_steps = group['total_steps']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                grad = grad.clamp(min=-clip_norm, max=clip_norm)
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['m'] = torch.zeros_like(p.data)
                    state['v'] = torch.zeros_like(p.data)
                    state['grad_norm_ema'] = torch.zeros_like(p.data)
                    state['prev_grad'] = torch.zeros_like(p.data)
                    state['loss_ema'] = 1.0

                state['step'] += 1
                m, v = state['m'], state['v']
                grad_norm_ema = state['grad_norm_ema']
                prev_grad = state['prev_grad']
                loss_ema = state['loss_ema']
                step = state['step']

                # Dynamic s scheduling
                s = s_min + (s_max - s_min) * min(step / total_steps, 1.0)
                zeta_s = sum(1.0 / (n ** s) for n in range(1, max_terms + 1))

                # Per-parameter gradient norm
                grad_norm_raw = torch.sqrt((grad ** 2).sum(dim=-1, keepdim=True))
                grad_norm_ema.mul_(0.9).add_(grad_norm_raw, alpha=0.1)

                # Adaptive damping based on loss and gradient norm
                if closure is not None and loss is not None:
                    state['loss_ema'] = 0.9 * loss_ema + 0.1 * loss.item()
                adaptive_damp = zeta_damp * (1.0 + grad_norm_ema / (1.0 + grad_norm_ema)) * (1.0 / max(0.1, loss_ema))
                zeta_factor = adaptive_damp / zeta_s * (1.0 / (1.0 + step * 0.005))

                # Gradient consistency (cosine similarity)
                grad_flat = grad.view(-1)
                prev_grad_flat = prev_grad.view(-1)
                cos_sim = torch.dot(grad_flat, prev_grad_flat) / (grad_flat.norm() * prev_grad_flat.norm() + eps)
                momentum_boost = 1.0 + zeta_factor * 0.2 * max(0.0, cos_sim.item())
                state['prev_grad'].copy_(grad)

                m.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                v.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                m_hat = m / (1.0 - beta1 ** step)
                v_hat = v / (1.0 - beta2 ** step)

                # Adam update
                adam_update = m_hat / (torch.sqrt(v_hat) + eps)

                # Zeta update
                grad_norm = torch.sqrt(v_hat).add_(eps)
                norm_factor = torch.clamp(grad_norm, max=0.5)
                zeta_scaled_lr = lr * zeta_factor / (grad_norm ** (s - 1.0) * norm_factor)
                zeta_update = zeta_scaled_lr * m_hat * momentum_boost

                # Hybrid update
                final_update = adam_mix * adam_update + (1.0 - adam_mix) * zeta_update

                # Step-wise LR decay on plateau
                lr_mult = 1.0 if loss_ema > 0.1 else 0.5
                current_lr = lr * lr_mult * (0.5 * (1.0 + math.cos(math.pi * step / (total_steps * 1.2))))

                p.data.add_(-current_lr * final_update)

        return loss

# Load CIFAR-10 dataset
def get_cifar10_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

# Training function with accuracy
def train(model, optimizer, data_loader, criterion, epochs=5, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step(closure=lambda: criterion(model(inputs), labels))
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        avg_loss = total_loss / len(data_loader)
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, Train Accuracy: {accuracy:.2f}%")
    return model

# Test function
def test(model, data_loader, criterion, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

def main():
    # Hyperparameters
    batch_size = 64
    lr_adam = 0.001
    lr_sgd = 0.01
    lr_zeta = 0.0015
    epochs = 5
    s_min = 1.001
    s_max = 2.0
    zeta_damp = 0.3
    adam_mix = 0.5
    total_steps = 5000  # Approx steps for 5 epochs (50000 samples / batch_size)

    # Data
    train_loader, test_loader = get_cifar10_data(batch_size=batch_size)

    # Models
    model_adam = SimpleNN()
    model_sgd = SimpleNN()
    model_zeta = SimpleNN()

    # Loss
    criterion = nn.CrossEntropyLoss()

    # Optimizers
    optimizer_adam = optim.Adam(model_adam.parameters(), lr=lr_adam)
    optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=lr_sgd, momentum=0.9)
    optimizer_zeta = ZetaOptimizer(model_zeta.parameters(), lr=lr_zeta, s_min=s_min, s_max=s_max, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)

    # Train with Adam
    print("Training with Adam optimizer:")
    model_adam = train(model_adam, optimizer_adam, train_loader, criterion, epochs)
    print("Evaluating Adam on test set:")
    test(model_adam, test_loader, criterion)

    # Train with SGD
    print("\nTraining with SGD optimizer (with momentum):")
    model_sgd = train(model_sgd, optimizer_sgd, train_loader, criterion, epochs)
    print("Evaluating SGD on test set:")
    test(model_sgd, test_loader, criterion)

    # Train with ZetaOptimizer
    print("\nTraining with Final Enhanced Zeta-inspired optimizer:")
    model_zeta = train(model_zeta, optimizer_zeta, train_loader, criterion, epochs)
    print("Evaluating ZetaOptimizer on test set:")
    test(model_zeta, test_loader, criterion)

if __name__ == "__main__":
    main()

Training with Adam optimizer:
Epoch 1/5, Train Loss: 1.7064, Train Accuracy: 41.01%
Epoch 2/5, Train Loss: 1.4876, Train Accuracy: 48.25%
Epoch 3/5, Train Loss: 1.4124, Train Accuracy: 50.54%
Epoch 4/5, Train Loss: 1.3599, Train Accuracy: 52.79%
Epoch 5/5, Train Loss: 1.3094, Train Accuracy: 54.31%
Evaluating Adam on test set:
Test Loss: 1.4453, Test Accuracy: 50.72%

Training with SGD optimizer (with momentum):
Epoch 1/5, Train Loss: 1.7495, Train Accuracy: 39.24%
Epoch 2/5, Train Loss: 1.6289, Train Accuracy: 44.28%
Epoch 3/5, Train Loss: 1.5783, Train Accuracy: 46.55%
Epoch 4/5, Train Loss: 1.5474, Train Accuracy: 47.82%
Epoch 5/5, Train Loss: 1.4977, Train Accuracy: 49.82%
Evaluating SGD on test set:
Test Loss: 1.6464, Test Accuracy: 45.32%

Training with Final Enhanced Zeta-inspired optimizer:
Epoch 1/5, Train Loss: 1.6772, Train Accuracy: 41.58%
Epoch 2/5, Train Loss: 1.4535, Train Accuracy: 49.20%
Epoch 3/5, Train Loss: 1.3517, Train Accuracy: 53.09%
Epoch 4/5, Train Loss: 1.259

## **CIFAR-10**

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import math

# Neural Network for CIFAR-100
class SimpleNN(nn.Module):
    def __init__(self, input_channels=3, hidden_dim=256, output_dim=100):
        super(SimpleNN, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_channels * 32 * 32, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# Enhanced Zeta-inspired optimizer
class ZetaOptimizer(optim.Optimizer):
    def __init__(self, params, lr=1e-3, s_min=1.001, s_max=2.0, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=0.3, adam_mix=0.7, total_steps=5000):
        if lr <= 0.0:
            raise ValueError("Learning rate must be positive")
        if s_min <= 1.0 or s_max <= 1.0:
            raise ValueError("s_min and s_max must be > 1 for convergence")
        if not 0.0 <= beta1 < 1.0:
            raise ValueError("Beta1 must be in [0, 1)")
        if not 0.0 <= beta2 < 1.0:
            raise ValueError("Beta2 must be in [0, 1)")
        if eps <= 0.0:
            raise ValueError("Epsilon must be positive")
        if clip_norm <= 0.0:
            raise ValueError("Clip norm must be positive")
        if not 0.0 <= adam_mix <= 1.0:
            raise ValueError("adam_mix must be in [0, 1]")

        defaults = dict(lr=lr, s_min=s_min, s_max=s_max, beta1=beta1, beta2=beta2, eps=eps, clip_norm=clip_norm, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)
        super(ZetaOptimizer, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        max_terms = 100
        for group in self.param_groups:
            lr = group['lr']
            s_min = group['s_min']
            s_max = group['s_max']
            beta1 = group['beta1']
            beta2 = group['beta2']
            eps = group['eps']
            clip_norm = group['clip_norm']
            zeta_damp = group['zeta_damp']
            adam_mix = group['adam_mix']
            total_steps = group['total_steps']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                grad = grad.clamp(min=-clip_norm, max=clip_norm)
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['m'] = torch.zeros_like(p.data)
                    state['v'] = torch.zeros_like(p.data)
                    state['grad_norm_ema'] = torch.zeros_like(p.data)
                    state['prev_grad'] = torch.zeros_like(p.data)
                    state['loss_ema'] = 1.0

                state['step'] += 1
                m, v = state['m'], state['v']
                grad_norm_ema = state['grad_norm_ema']
                prev_grad = state['prev_grad']
                loss_ema = state['loss_ema']
                step = state['step']

                s = s_min + (s_max - s_min) * min(step / total_steps, 1.0)
                zeta_s = sum(1.0 / (n ** s) for n in range(1, max_terms + 1))

                grad_norm_raw = torch.sqrt((grad ** 2).sum(dim=-1, keepdim=True))
                grad_norm_ema.mul_(0.9).add_(grad_norm_raw, alpha=0.1)

                if closure is not None and loss is not None:
                    state['loss_ema'] = 0.9 * loss_ema + 0.1 * loss.item()
                adaptive_damp = zeta_damp * (1.0 + grad_norm_ema / (1.0 + grad_norm_ema)) * (1.0 / max(0.1, loss_ema))
                zeta_factor = adaptive_damp / zeta_s * (1.0 / (1.0 + step * 0.005))

                grad_flat = grad.view(-1)
                prev_grad_flat = prev_grad.view(-1)
                cos_sim = torch.dot(grad_flat, prev_grad_flat) / (grad_flat.norm() * prev_grad_flat.norm() + eps)
                momentum_boost = 1.0 + zeta_factor * 0.2 * max(0.0, cos_sim.item())
                state['prev_grad'].copy_(grad)

                m.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                v.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                m_hat = m / (1.0 - beta1 ** step)
                v_hat = v / (1.0 - beta2 ** step)

                adam_update = m_hat / (torch.sqrt(v_hat) + eps)

                grad_norm = torch.sqrt(v_hat).add_(eps)
                norm_factor = torch.clamp(grad_norm, max=0.5)
                zeta_scaled_lr = lr * zeta_factor / (grad_norm ** (s - 1.0) * norm_factor)
                zeta_update = zeta_scaled_lr * m_hat * momentum_boost

                final_update = adam_mix * adam_update + (1.0 - adam_mix) * zeta_update

                lr_mult = 1.0 if loss_ema > 0.1 else 0.5
                current_lr = lr * lr_mult * (0.5 * (1.0 + math.cos(math.pi * step / (total_steps * 1.2))))

                p.data.add_(-current_lr * final_update)

        return loss

# Load CIFAR-100 dataset
def get_cifar100_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

# Training function with accuracy
def train(model, optimizer, data_loader, criterion, epochs=5, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step(closure=lambda: criterion(model(inputs), labels))
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        avg_loss = total_loss / len(data_loader)
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, Train Accuracy: {accuracy:.2f}%")
    return model

# Test function
def test(model, data_loader, criterion, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%")

def main():
    # Hyperparameters
    batch_size = 64
    lr_adam = 0.001
    lr_sgd = 0.01
    lr_zeta = 0.001
    epochs = 20
    s_min = 1.001
    s_max = 2.0
    zeta_damp = 0.3
    adam_mix = 0.7
    total_steps = 5000  # Approx steps for 5 epochs (50000 samples / batch_size)

    # Data
    train_loader, test_loader = get_cifar100_data(batch_size=batch_size)

    # Models
    model_adam = SimpleNN()
    model_sgd = SimpleNN()
    model_zeta = SimpleNN()

    # Loss
    criterion = nn.CrossEntropyLoss()

    # Optimizers
    optimizer_adam = optim.Adam(model_adam.parameters(), lr=lr_adam)
    optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=lr_sgd, momentum=0.9)
    optimizer_zeta = ZetaOptimizer(model_zeta.parameters(), lr=lr_zeta, s_min=s_min, s_max=s_max, beta1=0.85, beta2=0.999, eps=1e-8, clip_norm=1.0, zeta_damp=zeta_damp, adam_mix=adam_mix, total_steps=total_steps)

    # Train with Adam
    print("Training with Adam optimizer:")
    model_adam = train(model_adam, optimizer_adam, train_loader, criterion, epochs)
    print("Evaluating Adam on test set:")
    test(model_adam, test_loader, criterion)

    # Train with SGD
    print("\nTraining with SGD optimizer (with momentum):")
    model_sgd = train(model_sgd, optimizer_sgd, train_loader, criterion, epochs)
    print("Evaluating SGD on test set:")
    test(model_sgd, test_loader, criterion)

    # Train with ZetaOptimizer
    print("\nTraining with Final Enhanced Zeta-inspired optimizer:")
    model_zeta = train(model_zeta, optimizer_zeta, train_loader, criterion, epochs)
    print("Evaluating ZetaOptimizer on test set:")
    test(model_zeta, test_loader, criterion)

if __name__ == "__main__":
    main()

Training with Adam optimizer:
Epoch 1/20, Train Loss: 3.7533, Train Accuracy: 13.49%
Epoch 2/20, Train Loss: 3.3757, Train Accuracy: 19.42%
Epoch 3/20, Train Loss: 3.1953, Train Accuracy: 22.67%
Epoch 4/20, Train Loss: 3.0651, Train Accuracy: 24.91%
Epoch 5/20, Train Loss: 2.9407, Train Accuracy: 27.60%
Epoch 6/20, Train Loss: 2.8321, Train Accuracy: 29.29%
Epoch 7/20, Train Loss: 2.7379, Train Accuracy: 31.10%
Epoch 8/20, Train Loss: 2.6352, Train Accuracy: 33.03%
Epoch 9/20, Train Loss: 2.5409, Train Accuracy: 35.02%
Epoch 10/20, Train Loss: 2.4469, Train Accuracy: 36.48%
Epoch 11/20, Train Loss: 2.3633, Train Accuracy: 38.74%
Epoch 12/20, Train Loss: 2.2834, Train Accuracy: 40.04%
Epoch 13/20, Train Loss: 2.1946, Train Accuracy: 42.20%
Epoch 14/20, Train Loss: 2.1243, Train Accuracy: 43.64%
Epoch 15/20, Train Loss: 2.0560, Train Accuracy: 45.03%
Epoch 16/20, Train Loss: 1.9885, Train Accuracy: 46.66%
Epoch 17/20, Train Loss: 1.9149, Train Accuracy: 48.38%
Epoch 18/20, Train Loss: 1.