In [None]:
%load_ext autoreload
%autoreload 2
import torch
from torch.optim import Adam
from matplotlib import pyplot as plt
from utils import get_mnist_data
from models import ConvNN
from training_and_evaluation import train_model, predict_model
from attacks import fast_gradient_attack
from  torch.nn.functional import cross_entropy
from typing import Tuple

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Part 2: Adversarial training
In this notebook we perform advcersarial training on the convolutional neural network from Part 1.

In [None]:
mnist_trainset = get_mnist_data(train=True)
mnist_testset = get_mnist_data(train=False)
use_cuda = torch.cuda.is_available() #and False

model = ConvNN()
if use_cuda:
    model = model.cuda()

epochs = 2
batch_size = 128
test_batch_size = 1000  # feel free to change this
lr = 1e-3

opt = Adam(model.parameters(), lr=lr)

attack_args = {'norm': "2", "epsilon": 5}

### Loss function



In [None]:
def loss_function(x: torch.Tensor, y: torch.Tensor, model: torch.nn.Module,  **attack_args) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Loss function used for adversarial training. First computes adversarial examples on the input batch via fast_gradient_attack and then computes the logits
    and the loss on the adversarial examples.
    Parameters
    ----------
    x: torch.Tensor of shape [B, C, N, N], where B is the batch size, C is the number of channels, and N is the image width/height.
        The input batch to certify.
    y: torch.Tensor of shape [B, 1].
        The labels of the input batch.
    model: torch.nn.Module
        The classifier to be evaluated.
    attack_args: additional arguments passed to the adversarial attack function.
    
    Returns
    -------
    Tuple containing
        * loss_pert: torch.Tensor, shape [B,]
            The loss obtained on the adversarial examples.
        * logits_pert: torch.Tensor, shape [B, K], where K is the number of classes.
            The logits obtained on the adversarial examples
    """
    x = x.requires_grad_()
    logits = model(x).cpu()
    x_pert = fast_gradient_attack(logits=logits, x=x, y=y, epsilon=attack_args["epsilon"], norm=attack_args["norm"],
                         loss_fn=torch.nn.functional.cross_entropy)
    model.zero_grad()
    
    
    #y_pert_l2 = torch.argmax(model(x_pert_l2).cpu(), dim=1)
    logits_pert = model(x_pert).cpu()
    
    loss_pert = cross_entropy(logits_pert, y)
    return loss_pert, logits_pert

In [None]:
losses, accuracies = train_model(model, mnist_trainset, batch_size=batch_size, loss_function=loss_function, optimizer=opt, loss_args=attack_args, epochs=epochs)

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

In [None]:
torch.save(model.state_dict(), "models/adversarial_training.checkpoint")

In [None]:
fig = plt.figure(figsize=(10,3))
plt.subplot(121)
plt.plot(losses)
plt.xlabel("Iteration")
plt.ylabel("Training Loss")
plt.subplot(122)
plt.plot(accuracies)
plt.xlabel("Iteration")
plt.ylabel("Training Accuracy")
plt.show()

In [None]:
clean_accuracy = predict_model(model, mnist_testset, batch_size=test_batch_size, attack_function=None)

In [None]:
perturbed_accuracy = predict_model(model, mnist_testset, batch_size=test_batch_size, attack_function=fast_gradient_attack, attack_args=attack_args)

In [None]:
clean_accuracy # ours: 0.6959999799728394, template: 0.6869999766349792

In [None]:
perturbed_accuracy # ours: 0.8047000169754028, template: 0.9200999736785889 