<a href="https://colab.research.google.com/github/pawelmorawiecki/ABIDE_experiments/blob/master/CIFAR_interval_bound_prop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, RandomSampler
import torch.utils.data as data_utils
#from utils import epoch, epoch_robust_bound, epoch_calculate_robust_err, Flatten, generate_kappa_schedule_CIFAR, generate_epsilon_schedule_CIFAR

# Utils and helper functions

In [0]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1) 

def epoch(loader, model, device, opt=None):
    """Standard training/evaluation epoch over the dataset"""
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        yp,_ = model(X)
        loss = nn.CrossEntropyLoss()(yp,y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)


def bound_propagation(model, initial_bound):
    l, u = initial_bound
    bounds = []
    bounds.append(initial_bound)
    list_of_layers = list(model.children())
    
    for i in range(len(list_of_layers)):
        layer = list_of_layers[i]
        
        if isinstance(layer, Flatten):
            l_ = Flatten()(l)
            u_ = Flatten()(u)

        elif isinstance(layer, nn.Linear):
            l_ = (layer.weight.clamp(min=0) @ l.t() + layer.weight.clamp(max=0) @ u.t() 
                  + layer.bias[:,None]).t()
            u_ = (layer.weight.clamp(min=0) @ u.t() + layer.weight.clamp(max=0) @ l.t() 
                  + layer.bias[:,None]).t()
            
        elif isinstance(layer, nn.Conv2d):
            l_ = (nn.functional.conv2d(l, layer.weight.clamp(min=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  nn.functional.conv2d(u, layer.weight.clamp(max=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  layer.bias[None,:,None,None])
            
            u_ = (nn.functional.conv2d(u, layer.weight.clamp(min=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  nn.functional.conv2d(l, layer.weight.clamp(max=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) + 
                  layer.bias[None,:,None,None])
            
        elif isinstance(layer, nn.ReLU):
            l_ = l.clamp(min=0)
            u_ = u.clamp(min=0)
            
        bounds.append((l_, u_))
        l,u = l_, u_
    return bounds


def interval_based_bound(model, c, bounds, idx):
    # requires last layer to be linear
    cW = c.t() @ model.last_linear.weight
    cb = c.t() @ model.last_linear.bias
    
    l,u = bounds[-2]
    return (cW.clamp(min=0) @ l[idx].t() + cW.clamp(max=0) @ u[idx].t() + cb[:,None]).t()


def epoch_robust_bound(loader, model, epsilon_schedule, device, kappa_schedule, batch_counter, opt=None):
    robust_err = 0
    total_robust_loss = 0
    total_mse_loss = 0
    total_combined_loss = 0
    
    C = [-torch.eye(10).to(device) for _ in range(10)]
    for y0 in range(10):
        C[y0][y0,:] += 1
    
    for i,data in enumerate(loader,0):
      
        if i>99:  #calculate only for 100 batches
          break      
        
        mse_loss_list = []
        lower_bounds = []
        upper_bounds = []
        
        
        X,y = data
        X,y = X.to(device), y.to(device)
        
        ###### fit loss calculation ######
        yp,_ = model(X)
        fit_loss = nn.CrossEntropyLoss()(yp,y)
    
        ###### robust loss calculation ######
        initial_bound = (X - epsilon_schedule[batch_counter], X + epsilon_schedule[batch_counter])
        bounds = bound_propagation(model, initial_bound)
        robust_loss = 0
        for y0 in range(10):
            if sum(y==y0) > 0:
                lower_bound = interval_based_bound(model, C[y0], bounds, y==y0)
                robust_loss += nn.CrossEntropyLoss(reduction='sum')(-lower_bound, y[y==y0]) / X.shape[0]
                
                robust_err += (lower_bound.min(dim=1)[0] < 0).sum().item() #increment when true label is not winning       
        
        total_robust_loss += robust_loss.item() * X.shape[0]  
        
        ##### MSE Loss #####
        
        #indices_of_layers = [2,4,7,8] #CNN_small
        indices_of_layers = [2,4,6,8,11,13,14] #CNN_medium
        
        
        for i in range(len(indices_of_layers)):
            lower_bounds.append(Flatten()(bounds[indices_of_layers[i]][0])) #lower bounds 
            upper_bounds.append(Flatten()(bounds[indices_of_layers[i]][1])) #upper bounds 
            mse_loss_list.append(nn.MSELoss()(lower_bounds[i], upper_bounds[i]))
        
        mse_loss = mse_loss_list[0] + mse_loss_list[1] + mse_loss_list[2] + mse_loss_list[3] + mse_loss_list[4] + mse_loss_list[5] + mse_loss_list[6]
        total_mse_loss += mse_loss.item()
        
        ###### combined losss ######
        combined_loss = kappa_schedule[batch_counter]*fit_loss + (1-kappa_schedule[batch_counter])*robust_loss + mse_loss
        total_combined_loss += combined_loss.item()
        
        batch_counter +=1
         
        if opt:
            opt.zero_grad()
            combined_loss.backward()
            opt.step() 
        
    return robust_err / len(loader.dataset), total_combined_loss / len(loader.dataset), total_mse_loss/ len(loader.dataset)



def new_epoch_robust_bound(loader, model, epsilon, alpha, device, opt=None):
    
    robust_err = 0
    total_robust_loss = 0
    total_fit_loss = 0
    
    C = [-torch.eye(10).to(device) for _ in range(10)]
    for y0 in range(10):
        C[y0][y0,:] += 1

    for X,y in loader:
        X,y = X.to(device), y.to(device)
        
        ###### fit loss calculation ######
        yp,_ = model(X)
        fit_loss = nn.CrossEntropyLoss()(yp,y)
    
        ###### robust loss calculation ######
        initial_bound = (X - epsilon, X + epsilon)
        bounds = bound_propagation(model, initial_bound, how_many_layers=14)
        robust_loss = 0
        for y0 in range(10):
            if sum(y==y0) > 0:
                lower_bound = interval_based_bound(model, C[y0], bounds, y==y0)
                robust_loss += nn.CrossEntropyLoss(reduction='sum')(-lower_bound, y[y==y0]) / X.shape[0]        
                robust_err += (lower_bound.min(dim=1)[0] < 0).sum().item() #increment when true label is not winning       
        
        total_robust_loss += robust_loss.item() * X.shape[0]
        total_fit_loss += fit_loss.item() * X.shape[0]
        
                ###### combined losss ######
        combined_loss = (1-alpha)*fit_loss + alpha*robust_loss
      
        if opt:
            opt.zero_grad()
            combined_loss.backward()
            opt.step()
            
    return total_fit_loss / len(loader.dataset), total_robust_loss / len(loader.dataset)
        



def epoch_calculate_robust_err (loader, model, epsilon, device):
    robust_err = 0.0
    
    C = [-torch.eye(10).to(device) for _ in range(10)]
    for y0 in range(10):
        C[y0][y0,:] += 1


    for X,y in loader:
        X,y = X.to(device), y.to(device)
        
        initial_bound = (X - epsilon, X + epsilon)
        bounds = bound_propagation(model, initial_bound)

        for y0 in range(10):
            if sum(y==y0) > 0:
                lower_bound = interval_based_bound(model, C[y0], bounds, y==y0)                
                robust_err += (lower_bound.min(dim=1)[0] < 0).sum().item() #increment when true label is not winning       
        
    return robust_err / len(loader.dataset)
        
        


def generate_kappa_schedule_MNIST():

    kappa_schedule = 2000*[1] # warm-up phase
    kappa_value = 1.0
    step = 0.5/58000
    
    for i in range(58000):
        kappa_value = kappa_value - step
        kappa_schedule.append(kappa_value)
    
    return kappa_schedule

def generate_epsilon_schedule_MNIST(epsilon_train):
    
    epsilon_schedule = []
    step = epsilon_train/10000
            
    for i in range(10000):
        epsilon_schedule.append(i*step) #ramp-up phase
    
    for i in range(50000):
        epsilon_schedule.append(epsilon_train)
        
    return epsilon_schedule


def generate_kappa_schedule_CIFAR():

    kappa_schedule = 10000*[1] # warm-up phase
    kappa_value = 1.0
    step = 0.5/340000
    
    for i in range(340000):
        kappa_value = kappa_value - step
        kappa_schedule.append(kappa_value)
    
    return kappa_schedule

def generate_epsilon_schedule_CIFAR(epsilon_train):
    
    epsilon_schedule = []
    step = epsilon_train/150000
            
    for i in range(150000):
        epsilon_schedule.append(i*step) #ramp-up phase
    
    for i in range(200000):
        epsilon_schedule.append(epsilon_train)
        
    return epsilon_schedule 

# Loading dataset

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

<torch._C.Generator at 0x7f1231366050>

In [0]:
BATCH_SIZE = 50
dataset_path = './cifar10'

In [5]:
trainset = datasets.CIFAR10(root=dataset_path, train=True, download=True)

0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/cifar-10-python.tar.gz


100%|█████████▉| 170352640/170498071 [00:39<00:00, 8792945.74it/s]

In [6]:
train_mean = trainset.data.mean(axis=(0,1,2))/255  # [0.49139968  0.48215841  0.44653091]
train_std = trainset.data.std(axis=(0,1,2))/255  # [0.24703223  0.24348513  0.26158784]

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
])
kwargs = {'num_workers': 1, 'pin_memory': True}


train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
    root=dataset_path, train=True, download=True,
    transform=transform_train),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root=dataset_path, train=False, download=True,
    transform=transform_test),
    batch_size=BATCH_SIZE, shuffle=False, **kwargs)

Files already downloaded and verified
Files already downloaded and verified


In [0]:
len(train_loader)

1000

# Model

In [0]:
class CNN_small(torch.nn.Module):
    def __init__(self):

        super(CNN_small, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 16, 4, padding=0, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 4, padding=0, stride=1)
        self.relu2 = nn.ReLU()
        self.flat = Flatten()
        self.linear1 = nn.Linear(32*12*12, 100)
        self.relu3 = nn.ReLU()
        self.last_linear = nn.Linear(100, 10)                
        
    def forward(self, x):
        
        hidden_activations = []
        
        x = self.conv1(x)
        x = self.relu1(x)
        hidden_activations.append(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        hidden_activations.append(x)
        
        x = self.flat(x)
        x = self.linear1(x)
        x = self.relu3(x)
        hidden_activations.append(x)
        
        out = self.last_linear(x)
        hidden_activations.append(out)
        
        return out, hidden_activations

In [0]:
class CNN_medium(torch.nn.Module):
    def __init__(self):

        super(CNN_medium, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, 3, padding=0, stride=1)
        self.relu1 = nn.ReLU() 
        self.conv2 = nn.Conv2d(32, 32, 4, padding=0, stride=2)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 64, 3, padding=0, stride=1)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 64, 4, padding=0, stride=2)
        self.relu4 = nn.ReLU()
        self.flat = Flatten()
        self.linear1 = nn.Linear(64*5*5, 512)
        self.relu5 = nn.ReLU()
        self.linear2 = nn.Linear(512, 512)
        self.relu6 = nn.ReLU()
        self.last_linear = nn.Linear(512, 10)                
        
    def forward(self, x):
        
        hidden_activations = []
        
        x = self.conv1(x)
        x = self.relu1(x)
        hidden_activations.append(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        hidden_activations.append(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        hidden_activations.append(x)
        
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.flat(x)
        hidden_activations.append(x)

        x = self.linear1(x)
        x = self.relu5(x)
        hidden_activations.append(x)

        x = self.linear2(x)
        x = self.relu6(x)
        hidden_activations.append(x)
        
        out = self.last_linear(x)
        hidden_activations.append(out)
        
        return out, hidden_activations

In [0]:
model = CNN_medium().to(device)

# Training

In [0]:
opt = optim.Adam(model.parameters(), lr=1e-3)
for i in range(20):
  train_err, _ = epoch(train_loader, model, device, opt)
  print (train_err)

test_err, _ = epoch(test_loader, model, device)

0.60354
0.46068
0.39654
0.35646
0.3253
0.30708
0.28894
0.28002
0.26874
0.2589
0.2539
0.24514
0.2391
0.23464
0.23322
0.22548
0.22256
0.221
0.2186
0.2134


In [0]:
print (test_err)

0.2185


In [0]:
opt = optim.Adam(model.parameters(), lr=1e-3)

EPSILON = 8/255
EPSILON_TRAIN = 8/255
epsilon_schedule = generate_epsilon_schedule_CIFAR(EPSILON_TRAIN)
kappa_schedule = generate_kappa_schedule_CIFAR()
batch_counter = 0

print("Epoch   ", "Combined Loss", "MSE Loss", "Test Err", "Test Robust Err", sep="\t")

for t in range(350):
    _, combined_loss, mse_loss = epoch_robust_bound(train_loader, model, epsilon_schedule, device, kappa_schedule, batch_counter, opt)
    
    # check loss and accuracy on test set
    test_err, _ = epoch(test_loader, model, device)
    robust_err = epoch_calculate_robust_err(test_loader, model, EPSILON, device)
    
    batch_counter += 1000
    
    if t == 200:  #decrease learning rate after 200 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-4
    
    if t == 250:  #decrease learning rate after 250 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-5
    
    if t == 300:  #decrease learning rate after 300 epochs
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-6
    
    print(*("{:.6f}".format(i) for i in (t, combined_loss, mse_loss, test_err, robust_err)), sep="\t")

Epoch   	Combined Loss	MSE Loss	Test Err	Test Robust Err
0.000000	0.004207	0.000007	0.768200	1.000000
1.000000	0.004033	0.000048	0.751700	1.000000
2.000000	0.003908	0.000034	0.732000	1.000000
3.000000	0.003840	0.000038	0.709100	1.000000
4.000000	0.003864	0.000042	0.698100	1.000000
5.000000	0.003800	0.000046	0.680800	1.000000
6.000000	0.003805	0.000049	0.703300	0.999400
7.000000	0.003781	0.000050	0.686400	0.999900
8.000000	0.003763	0.000051	0.700700	0.998300
9.000000	0.003740	0.000051	0.680700	0.999000
10.000000	0.003720	0.000060	0.672400	0.991300
11.000000	0.003725	0.000060	0.667800	0.995400
12.000000	0.003684	0.000066	0.646200	0.991600
13.000000	0.003637	0.000074	0.636000	0.985600
14.000000	0.003619	0.000074	0.629300	0.978100
15.000000	0.003605	0.000078	0.625000	0.964500
16.000000	0.003592	0.000075	0.622600	0.956500
17.000000	0.003529	0.000087	0.618000	0.970200
18.000000	0.003490	0.000093	0.609800	0.937900
19.000000	0.003516	0.000085	0.611200	0.945000
20.000000	0.003506	0.000085	0.613