In [None]:
#%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.utils.data as data_utils
from utils import epoch, epoch_robust_bound, epoch_calculate_robust_err, Flatten, generate_kappa_schedule_CIFAR, generate_epsilon_schedule_CIFAR

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

In [None]:
BATCH_SIZE = 50
dataset_path = './cifar10'

In [None]:
trainset = datasets.CIFAR10(root=dataset_path, train=True, download=True)

In [None]:
train_mean = trainset.train_data.mean(axis=(0,1,2))/255  # [0.49139968  0.48215841  0.44653091]
train_std = trainset.train_data.std(axis=(0,1,2))/255  # [0.24703223  0.24348513  0.26158784]

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
])
kwargs = {'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
    root=dataset_path, train=True, download=True,
    transform=transform_train),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root=dataset_path, train=False, download=True,
    transform=transform_test),
    batch_size=BATCH_SIZE, shuffle=False, **kwargs)

# Model

In [None]:
model_cnn_medium = nn.Sequential(nn.Conv2d(3, 32, 3, padding=0, stride=1), nn.ReLU(),
                                   nn.Conv2d(32, 32, 4, padding=0, stride=2), nn.ReLU(),
                                   nn.Conv2d(32, 64, 3, padding=0, stride=1), nn.ReLU(),
                                   nn.Conv2d(64, 64, 4, padding=0, stride=2), nn.ReLU(),
                                   Flatten(),
                                   nn.Linear(64*5*5, 512), nn.ReLU(),
                                   nn.Linear(512, 512), nn.ReLU(),
                                   nn.Linear(512, 10)).to(device)

# Training

In [None]:
opt = optim.Adam(model_cnn_medium.parameters(), lr=1e-3)

EPSILON = 0.1
EPSILON_TRAIN = 0.2
epsilon_schedule = generate_epsilon_schedule_CIFAR(EPSILON_TRAIN)
kappa_schedule = generate_kappa_schedule_CIFAR()
batch_counter = 0

print("Epoch   ", "Combined Loss", "Test Err", "Test Robust Err", sep="\t")

for t in range(350):
    _, combined_loss = epoch_robust_bound(train_loader, model_cnn_medium, epsilon_schedule, device, kappa_schedule, batch_counter, opt)
    
    # check loss and accuracy on test set
    test_err, _ = epoch(test_loader, model_cnn_medium, device)
    robust_err = epoch_calculate_robust_err(test_loader, model_cnn_medium, EPSILON, device)
    
    batch_counter += 1000
    
    if t == 200:  #decrease learning rate after 200 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-4
    
    if t == 250:  #decrease learning rate after 250 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-5
    
    if t == 300:  #decrease learning rate after 300 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-6
    
    print(*("{:.6f}".format(i) for i in (t, combined_loss, test_err, robust_err)), sep="\t")