In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 139089672.39it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 22399351.67it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 123262541.16it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3303368.96it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [5]:
import math
import torch
from torch.optim import Optimizer

class AdEMAMix(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9999), alpha=5.0, eps=1e-8, weight_decay=0.0, T_alpha=0, T_beta3=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, alpha=alpha, T_alpha=T_alpha, T_beta3=T_beta3)
        super(AdEMAMix, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['m1'] = torch.zeros_like(p.data)  # Fast EMA
                    state['m2'] = torch.zeros_like(p.data)  # Slow EMA
                    state['v'] = torch.zeros_like(p.data)   # Second moment (like ADAM)

                m1, m2, v = state['m1'], state['m2'], state['v']
                beta1, beta2, beta3_final = group['betas']
                eps, alpha_final = group['eps'], group['alpha']
                lr, weight_decay = group['lr'], group['weight_decay']
                T_alpha, T_beta3 = group['T_alpha'], group['T_beta3']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                # Schedulers for alpha and beta3
                alpha = alpha_scheduler(state['step'], alpha_final, T_alpha)
                beta3 = beta3_scheduler(state['step'], beta1, beta3_final, T_beta3)

                # Update fast EMA
                m1.mul_(beta1).add_(1 - beta1, grad)
                
                # Update slow EMA
                m2.mul_(beta3).add_(1 - beta3, grad)

                # Update second moment estimate (similar to ADAM)
                v.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                # Compute bias-corrected first moment estimate
                m1_hat = m1 / bias_correction1

                # Compute bias-corrected second moment estimate
                v_hat = v / bias_correction2

                # Parameter update step
                denom = (v_hat.sqrt() + eps)
                update = (m1_hat + alpha * m2) / denom

                if weight_decay != 0:
                    update.add_(p.data, alpha=weight_decay)

                p.data.add_(-lr * update)

        return loss

# Schedulers for alpha and beta3 based on training steps
def alpha_scheduler(step, alpha_final, T_alpha):
    if T_alpha == 0:
        return alpha_final
    return min(step / T_alpha, 1.0) * alpha_final

def beta3_scheduler(step, beta_start, beta3_final, T_beta3):
    if T_beta3 == 0:
        return beta3_final
    return beta_start + (beta3_final - beta_start) * min(step / T_beta3, 1.0)


In [6]:
# Define a CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 12 * 12, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 12 * 12)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Instantiate the model
model = CNNModel()

# Loss function
criterion = nn.CrossEntropyLoss()

# Using AdEMAMix optimizer
optimizer_ademamix = AdEMAMix(model.parameters(), lr=0.001, betas=(0.9, 0.999, 0.9999), alpha=5.0, eps=1e-8, T_alpha=0, T_beta3=0)

# Training function
def train_model(model, optimizer, criterion, train_loader, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}')

# Testing function
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

# Train and test using AdEMAMix
print("Training with AdEMAMix Optimizer")
train_model(model, optimizer_ademamix, criterion, train_loader, epochs=5)
test_model(model, test_loader)

# Now, switch to ADAM for comparison
model_adam = CNNModel()
optimizer_adam = optim.Adam(model_adam.parameters(), lr=0.001)

print("\nTraining with Adam Optimizer")
train_model(model_adam, optimizer_adam, criterion, train_loader, epochs=5)
test_model(model_adam, test_loader)

Training with AdEMAMix Optimizer


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1630.)
  m1.mul_(beta1).add_(1 - beta1, grad)


Epoch [1/5], Loss: 0.1489
Epoch [2/5], Loss: 0.0455
Epoch [3/5], Loss: 0.0272
Epoch [4/5], Loss: 0.0207
Epoch [5/5], Loss: 0.0155
Test Accuracy: 98.99%

Training with Adam Optimizer
Epoch [1/5], Loss: 0.1634
Epoch [2/5], Loss: 0.0552
Epoch [3/5], Loss: 0.0369
Epoch [4/5], Loss: 0.0276
Epoch [5/5], Loss: 0.0227
Test Accuracy: 98.94%
