In [None]:
#%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.utils.data as data_utils
from utils import epoch, epoch_robust_bound, epoch_calculate_robust_err, Flatten, generate_kappa_schedule_MNIST, generate_epsilon_schedule_MNIST

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

In [None]:
mnist_train = datasets.MNIST("./", train=True, download=True, transform=transforms.ToTensor())

In [None]:
mnist_test = datasets.MNIST("./", train=False, download=True, transform=transforms.ToTensor())

In [None]:
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

# Models

In [None]:
model_cnn_medium = nn.Sequential(nn.Conv2d(1, 32, 3, padding=0, stride=1), nn.ReLU(),
                                   nn.Conv2d(32, 32, 4, padding=0, stride=2), nn.ReLU(),
                                   nn.Conv2d(32, 64, 3, padding=0, stride=1), nn.ReLU(),
                                   nn.Conv2d(64, 64, 4, padding=0, stride=2), nn.ReLU(),
                                   Flatten(),
                                   nn.Linear(64*4*4, 512), nn.ReLU(),
                                   nn.Linear(512, 512), nn.ReLU(),
                                   nn.Linear(512, 10)).to(device)

In [None]:
model_cnn_small = nn.Sequential(nn.Conv2d(1, 16, 4, padding=0, stride=2), nn.ReLU(),
                                   nn.Conv2d(16, 32, 4, padding=0, stride=1), nn.ReLU(),
                                   Flatten(),
                                   nn.Linear(32*10*10, 100), nn.ReLU(),
                                   nn.Linear(100, 10)).to(device)

# Training

In [None]:
opt = optim.Adam(model_cnn_medium.parameters(), lr=1e-3)

EPSILON = 0.1
EPSILON_TRAIN = 0.2
epsilon_schedule = generate_epsilon_schedule_MNIST(EPSILON_TRAIN)
kappa_schedule = generate_kappa_schedule_MNIST()
batch_counter = 0

print("Epoch   ", "Combined Loss", "Test Err", "Test Robust Err", sep="\t")

for t in range(100):
    _, combined_loss = epoch_robust_bound(train_loader, model_cnn_medium, epsilon_schedule, device, kappa_schedule, batch_counter, opt)
    
    # check loss and accuracy on test set
    test_err, _ = epoch(test_loader, model_cnn_medium, device)
    robust_err = epoch_calculate_robust_err(test_loader, model_cnn_medium, EPSILON, device)
    
    batch_counter += 600
    
    if t == 24:  #decrease learning rate after 25 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-4
    
    if t == 40:  #decrease learning rate after 41 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-5
    
    print(*("{:.6f}".format(i) for i in (t, combined_loss, test_err, robust_err)), sep="\t")