In [39]:
%pip install tensorboardX
%pip install torch torchvision

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time

from torchvision import datasets, transforms
from tensorboardX import SummaryWriter



In [40]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 64

np.random.seed(42)
torch.manual_seed(42)


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Simple NN. You can change this if you want.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()


        self.layer1 = nn.Linear(28*28, 50)
        self.layer2 = nn.Linear(50, 50)
        self.layer3 = nn.Linear(50, 50)
        self.layer4 = nn.Linear(50, 10)

        self.nn = nn.Sequential(
          nn.Flatten(),
          self.layer1,
          nn.ReLU(),
          self.layer2,
          nn.ReLU(),
          self.layer3,
          nn.ReLU(),
          self.layer4,
        )

    def forward(self, x):
        return self.nn.forward(x)

    def box_forward_linear(self, Ltensor, Utensor, layer):
      Ltensor_forward = layer.forward(Ltensor)
      Utensor_forward = layer.forward(Utensor)

      return torch.minimum(Ltensor_forward, Utensor_forward), torch.maximum(Ltensor_forward, Utensor_forward)

    def box_forward_relu(self, Ltensor, Utensor):
      return nn.ReLU().forward(Ltensor), nn.ReLU().forward(Utensor)

    def box_forward(self, Ltensor, Utensor):
      Ltensor = nn.Flatten().forward(Ltensor)
      Utensor = nn.Flatten().forward(Utensor)

      Ltensor, Utensor = self.box_forward_linear(Ltensor, Utensor, self.layer1)
      Ltensor, Utensor = self.box_forward_relu(Ltensor, Utensor)

      Ltensor, Utensor = self.box_forward_linear(Ltensor, Utensor, self.layer2)
      Ltensor, Utensor = self.box_forward_relu(Ltensor, Utensor)

      Ltensor, Utensor = self.box_forward_linear(Ltensor, Utensor, self.layer3)
      Ltensor, Utensor = self.box_forward_relu(Ltensor, Utensor)

      Ltensor, Utensor = self.box_forward_linear(Ltensor, Utensor, self.layer4)

      return Ltensor, Utensor

class Normalize(nn.Module):
    def forward(self, x):
        return (x - 0.1307)/0.3081

# Add the data normalization as a first "layer" to the network
# this allows us to search for adverserial examples to the real image, rather than
# to the normalized image
model = nn.Sequential(Normalize(), Net())

model = model.to(device)
model.train()

Sequential(
  (0): Normalize()
  (1): Net(
    (layer1): Linear(in_features=784, out_features=50, bias=True)
    (layer2): Linear(in_features=50, out_features=50, bias=True)
    (layer3): Linear(in_features=50, out_features=50, bias=True)
    (layer4): Linear(in_features=50, out_features=10, bias=True)
    (nn): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=784, out_features=50, bias=True)
      (2): ReLU()
      (3): Linear(in_features=50, out_features=50, bias=True)
      (4): ReLU()
      (5): Linear(in_features=50, out_features=50, bias=True)
      (6): ReLU()
      (7): Linear(in_features=50, out_features=10, bias=True)
    )
  )
)

In [62]:
def train_model(model, num_epochs, enable_defense=False, epsilon=None):
    learning_rate = 0.0001

    opt = optim.Adam(params=model.parameters(), lr=learning_rate)

    ce_loss = torch.nn.CrossEntropyLoss()

    tot_steps = 0

    for epoch in range(1,num_epochs+1):
        t1 = time.time()
        for batch_idx, (x_batch, y_batch) in enumerate(train_loader):

            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            tot_steps += 1
            opt.zero_grad()
            if not enable_defense:
              out = model(x_batch)
              batch_loss = ce_loss(out, y_batch)
              batch_loss.backward()
            else:
              xmin = x_batch - epsilon
              xmax = x_batch + epsilon

              outL, outU = model.box_forward(xmin, xmax)
              # max of incorrect and min of correct

              for i in range(len(outU)):
                outU[i, y_batch[i]] = outL[i, y_batch[i]]

              out = outU

              batch_loss = ce_loss(out, y_batch)
              batch_loss.backward()
            opt.step()

        tot_test, tot_acc = 0.0, 0.0
        for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            out = model(x_batch)
            pred = torch.max(out, dim=1)[1]
            acc = pred.eq(y_batch).sum().item()
            tot_acc += acc
            tot_test += x_batch.size()[0]
        t2 = time.time()

        print('Epoch %d: Accuracy %.5lf [%.2lf seconds]' % (epoch, tot_acc/tot_test, t2-t1))

In [42]:
model = Net()
model = model.to(device)
train_model(model, 10, False)

Epoch 1: Accuracy 0.85570 [13.69 seconds]
Epoch 2: Accuracy 0.89720 [8.86 seconds]
Epoch 3: Accuracy 0.90810 [8.41 seconds]
Epoch 4: Accuracy 0.91590 [9.10 seconds]
Epoch 5: Accuracy 0.92170 [9.13 seconds]
Epoch 6: Accuracy 0.92790 [8.36 seconds]
Epoch 7: Accuracy 0.93070 [8.76 seconds]
Epoch 8: Accuracy 0.93580 [8.94 seconds]
Epoch 9: Accuracy 0.93850 [9.51 seconds]
Epoch 10: Accuracy 0.93930 [8.24 seconds]


The network is now implemented, we can run Box analysis using `network.box_forward()` and passing in the lower and upper values of the box as tensors. We define the robustness as =1 for an example if no adversarial examples exist within the L-infinity ball of size epsilon (as provable by Box), and =0 otherwise. We will look at the average robustness over the training dataset.

In [52]:
def measure_robustness(epsilon, model, test_loader):
  tot_acc = 0
  tot_test = 0

  for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
    out = model(x_batch)
    pred = torch.max(out, dim=1)[1]

    Ltensor = x_batch - epsilon
    Utensor = x_batch + epsilon

    acc = pred.eq(y_batch).sum().item()

    Lbox, Ubox = model.box_forward(Ltensor, Utensor)

    Lcorrect = Lbox[:,torch.argmax(out, dim=1)][1]
    jcorrect = torch.argmax(out, dim=1)
    Uother = Ubox

    for i in range(len(Lcorrect)):
      proved = True
      for j in range(Uother.shape[1]):
        if jcorrect[i] == j:
          continue
        if Uother[i, j] >= Lcorrect[i]:
          proved = False
      if proved:
        tot_acc += 1

    tot_test += x_batch.size()[0]

  print('Epoch NA: Accuracy %.5lf' % (tot_acc/tot_test))

In [44]:
model.eval()
i = 0.01
while i <= 0.1:
  measure_robustness(i, model, train_loader)
  i += 0.01

Epoch NA: Accuracy 0.16957
Epoch NA: Accuracy 0.16123
Epoch NA: Accuracy 0.14982
Epoch NA: Accuracy 0.13988
Epoch NA: Accuracy 0.13195
Epoch NA: Accuracy 0.12767
Epoch NA: Accuracy 0.11668
Epoch NA: Accuracy 0.11543
Epoch NA: Accuracy 0.10300
Epoch NA: Accuracy 0.09772


We will now implement robust training using Box and compare the results.

In [63]:
i = 0.01
while i <= 0.1:
  model = Net()
  model = model.to(device)
  train_model(model, 10, True, i)
  model.eval()
  measure_robustness(i, model, train_loader)
  i += 0.01


Epoch 1: Accuracy 0.84830 [13.31 seconds]
Epoch 2: Accuracy 0.88320 [13.50 seconds]
Epoch 3: Accuracy 0.90150 [13.61 seconds]
Epoch 4: Accuracy 0.91090 [13.44 seconds]
Epoch 5: Accuracy 0.91960 [13.54 seconds]
Epoch 6: Accuracy 0.92490 [13.58 seconds]
Epoch 7: Accuracy 0.93200 [13.59 seconds]
Epoch 8: Accuracy 0.93540 [13.37 seconds]
Epoch 9: Accuracy 0.93830 [13.47 seconds]
Epoch 10: Accuracy 0.94010 [13.44 seconds]
Epoch NA: Accuracy 0.17613
Epoch 1: Accuracy 0.85940 [13.27 seconds]
Epoch 2: Accuracy 0.89490 [13.19 seconds]
Epoch 3: Accuracy 0.90910 [13.33 seconds]
Epoch 4: Accuracy 0.91460 [13.29 seconds]
Epoch 5: Accuracy 0.92150 [13.40 seconds]
Epoch 6: Accuracy 0.92770 [13.30 seconds]
Epoch 7: Accuracy 0.93110 [13.48 seconds]
Epoch 8: Accuracy 0.93410 [13.66 seconds]
Epoch 9: Accuracy 0.93870 [13.25 seconds]
Epoch 10: Accuracy 0.94110 [13.18 seconds]
Epoch NA: Accuracy 0.18318
Epoch 1: Accuracy 0.86540 [13.44 seconds]
Epoch 2: Accuracy 0.89980 [13.79 seconds]
Epoch 3: Accuracy 0.