In [1]:
#%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.utils.data as data_utils
from utils import epoch, epoch_robust_bound, epoch_calculate_robust_err, Flatten, generate_kappa_schedule_MNIST, generate_epsilon_schedule_MNIST
from utils import bound_propagation, new_epoch_robust_bound, epoch_robust_bound
import torch.nn.functional as F

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

<torch._C.Generator at 0xed682ed950>

In [4]:
BATCH_SIZE = 50
dataset_path = './cifar10'

In [5]:
trainset = datasets.CIFAR10(root=dataset_path, train=True, download=True)

Files already downloaded and verified


In [6]:
train_mean = trainset.train_data.mean(axis=(0,1,2))/255  # [0.49139968  0.48215841  0.44653091]
train_std = trainset.train_data.std(axis=(0,1,2))/255  # [0.24703223  0.24348513  0.26158784]

transform_train = transforms.Compose([
    #transforms.RandomCrop(32, padding=4),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
])
kwargs = {'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
    root=dataset_path, train=True, download=True,
    transform=transform_train),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root=dataset_path, train=False, download=True,
    transform=transform_test),
    batch_size=BATCH_SIZE, shuffle=False, **kwargs)

Files already downloaded and verified
Files already downloaded and verified


# Models

In [7]:
class CNN_medium(torch.nn.Module):
    def __init__(self):

        super(CNN_medium, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, 3, padding=0, stride=1)
        self.relu1 = nn.ReLU() 
        self.conv2 = nn.Conv2d(32, 32, 4, padding=0, stride=2)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 64, 3, padding=0, stride=1)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 64, 4, padding=0, stride=2)
        self.relu4 = nn.ReLU()
        self.flat = Flatten()
        self.linear1 = nn.Linear(64*5*5, 512)
        self.relu5 = nn.ReLU()
        self.linear2 = nn.Linear(512, 512)
        self.relu6 = nn.ReLU()
        self.last_linear = nn.Linear(512, 10)                
        
    def forward(self, x):
        
        hidden_activations = []
        
        x = self.conv1(x)
        x = self.relu1(x)
        hidden_activations.append(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        hidden_activations.append(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        hidden_activations.append(x)
        
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.flat(x)
        hidden_activations.append(x)

        x = self.linear1(x)
        x = self.relu5(x)
        hidden_activations.append(x)

        x = self.linear2(x)
        x = self.relu6(x)
        hidden_activations.append(x)
        
        out = self.last_linear(x)
        hidden_activations.append(out)
        
        return out, hidden_activations

In [8]:
model_ref = CNN_medium().to(device)
model_robust = CNN_medium().to(device)

In [9]:
model_ref.load_state_dict(torch.load("CIFAR_trained_model.pth"))
model_robust.load_state_dict(torch.load("CIFAR_trained_model.pth"))

# Normal training

Let's first train the model to reach >80% accuracy and save the model.

In [None]:
opt = optim.Adam(model.parameters(), lr=1e-3)

for t in range(45): 
    train_err, _ = epoch(train_loader, model, device, opt)
    print (train_err)

In [None]:
PATH = "CIFAR_trained_model.pth"

In [None]:
torch.save(model.state_dict(), PATH)

In [None]:
model.load_state_dict(torch.load("CIFAR_trained_model.pth"))

In [None]:
len(test_loader.dataset)

# Robust training

In [None]:
model_robust.load_state_dict(torch.load("robust_layer_14_epoch_4.pth"))

In [16]:
def epoch_robust_train(loader, model_ref, model_robust, epsilon, device, opt=None):
    
    total_loss_fit = 0
    total_loss_spec = 0
    
    for X,y in loader:

        loss_spec = []
        lower_bounds = []
        upper_bounds = []
        real_values = []
        
        X,y = X.to(device), y.to(device)
        
        yp, hidden_activations = model_ref(X)
        loss_fit = nn.CrossEntropyLoss(reduction="mean")(yp,y) #calculate regular loss
        
        initial_bounds = (X-epsilon, X+epsilon)
        bounds = bound_propagation(model_robust, initial_bounds, how_many_layers=14) #calculate bounds up to 14th layer
        
        lower_bounds.append(Flatten()(bounds[2][0])) #lower bounds 
        upper_bounds.append(Flatten()(bounds[2][1])) #upper bounds 
        real_values.append(Flatten()(hidden_activations[0])) #real activations 
        loss_spec.append(nn.MSELoss()(real_values[0].detach(), upper_bounds[0]) + nn.MSELoss()(real_values[0].detach(), lower_bounds[0]))
        
        lower_bounds.append(Flatten()(bounds[4][0])) #lower bounds 
        upper_bounds.append(Flatten()(bounds[4][1])) #upper bounds 
        real_values.append(Flatten()(hidden_activations[1])) #real activations 
        loss_spec.append(nn.MSELoss()(real_values[1].detach(), upper_bounds[1]) + nn.MSELoss()(real_values[1].detach(), lower_bounds[1]))
        
        lower_bounds.append(Flatten()(bounds[6][0])) #lower bounds 
        upper_bounds.append(Flatten()(bounds[6][1])) #upper bounds 
        real_values.append(Flatten()(hidden_activations[2])) #real activations 
        loss_spec.append(nn.MSELoss()(real_values[2].detach(), upper_bounds[2]) + nn.MSELoss()(real_values[2].detach(), lower_bounds[2]))
        
        lower_bounds.append(Flatten()(bounds[9][0])) #lower bounds 
        upper_bounds.append(Flatten()(bounds[9][1])) #upper bounds 
        real_values.append(Flatten()(hidden_activations[3])) #real activations 
        loss_spec.append(nn.MSELoss()(real_values[3].detach(), upper_bounds[3]) + nn.MSELoss()(real_values[3].detach(), lower_bounds[3]))
        
        lower_bounds.append(Flatten()(bounds[11][0])) #lower bounds 
        upper_bounds.append(Flatten()(bounds[11][1])) #upper bounds 
        real_values.append(Flatten()(hidden_activations[4])) #real activations 
        loss_spec.append(nn.MSELoss()(real_values[4].detach(), upper_bounds[4]) + nn.MSELoss()(real_values[4].detach(), lower_bounds[4]))
        
        lower_bounds.append(Flatten()(bounds[13][0])) #lower bounds 
        upper_bounds.append(Flatten()(bounds[13][1])) #upper bounds 
        real_values.append(Flatten()(hidden_activations[5])) #real activations 
        loss_spec.append(nn.MSELoss()(real_values[5].detach(), upper_bounds[5]) + nn.MSELoss()(real_values[5].detach(), lower_bounds[5]))
        
        lower_bounds.append(Flatten()(bounds[14][0])) #lower bounds 
        upper_bounds.append(Flatten()(bounds[14][1])) #upper bounds 
        real_values.append(Flatten()(hidden_activations[6])) #real activations 
        loss_spec.append(nn.MSELoss()(real_values[6].detach(), upper_bounds[6]) + nn.MSELoss()(real_values[6].detach(), lower_bounds[6]))
                
        
        combined_loss = loss_spec[0] + loss_spec[1] + loss_spec[2] + loss_spec[3] + loss_spec[4] + loss_spec[5] + loss_spec[6]                                           
        #combined_loss = loss_fit + loss_spec
        
        total_loss_fit += loss_fit.item() * X.shape[0]
        total_loss_spec += loss_spec[6].item() * X.shape[0]
        
        if opt:
            opt.zero_grad()
            combined_loss.backward()
            opt.step()
        
    return total_loss_fit/len(loader.dataset), total_loss_spec/len(loader.dataset)

In [None]:
model_robust.load_state_dict(torch.load("robust_layer_14_loss_63.pth"))

In [None]:
epsilon = 8/255
opt = optim.Adam(model_robust.parameters(), lr=1e-3)

for t in range(50): 

    
    loss_fit, loss_spec = epoch_robust_train(train_loader, model_ref, model_robust, epsilon, device, opt)
    #loss_fit, loss_spec = new_epoch_robust_bound(train_loader, model_robust, epsilon, device, opt) 
    print (f'Epoch {t}: Loss_fit: {loss_fit}     Loss_spec: {loss_spec}')

Epoch 0: Loss_fit: 0.41829184756428     Loss_spec: 11915808678056.191
Epoch 1: Loss_fit: 0.41829185001552105     Loss_spec: 10648875179.946
Epoch 2: Loss_fit: 0.41829185144603254     Loss_spec: 2891487214.9535
Epoch 3: Loss_fit: 0.41829184702038763     Loss_spec: 1053530492.9274563
Epoch 4: Loss_fit: 0.4182918484508991     Loss_spec: 460019156.77685696
Epoch 5: Loss_fit: 0.41829184933006763     Loss_spec: 221333956.88957378
Epoch 6: Loss_fit: 0.4182918483838439     Loss_spec: 107087007.41236605
Epoch 7: Loss_fit: 0.4182918471172452     Loss_spec: 51570353.14583178
Epoch 8: Loss_fit: 0.4182918494567275     Loss_spec: 24586724.287693422
Epoch 9: Loss_fit: 0.4182918496504426     Loss_spec: 11696254.194039695
Epoch 10: Loss_fit: 0.4182918484434485     Loss_spec: 5715137.133328239
Epoch 11: Loss_fit: 0.41829184725880625     Loss_spec: 2752379.4642795636
Epoch 12: Loss_fit: 0.41829184879362585     Loss_spec: 1323863.570231842
Epoch 13: Loss_fit: 0.41829184836149214     Loss_spec: 635488.0026

In [None]:
torch.save(model_robust.state_dict(), "robust_bigger_epsilon_epoch_100.pth")

In [None]:
epsilon = 8/255
how_many_layers = 14
robust_err = epoch_calculate_robust_err (test_loader, model_robust, epsilon, how_many_layers, device)

In [None]:
print (robust_err)

In [None]:
test_err, _ = epoch(test_loader, model_robust, device)
print (f'Test error: {test_err}')

# Training

In [None]:
opt = optim.Adam(model_cnn_medium.parameters(), lr=1e-3)

EPSILON = 0.1
EPSILON_TRAIN = 0.2
epsilon_schedule = generate_epsilon_schedule_MNIST(EPSILON_TRAIN)
kappa_schedule = generate_kappa_schedule_MNIST()
batch_counter = 0

print("Epoch   ", "Combined Loss", "Test Err", "Test Robust Err", sep="\t")

for t in range(100):
    _, combined_loss = epoch_robust_bound(train_loader, model_cnn_medium, epsilon_schedule, device, kappa_schedule, batch_counter, opt)
    
    # check loss and accuracy on test set
    test_err, _ = epoch(test_loader, model_cnn_medium, device)
    robust_err = epoch_calculate_robust_err(test_loader, model_cnn_medium, EPSILON, device)
    
    batch_counter += 600
    
    if t == 24:  #decrease learning rate after 25 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-4
    
    if t == 40:  #decrease learning rate after 41 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-5
    
    print(*("{:.6f}".format(i) for i in (t, combined_loss, test_err, robust_err)), sep="\t")