In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
import torch.utils.data as data_utils
import numpy as np
from random import randint
from sklearn.impute import SimpleImputer  # used to input missing values

# Constants

In [2]:
BATCH_SIZE = 100
MISSING_SQUARE = 13 #size of missing square
IMAGE_SHAPE = (28,28)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

<torch._C.Generator at 0x7f6f3ca9abd0>

# Load data and create dataloaders

In [3]:
data_after_imputation =  np.load("./data/MNIST_data_imputation.npy")
labels = np.load("./data/MNIST_labels.npy")
mask = np.load("./data/MNIST_mask.npy")

In [4]:
# create dataloaders
dataset = TensorDataset(torch.from_numpy(data_after_imputation), torch.from_numpy(mask), torch.from_numpy(labels))
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Utils and helper functions

In [5]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1) 

In [6]:
def bound_propagation(model, initial_bound):
    l, u = initial_bound
    bounds = []
    bounds.append(initial_bound)
    list_of_layers = list(model.children())
    
    for i in range(len(list_of_layers)-1): #all layers except the last one
        layer = list_of_layers[i]
        
        if isinstance(layer, Flatten):
            l_ = Flatten()(l)
            u_ = Flatten()(u)

        elif isinstance(layer, nn.Linear):
            l_ = (layer.weight.clamp(min=0) @ l.t() + layer.weight.clamp(max=0) @ u.t() 
                  + layer.bias[:,None]).t()
            u_ = (layer.weight.clamp(min=0) @ u.t() + layer.weight.clamp(max=0) @ l.t() 
                  + layer.bias[:,None]).t()
            
        elif isinstance(layer, nn.Conv2d):
            l_ = (nn.functional.conv2d(l, layer.weight.clamp(min=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  nn.functional.conv2d(u, layer.weight.clamp(max=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  layer.bias[None,:,None,None])
            
            u_ = (nn.functional.conv2d(u, layer.weight.clamp(min=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) +
                  nn.functional.conv2d(l, layer.weight.clamp(max=0), bias=None, 
                                       stride=layer.stride, padding=layer.padding,
                                       dilation=layer.dilation, groups=layer.groups) + 
                  layer.bias[None,:,None,None])
            
        elif isinstance(layer, nn.ReLU):
            l_ = l.clamp(min=0)
            u_ = u.clamp(min=0)
            
        bounds.append((l_, u_))
        l,u = l_, u_
    return bounds

In [7]:
def epoch_with_intervals(loader, model, epsilon_schedule, device, batch_counter, opt=None):

    total_loss, total_err = 0.,0.
    for X, mask, y in loader:
        X, mask, y = X.float().to(device), mask.float().to(device), y.long().to(device)

        mask = mask * epsilon_schedule[batch_counter]
     
        lower_bound = X - mask
        lower_bound = torch.clamp(lower_bound, min=0.0, max=1.0)
        upper_bound = X + mask
        upper_bound = torch.clamp(upper_bound, min=0.0, max=1.0)
        
        bounds = bound_propagation(model, (lower_bound,upper_bound))
        bounds_concatenated = torch.cat((bounds[-1][0],bounds[-1][1]), dim=1)
        
        predictions = model.intervals_combined(bounds_concatenated)
        loss = nn.CrossEntropyLoss()(predictions,y)
        batch_counter +=1 
        
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (predictions.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)

In [8]:
def generate_epsilon_schedule_MNIST(epsilon_train):
    
    epsilon_schedule = []
    step = epsilon_train/10000
            
    for i in range(10000):
        epsilon_schedule.append(i*step) #ramp-up phase
    
    for i in range(50000):
        epsilon_schedule.append(epsilon_train)
        
    return epsilon_schedule

# Model

In [9]:
class CNN_small(torch.nn.Module):
    def __init__(self):

        super(CNN_small, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 16, 4, padding=0, stride=2)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 4, padding=0, stride=1)
        self.relu2 = nn.ReLU()
        self.flat = Flatten()
        self.linear1 = nn.Linear(32*10*10, 100)
        self.relu3 = nn.ReLU()        
        self.last_linear = nn.Linear(100, 10)                 
        
        self.intervals_combined = nn.Linear(10+10, 10) #layer for combining upper and lower bounds of intervals
        
    def forward(self, x):
                
        x = self.conv1(x)
        x = self.relu1(x)

        
        x = self.conv2(x)
        x = self.relu2(x)

        
        x = self.flat(x)
        x = self.linear1(x)
        x = self.relu3(x)
        
        out = self.last_linear(x)

        
        return out

# Training

In [10]:
EPSILON = 0.2
epsilon_schedule = generate_epsilon_schedule_MNIST(EPSILON)
batch_counter = 0

model = CNN_small().to(device)
opt = optim.Adam(model.parameters(), lr=1e-3)

for t in range(20):
    
    train_err, loss = epoch_with_intervals(train_dataloader, model, epsilon_schedule, device, batch_counter, opt)  
    batch_counter += 600
    print (train_err)
    print (loss)

0.2056
0.635987416903178
0.10241666666666667
0.30714157688121
0.08556666666666667
0.2541016906748215
0.07623333333333333
0.22268366693208616
0.0694
0.2030802104063332
0.0648
0.18687705857058365
0.05976666666666667
0.1714403611732026
0.05586666666666667
0.15989502164224784
0.05335
0.1505830330774188
0.049883333333333335
0.13918364140825967
0.047516666666666665
0.1328047810215503
0.04438333333333333
0.12284485846447447
0.041433333333333336
0.11607960573087137
0.0385
0.10768518236155311
0.03605
0.10069064659066498
0.03481666666666667
0.09657419217129548
0.0327
0.09121930407360196
0.029716666666666666
0.08349764007764558
0.027566666666666666
0.07460733168758452
0.024633333333333333
0.06894644362696757


# Testing

In [11]:
test_data_after_imputation =  np.load("./data/MNIST_test_data_imputation.npy")
test_labels = np.load("./data/MNIST_test_labels.npy")
test_mask = np.load("./data/MNIST_test_mask.npy")

In [12]:
test_dataset = TensorDataset(torch.from_numpy(test_data_after_imputation), torch.from_numpy(test_mask), torch.from_numpy(test_labels))
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [13]:
epsilon_schedule = [0.2]*1000
batch_counter = 0
test_err, loss = epoch_with_intervals(test_dataloader, model, epsilon_schedule, device, batch_counter)  

print (test_err)

0.0804
