In [9]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt

In [10]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
batch_size = 64
np.random.seed(42)
torch.manual_seed(42)

## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [36]:
#Implement Neural Network
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_relu_stack(x)
        return x
        
model = NeuralNetwork().to(device)
model.train()
    

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=10, bias=True)
    (5): ReLU()
  )
)

In [31]:
def train_model(model, num_epochs):
    # TODO: implement this function that trains a given model on the MNIST dataset.
    # this is a general-purpose function for both standard training and adversarial training.
    # (toggle enable_defense parameter to switch between training schemes)
    model.train()
    optimizer = optim.Adadelta(model.parameters(), lr=1)
    for i in range(num_epochs):
        for idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            optimizer.step()
            #if idx % 10 == 0:
                #print(loss)

In [32]:
train_model(model, 1)

In [33]:
def standard_test(model, device, test_loader):
    model.eval()
    correct = 0
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True) 
        correct += pred.eq(target.view_as(pred)).sum().item()

    print('\n Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [34]:
standard_test(model, device, test_loader)


 Accuracy: 7807/10000 (78%)



In [192]:
def interval_analysis(lower, upper):
    layers = list(model.linear_relu_stack)
    flatten_lower = lower.view(-1, 28 * 28)
    flatten_upper = upper.view(-1, 28 * 28)
    curr_lower = flatten_lower
    curr_upper = flatten_upper
    for layer in layers:
        if isinstance(layer, nn.Linear):
            curr_lower = torch.matmul(curr_lower, layer.weight.T)+layer.bias
            curr_upper = torch.matmul(curr_upper, layer.weight.T)+layer.bias
        if isinstance(layer, nn.ReLU):
            curr_lower = torch.relu(curr_lower)
            curr_upper = torch.relu(curr_upper)
    return (curr_lower, curr_upper)

In [207]:

for LInf in torch.linspace(0.01, 0.1, 10):
    total = 0
    count = 0
    for data, target in test_loader:
        for image in data:
            total+=1
            lower = torch.clamp(image - LInf, 0, 1)
            upper = torch.clamp(image + LInf, 0, 1)
            output = interval_analysis(lower, upper)
            prediction_lower = output[0].argmax(1, keepdim=True)
            prediction_upper = output[1].argmax(1, keepdim=True)
            #print(output[0][0][prediction_lower], prediction_upper)
            
            if output[0][0][prediction_lower] > output[1][0][prediction_lower]:
                count += 1
    print(1-count/total)

0.08799999999999997
0.0807
0.07469999999999999
0.06779999999999997
0.061899999999999955
0.05679999999999996
0.0524
0.04720000000000002
0.04369999999999996
0.03859999999999997
