In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class SimpleMLP(nn.Module):
    def __init__(self, spectral=False):
        super().__init__()
        self.spectral = spectral

        if spectral:
            self.net = nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(1, 8, 3, stride=2, padding=1)),
                nn.ReLU(),
                nn.utils.spectral_norm(nn.Conv2d(8, 16, 3, stride=2, padding=1)),
                nn.ReLU(),
                nn.Flatten(),
                nn.utils.spectral_norm(nn.Linear(16 * 7 * 7, 64)),  
                nn.ReLU(),
                nn.Linear(64, 10)  
            )
        else:
            self.net = nn.Sequential(
                nn.Conv2d(1, 8, 3, stride=2, padding=1),
                nn.ReLU(),
                nn.Conv2d(8, 16, 3, stride=2, padding=1),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(16 * 7 * 7, 64),  
                nn.ReLU(),
                nn.Linear(64, 10)
            )
            

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        return self.net(x)


transform = transforms.ToTensor()
train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=64, shuffle=True)
test_loader = DataLoader(
    datasets.MNIST('./data', train=False, transform=transform),
    batch_size=1000)

def train(model, optimizer, criterion, loader, record_grads=False):
    model.train()
    loss_list, grad_list = [], []
    for epoch in range(10):
        total_loss = 0
        for data, target in loader:
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()

            if record_grads:
                grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5
                grad_list.append(grad_norm)

            optimizer.step()
            total_loss += loss.item()
        loss_list.append(total_loss / len(loader))
    return loss_list, grad_list

def test(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            pred = model(data).argmax(dim=1)
            correct += (pred == target).sum().item()
    return correct / len(loader.dataset)

baseline = SimpleMLP(spectral=False).to(device)
opt_b = optim.Adam(baseline.parameters(), lr=0.001)
loss_b, grad_b = train(baseline, opt_b, nn.CrossEntropyLoss(), train_loader, record_grads=True)
acc_b = test(baseline, test_loader)

lipschitz_lr01 = SimpleMLP(spectral=True).to(device)
opt_l1 = optim.Adam(lipschitz_lr01.parameters(), lr=0.005)
loss_l1, grad_l1 = train(lipschitz_lr01, opt_l1, nn.CrossEntropyLoss(), train_loader, record_grads=True)
acc_l1 = test(lipschitz_lr01, test_loader)

lipschitz_lr001 = SimpleMLP(spectral=True).to(device)
opt_l2 = optim.Adam(lipschitz_lr001.parameters(), lr=0.001)
loss_l2, grad_l2 = train(lipschitz_lr001, opt_l2, nn.CrossEntropyLoss(), train_loader, record_grads=True)
acc_l2 = test(lipschitz_lr001, test_loader)

# Print accuracies
print(f"Baseline Accuracy: {acc_b:.4f}")
print(f"Spectral Norm (lr=0.005) Accuracy: {acc_l1:.4f}")
print(f"Spectral Norm (lr=0.001) Accuracy: {acc_l2:.4f}")

# Plot training losses
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(loss_b, label='Baseline (lr=0.001)')
plt.plot(loss_l1, label='Spectral Norm (lr=0.005)')
plt.plot(loss_l2, label='Spectral Norm (lr=0.001)')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot gradient norms
plt.subplot(1, 2, 2)
plt.plot(grad_b, label='Baseline (lr=0.001)')
plt.plot(grad_l1, label='Spectral Norm (lr=0.005)')
plt.plot(grad_l2, label='Spectral Norm (lr=0.001)')
plt.title('Gradient Norm per Batch')
plt.xlabel('Batch')
plt.ylabel('Gradient Norm')
plt.legend()
plt.tight_layout()
plt.show()

# Robustness testing
def add_noise(inputs, std=0.2):
    return inputs + std * torch.randn_like(inputs)

noisy_test_data = [(add_noise(data), label) for data, label in test_loader.dataset]
noisy_test_loader = DataLoader(noisy_test_data, batch_size=1000)

# Get train accuracies
train_acc_b = test(baseline, train_loader)
train_acc_l1 = test(lipschitz_lr01, train_loader)
train_acc_l2 = test(lipschitz_lr001, train_loader)

# Get robust accuracies
robust_acc_b = test(baseline, noisy_test_loader)
robust_acc_l1 = test(lipschitz_lr01, noisy_test_loader)
robust_acc_l2 = test(lipschitz_lr001, noisy_test_loader)

print("\nTraining Accuracies:")
print(f"Baseline: {train_acc_b:.4f}")
print(f"Spectral Norm (lr=0.005): {train_acc_l1:.4f}")
print(f"Spectral Norm (lr=0.001): {train_acc_l2:.4f}")

print("\nRobust Accuracies (noisy data):")
print(f"Baseline: {robust_acc_b:.4f}")
print(f"Spectral Norm (lr=0.005): {robust_acc_l1:.4f}")
print(f"Spectral Norm (lr=0.001): {robust_acc_l2:.4f}")


Using device: cpu
