In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Hyperparameters
batch_size = 128
lr = 0.01
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data preparation (MNIST)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Logistic Regression Model
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.linear(x)

# Function to compute full gradient
def full_gradient(model, criterion, data_loader):
    model.train()
    grad = torch.zeros_like(model.linear.weight)
    bias_grad = torch.zeros_like(model.linear.bias)
    
    for inputs, labels in data_loader:
        inputs, labels = inputs.view(inputs.size(0), -1).to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        grad += model.linear.weight.grad.data
        bias_grad += model.linear.bias.grad.data
    
    grad /= len(data_loader)
    bias_grad /= len(data_loader)
    return grad, bias_grad

# SVRG Optimization
def svrg(model, criterion, train_loader, num_epochs, lr):
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        # Compute the full gradient
        full_grad, full_bias_grad = full_gradient(model, criterion, train_loader)
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.view(inputs.size(0), -1).to(device), labels.to(device)
            # Snapshot model weights
            snapshot_weights = model.linear.weight.data.clone()
            snapshot_bias = model.linear.bias.data.clone()
            
            # Compute snapshot gradient
            outputs_snapshot = model(inputs)
            loss_snapshot = criterion(outputs_snapshot, labels)
            loss_snapshot.backward()
            
            snapshot_grad = model.linear.weight.grad.data
            snapshot_bias_grad = model.linear.bias.grad.data
            
            # Update rule for SVRG
            with torch.no_grad():
                model.linear.weight.data -= lr * (snapshot_grad - full_grad + full_grad)
                model.linear.bias.data -= lr * (snapshot_bias_grad - full_bias_grad + full_bias_grad)
            
            # Zero gradients
            model.zero_grad()
    
    print("Optimization with SVRG completed.")

# Loss function and model initialization
model = LogisticRegression(28*28, 10)
criterion = nn.CrossEntropyLoss()

# Apply SVRG on MNIST
svrg(model, criterion, train_loader, num_epochs=num_epochs, lr=lr)


Epoch [1/10]
Epoch [2/10]
Epoch [3/10]
Epoch [4/10]
Epoch [5/10]
Epoch [6/10]
Epoch [7/10]
Epoch [8/10]
Epoch [9/10]
Epoch [10/10]
Optimization with SVRG completed.


In [2]:
def evaluate_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.view(inputs.size(0), -1).to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# After training with SVRG, evaluate accuracy
train_accuracy = evaluate_accuracy(model, train_loader)
test_accuracy = evaluate_accuracy(model, test_loader)
print(f"Train Accuracy: {train_accuracy:.2f}%")
print(f"Test Accuracy: {test_accuracy:.2f}%")

Train Accuracy: 91.60%
Test Accuracy: 91.60%
