Adversarial Examples Lab
==============================

This lab focus on showing you an example how to create Adversarial Example towards image classifier. To be specific, we will use Fast Gradient Sign Attack to exploit the vulnerability of an MNIST classifier.

Through this lab, one can find that adding imperceptible perturbations to an image can cause drastically different model performance.

Credit: [PyTorch Tutorial](https://pytorch.org/tutorials/beginner/fgsm_tutorial.html)

### Threat Model ###

A white-box attack assumes the attacker has full knowledge and access to the model, including architecture, inputs, outputs, and weights. A black-box attack assumes the attacker only has access to the inputs and outputs of the model, and knows nothing about the underlying architecture or weights. 

In this Lab, we are focusing on white-box attack with the goal of misclassification.

### Attack Method ###

We're using Fast Gradient Sign Attack ([Paper](https://arxiv.org/abs/1412.6572), [Blog](https://jaketae.github.io/study/fgsm/)) as our attack method.(The famous panda example is describing this methods too!) For the detail of this method, please look into the paper or the blog linked above.


![](https://pytorch.org/tutorials/_static/img/fgsm_panda_image.png)

Let's start!
=========================

## Import Relevant Package

In [None]:
%matplotlib inline
import gdown
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm


## Define inputs

-   `epsilons` - List of epsilon values to use for the run. `epsilons` are the $\epsilon$ described in the FGSM attack method. Note that $\epsilon=0$ case represents the original test accuracy, with no attack.
-   `pretrained_model` - path to the pretrained MNIST model from PyTorch, which will be downloaded under the current directory.


In [None]:
epsilons = [0, .05, .1, .15, .2, .25, .3]
pretrained_model = "./pretrained_model/lenet_mnist_model.pth"

os.makedirs('./pretrained_model', exist_ok=True)
url = 'https://drive.google.com/uc?id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl'
gdown.download(url, pretrained_model, quiet=False)
# Set random seed for reproducibility
torch.manual_seed(42)

## Model Under Attack

The model under attack is the same MNIST model from [pytorch/examples/mnist](https://github.com/pytorch/examples/tree/master/mnist).

In this section, we define the model and dataloader, then initialize the model and load the pretrained weights.


In [None]:
# LeNet Model definition
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

# MNIST Test dataset and dataloader declaration
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
            ])),
        batch_size=1, shuffle=True)

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

# Initialize the network
model = Net().to(device)

# Load the pretrained model
model.load_state_dict(torch.load(pretrained_model, map_location=device, weights_only=True))

# Set the model in evaluation mode. In this case this is for the Dropout layers
model.eval()

## FGSM Attack

Now, we can define the function that creates the adversarial examples by
perturbing the original inputs. The `fgsm_attack` function takes three
inputs, *image* is the original clean image ($x$), *epsilon* is the
pixel-wise perturbation amount ($\epsilon$), and *data\_grad* is
gradient of the loss w.r.t the input image
($\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)$). The function then
creates perturbed image as

$$perturbed\_image = image + epsilon*sign(data\_grad) = x + \epsilon * sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))$$

Finally, in order to maintain the original range of the data, the
perturbed image is clipped to range $[0,1]$.


In [None]:
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

# restores the tensors to their original scale
def denorm(batch, mean=[0.1307], std=[0.3081]):
    """
    Convert a batch of tensors to their original scale.

    Args:
        batch (torch.Tensor): Batch of normalized tensors.
        mean (torch.Tensor or list): Mean used for normalization.
        std (torch.Tensor or list): Standard deviation used for normalization.

    Returns:
        torch.Tensor: batch of tensors without normalization applied to them.
    """
    if isinstance(mean, list):
        mean = torch.tensor(mean).to(device)
    if isinstance(std, list):
        std = torch.tensor(std).to(device)

    return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)

We define a wrapper function to launch FGSM attack on a test dataset against the model we passed. This function returns the attacked accuracy and some successful adversarial examples to be visualized later.


In [None]:
def fgsm_attack_launcher(model, device, test_loader, epsilon):

    # Accuracy counter
    correct = 0
    adv_examples = []

    # Loop over all examples in test set
    for data, target in tqdm(test_loader, desc=f"Launching attack with epsilon={epsilon}"):

        # Send the data and label to the device
        data, target = data.to(device), target.to(device)

        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

        # If the initial prediction is wrong, don't bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # Calculate the loss
        loss = F.nll_loss(output, target)

        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect ``datagrad``
        data_grad = data.grad.data

        # Restore the data to its original scale
        data_denorm = denorm(data)

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)

        # Reapply normalization
        perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)

        # Re-classify the perturbed image
        output = model(perturbed_data_normalized)

        # Check for success
        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        if final_pred.item() == target.item():
            correct += 1
            # Special case for saving 0 epsilon examples
            if epsilon == 0 and len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
        else:
            # Save some adv examples for visualization later
            if len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )

    # Calculate final accuracy for this epsilon
    final_acc = correct/float(len(test_loader))
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader)} = {final_acc}")

    # Return the accuracy and an adversarial example
    return final_acc, adv_examples

## Launch Attack

In this section, we run a full test step for each epsilon value in the *epsilons* input. For each epsilon we also save the final accuracy and some successful adversarial examples to be plotted in the coming sections.


In [None]:
accuracies = []
examples = []

# Launch attack for each epsilon
for eps in epsilons:
    acc, ex = fgsm_attack_launcher(model, device, test_loader, eps)
    accuracies.append(acc)
    examples.append(ex)

## Visualization

- Accuracy vs Epsilon

First, let's look into the accuracy versus epsilon plot. As epsilon increases, we expect the test accuracy to decrease.

In [None]:
plt.figure(figsize=(5,5))
plt.plot(epsilons, accuracies, "*-")
plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()

- Sample Adversarial Examples

As epsilon increases, the test accuracy decreases but the perturbations become more easily perceptible.

In [None]:
# Plot several examples of adversarial samples at each epsilon
cnt = 0
plt.figure(figsize=(8,10))
for i in range(len(epsilons)):
    for j in range(len(examples[i])):
        cnt += 1
        plt.subplot(len(epsilons),len(examples[0]),cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0:
            plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)
        orig,adv,ex = examples[i][j]
        plt.title(f"{orig} -> {adv}")
        plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()