In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, RandomSampler
import torch.utils.data as data_utils
#from utils import epoch, epoch_robust_bound, epoch_calculate_robust_err, Flatten, generate_kappa_schedule_CIFAR, generate_epsilon_schedule_CIFAR

# Utils and helper functions

In [0]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1) 

In [0]:
def epoch(loader, model, device, opt=None):
    """Standard training/evaluation epoch over the dataset"""
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        yp,_ = model(X)
        loss = nn.CrossEntropyLoss()(yp,y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)


def bound_propagation(model, initial_bound):
    l, u = initial_bound
    bounds = []
    bounds.append(initial_bound)
    list_of_layers = list(model.children())
    
    for i in range(len(list_of_layers)):
        layer = list_of_layers[i]
        
        if isinstance(layer, Flatten):
            l_ = Flatten()(l)
            u_ = Flatten()(u)

        elif isinstance(layer, nn.Linear):
            l_ = (layer.weight.clamp(min=0) @ l.t() + layer.weight.clamp(max=0) @ u.t() 
                  + layer.bias[:,None]).t()
            u_ = (layer.weight.clamp(min=0) @ u.t() + layer.weight.clamp(max=0) @ l.t() 
                  + layer.bias[:,None]).t()
            
        elif isinstance(layer, nn.Conv2d):
            l_ = (nn.functional.conv2d(l, layer.weight.clamp(min=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  nn.functional.conv2d(u, layer.weight.clamp(max=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  layer.bias[None,:,None,None])
            
            u_ = (nn.functional.conv2d(u, layer.weight.clamp(min=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  nn.functional.conv2d(l, layer.weight.clamp(max=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) + 
                  layer.bias[None,:,None,None])
            
        elif isinstance(layer, nn.ReLU):
            l_ = l.clamp(min=0)
            u_ = u.clamp(min=0)
            
        bounds.append((l_, u_))
        l,u = l_, u_
    return bounds


def interval_based_bound(model, c, bounds, idx):
    # requires last layer to be linear
    cW = c.t() @ model.last_linear.weight
    cb = c.t() @ model.last_linear.bias
    
    l,u = bounds[-2]
    return (cW.clamp(min=0) @ l[idx].t() + cW.clamp(max=0) @ u[idx].t() + cb[:,None]).t()


def epoch_robust_bound(loader, model, epsilon_schedule, device, kappa_schedule, batch_counter, opt=None):
    robust_err = 0
    total_robust_loss = 0
    total_mse_loss = 0
    total_combined_loss = 0
    
    C = [-torch.eye(10).to(device) for _ in range(10)]
    for y0 in range(10):
        C[y0][y0,:] += 1
    
    for i,data in enumerate(loader,0):
      
        #if i>99:  #calculate only for 100 batches
        #  break      
        
        mse_loss_list = []
        lower_bounds = []
        upper_bounds = []
        
        
        X,y = data
        X,y = X.to(device), y.to(device)
        
        ###### fit loss calculation ######
        yp,_ = model(X)
        fit_loss = nn.CrossEntropyLoss()(yp,y)
    
        ###### robust loss calculation ######
        initial_bound = (X - epsilon_schedule[batch_counter], X + epsilon_schedule[batch_counter])
        bounds = bound_propagation(model, initial_bound)
        robust_loss = 0
        for y0 in range(10):
            if sum(y==y0) > 0:
                lower_bound = interval_based_bound(model, C[y0], bounds, y==y0)
                robust_loss += nn.CrossEntropyLoss(reduction='sum')(-lower_bound, y[y==y0]) / X.shape[0]
                
                robust_err += (lower_bound.min(dim=1)[0] < 0).sum().item() #increment when true label is not winning       
        
        total_robust_loss += robust_loss.item() * X.shape[0]  
        
        ##### MSE Loss #####
        
        #indices_of_layers = [2,4,7,8] #CNN_small
        indices_of_layers = [2,4,6,8,11,13,14] #CNN_medium
        
        
        for i in range(len(indices_of_layers)):
            lower_bounds.append(Flatten()(bounds[indices_of_layers[i]][0])) #lower bounds 
            upper_bounds.append(Flatten()(bounds[indices_of_layers[i]][1])) #upper bounds 
            mse_loss_list.append(nn.MSELoss()(lower_bounds[i], upper_bounds[i]))
        
        mse_loss = mse_loss_list[0] + mse_loss_list[1] + mse_loss_list[2] + mse_loss_list[3] #+  mse_loss_list[4] +  mse_loss_list[5] +  mse_loss_list[6]
        total_mse_loss += mse_loss.item()
        
        ###### combined losss ######
        combined_loss = kappa_schedule[batch_counter]*fit_loss + (1-kappa_schedule[batch_counter])*robust_loss + mse_loss
        #combined_loss = kappa_schedule[batch_counter]*fit_loss + (1-kappa_schedule[batch_counter])*robust_loss
        
        total_combined_loss += combined_loss.item()
        
        batch_counter +=1
         
        if opt:
            opt.zero_grad()
            combined_loss.backward()
            opt.step() 
        
    return robust_err / len(loader.dataset), total_combined_loss / len(loader.dataset), total_mse_loss/ len(loader.dataset)


def epoch_calculate_robust_err (loader, model, epsilon, device):
    robust_err = 0.0
    
    C = [-torch.eye(10).to(device) for _ in range(10)]
    for y0 in range(10):
        C[y0][y0,:] += 1


    for X,y in loader:
        X,y = X.to(device), y.to(device)
        
        initial_bound = (X - epsilon, X + epsilon)
        bounds = bound_propagation(model, initial_bound)

        for y0 in range(10):
            if sum(y==y0) > 0:
                lower_bound = interval_based_bound(model, C[y0], bounds, y==y0)                
                robust_err += (lower_bound.min(dim=1)[0] < 0).sum().item() #increment when true label is not winning       
        
    return robust_err / len(loader.dataset)
        
        


def generate_kappa_schedule_MNIST():

    kappa_schedule = 2000*[1] # warm-up phase
    kappa_value = 1.0
    step = 0.5/58000
    
    for i in range(58000):
        kappa_value = kappa_value - step
        kappa_schedule.append(kappa_value)
    
    return kappa_schedule

def generate_epsilon_schedule_MNIST(epsilon_train):
    
    epsilon_schedule = []
    step = epsilon_train/10000
            
    for i in range(10000):
        epsilon_schedule.append(i*step) #ramp-up phase
    
    for i in range(50000):
        epsilon_schedule.append(epsilon_train)
        
    return epsilon_schedule


def generate_kappa_schedule_CIFAR():

    kappa_schedule = 10000*[1] # warm-up phase
    kappa_value = 1.0
    step = 0.5/340000
    
    for i in range(340000):
        kappa_value = kappa_value - step
        kappa_schedule.append(kappa_value)
    
    return kappa_schedule

def generate_epsilon_schedule_CIFAR(epsilon_train):
    
    epsilon_schedule = []
    step = epsilon_train/150000
            
    for i in range(150000):
        epsilon_schedule.append(i*step) #ramp-up phase
    
    for i in range(200000):
        epsilon_schedule.append(epsilon_train)
        
    return epsilon_schedule
  

def pgd_linf_rand(model, X, y, epsilon, alpha, num_iter, restarts):
    """ Construct PGD adversarial examples on the samples X, with random restarts"""
    max_loss = torch.zeros(y.shape[0]).to(y.device)
    max_delta = torch.zeros_like(X)
    
    for i in range(restarts):
        delta = torch.rand_like(X, requires_grad=True)
        delta.data = delta.data * 2 * epsilon - epsilon
        
        for t in range(num_iter):
            loss = nn.CrossEntropyLoss()(model(X + delta)[0], y)
            loss.backward()
            delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-epsilon,epsilon)
            delta.grad.zero_()
        
        all_loss = nn.CrossEntropyLoss(reduction='none')(model(X+delta)[0],y)
        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)
        
    return max_delta


def epoch_adversarial(model, loader, attack, *args):
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        delta = attack(model, X, y, *args)
        yp = model(X+delta)[0]
        loss = nn.CrossEntropyLoss()(yp,y)
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)

# Loading dataset

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

<torch._C.Generator at 0x7f9c7d458490>

In [0]:
mnist_train = datasets.MNIST("./", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("./", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

# Model

In [0]:
class CNN_large(torch.nn.Module):
    def __init__(self):

        super(CNN_large, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 64, 3, padding=0, stride=1)
        self.relu1 = nn.ReLU() 
        self.conv2 = nn.Conv2d(64, 64, 3, padding=0, stride=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(64, 128, 3, padding=0, stride=2)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(128, 128, 3, padding=0, stride=1)
        self.relu4 = nn.ReLU()
        self.conv5 = nn.Conv2d(128, 128, 3, padding=0, stride=1)
        self.relu5 = nn.ReLU()
        
        self.flat = Flatten()
        self.linear1 = nn.Linear(128*7*7, 200)
        self.relu6 = nn.ReLU()
        self.last_linear = nn.Linear(200, 10)                

        
    def forward(self, x):
        
        hidden_activations = []
        
        x = self.conv1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        
        x = self.conv4(x)
        x = self.relu4(x)
       
        x = self.conv5(x)
        x = self.relu5(x)
        
        x = self.flat(x)

        x = self.linear1(x)
        x = self.relu6(x)
        
        out = self.last_linear(x)

        return out, hidden_activations

In [0]:
model = CNN_large().to(device)

# Training

In [0]:
opt = optim.Adam(model.parameters(), lr=1e-3)

EPSILON = 0.4
EPSILON_TRAIN = 0.4
epsilon_schedule = generate_epsilon_schedule_MNIST(EPSILON_TRAIN)
kappa_schedule = generate_kappa_schedule_MNIST()
batch_counter = 0

print("Epoch   ", "Combined Loss", "MSE Loss", "Test Err", "Test Robust Err", sep="\t")

for t in range(100):
    _, combined_loss, mse_loss = epoch_robust_bound(train_loader, model, epsilon_schedule, device, kappa_schedule, batch_counter, opt)
    
    # check loss and accuracy on test set
    test_err, _ = epoch(test_loader, model, device)
    robust_err = epoch_calculate_robust_err(test_loader, model, EPSILON, device)
    
    batch_counter += 600
    
    if t == 24:  #decrease learning rate after 25 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-4
    
    if t == 40:  #decrease learning rate after 41 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-5
    
    
    print(*("{:.6f}".format(i) for i in (t, combined_loss, mse_loss, test_err, robust_err)), sep="\t")

print ("END OF TRAINING ")

# checking resistance against PGD attack
print("PGD Error Rate:", epoch_adversarial(model, test_loader, pgd_linf_rand, EPSILON, 1e-2, 200, 10)[0])

Epoch   	Combined Loss	MSE Loss	Test Err	Test Robust Err
0.000000	0.001913	0.000162	0.018000	1.000000
1.000000	0.000610	0.000101	0.014500	1.000000
