# Defensive Distillation

The authors of [Distillation as a Defense to Adversarial
Perturbations against Deep Neural Networks](https://arxiv.org/pdf/1511.04508) give a description of the four key ideas behind distilling image classifiers as a defense against adversarial examples. 

1. Start with hard labels (they describe this a series of one-hot vectors, but that is not necessarily how they would be stored in memory).
2. Train the initial model using a traditional procedure, but let the final layer have a softmax with a temperature greater than one.
3. Create a new training set using the outputs of this initial model. That is, instead of starting with hard labels like the previous model, we start with soft labels outputed by the initial model.
4. Train a new model from scratch using the same architecture but with the soft labels (and with the same temperature as before).


In this notebook, you will implement the final 2 steps and evaluate the results.

In [None]:
# IF YOU ARE IN COLAB OR HAVE NOT INSTALLED `xlab-security`
!pip install xlab-security # should not take more than a minute or two to run

In [None]:
from huggingface_hub import hf_hub_download

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Flatten, Linear, ReLU
from torch.utils.data import DataLoader, TensorDataset
import xlab

device = xlab.utils.get_best_device()

### Preliminaries: Train an image classifer on hard labels

We have already completed this step for you. We trained a simple MLP on the MNIST dataset on for two epochs and achieved a 94.90% accuracy on the test set. Importantly, we use a softmax with temperature ($T=20$) as described in our [explainer page](https://xlabaisecurity.com/adversarial/defensive-distillation/).


If interested you can see the output of our training run [here](https://github.com/zroe1/xlab-ai-security/blob/main/models/defensive_distillation/training_output.txt) and the complete code [here](https://github.com/zroe1/xlab-ai-security/tree/main/models/defensive_distillation). You will train your own version of this model for step 5 of this notebook.



In [None]:
# skeleton of the model we trained
class FeedforwardMNIST(nn.Module):
    """Simple 4-layer MLP for MNIST classification"""

    def __init__(self, num_classes=10):
        super(FeedforwardMNIST, self).__init__()

        input_size = 28 * 28
        self.fc1 = Linear(input_size, 256)
        self.fc2 = Linear(256, 64)
        self.fc3 = Linear(64, num_classes)

        self.flatten = Flatten()
        self.relu = ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model_path = hf_hub_download(
    repo_id="uchicago-xlab-ai-security/base-mnist-model",
    filename="mnist_mlp_temp_30.pth",
)
model = torch.load(model_path, map_location=device, weights_only=False)

In [None]:
print(model)

### Benchmark on PGD

We will benchmark on the pretained model you just loaded. Note that the model already has most of the resistence to adversarial attacks that you will see in this notebook. This is because we trained the model with a temperature greater than one which already accomplishes most of the smoothing. For comparison, the end of the notebook includes code for loading and benchmarking a model trained with a temperature of one, which you will see has almost 0% robustness against 100 iterations of PGD.

When you train your distilled model you should only see a small reduction in attack success. This is actually expected! The authors of the original paper note that the distilled model should in theory converge to the original model but emperically it can offer some additional protection. 

If the original model is responsible for most of the protection you may wonder why we don't have you implement it. The reason we don't have you train the original model in this notebook is because it is extremely similar to what you will do in step 4. If you are interested, you should find it fairly easy to replace our pretrained model with your own implementation. 



In [None]:
num_test_imgs = 100
imgs, ys = xlab.utils.load_mnist_test_samples(num_test_imgs)
loss_fn = torch.nn.CrossEntropyLoss()

num_success = 0

for img, y in zip(imgs, ys):
    adv_x = xlab.utils.PGD(
        model, loss_fn, img, y, epsilon=12 / 255, alpha=2 / 255, num_iters=20
    )
    adv_y = torch.argmax(model(adv_x))

    if adv_y.item() != y:
        num_success += 1

print(f"{(num_success / num_test_imgs) * 100:.4}% of attacks succeded")

## Task #1 and #2: Create new training set

We will be training our distilled model on the labels of the pretrained model you have loaded above. 


The model you loaded however, gives logits, not a temperature-smoothed softmax, so to get the proper labels, you will first have to implement the function below which returns softmax with temperature.


<details>
<summary>🔐 <b>Solution for Task #1</b></summary>

```python
def softmax_with_temp(inputs, T):
    """Applies temperature-scaled softmax to inputs
    Args:
        inputs [batch, features]: Input logits tensor.
        T (float): Temperature scaling parameter.
    Returns:
        [batch, features]: Temperature-scaled softmax probabilities.
    """
    out = inputs / T
    return F.softmax(out, dim=1)
```
</details>

In [None]:
def softmax_with_temp(inputs, T):
    """Applies temperature-scaled softmax to inputs
    Args:
        inputs [batch, features]: Input logits tensor.
        T (float): Temperature scaling parameter.
    Returns:
        [batch, features]: Temperature-scaled softmax probabilities.
    """

    raise NotImplementedError("softmax_with_temp hasn't been implemented.")

In [None]:
_ = xlab.tests.distillation.task1(softmax_with_temp)

Now you will find the labels for each batch by calling the model and running it's outputs through `softmax_with_temp`.

<details>
<summary>🔐 <b>Solution for Task #2</b></summary>

```python
def get_batch_labels(batch, T):
    """Generates temperature-scaled probability distributions for a batch
    Args:
        batch [batch, *]: Input batch tensor.
        T (float): Temperature scaling parameter.
    Returns:
        [batch, num_classes]: Temperature-scaled softmax probabilities.
    """
    outs = model(batch)
    outs = softmax_with_temp(outs, T)
    return outs
```
</details>

In [None]:
def get_batch_labels(batch, T):
    """Generates temperature-scaled probability distributions for a batch
    Args:
        batch [batch, *]: Input batch tensor.
        T (float): Temperature scaling parameter.
    Returns:
        [batch, num_classes]: Temperature-scaled softmax probabilities.
    """

    raise NotImplementedError("get_batch_labels hasn't been implemented.")

In [None]:
_ = xlab.tests.distillation.task2(get_batch_labels, model)

In [None]:
train_loader = xlab.utils.get_mnist_train_loader(batch_size=128, shuffle=True)

In [None]:
imgs = []
soft_labels = []

with torch.no_grad():
    for x_batch, _ in train_loader:
        x_batch = x_batch.to(device)
        soft_labels_batch = get_batch_labels(x_batch, 30)

        imgs.append(x_batch.cpu())
        soft_labels.append(soft_labels_batch.cpu())

all_images = torch.cat(imgs, dim=0)
all_soft_labels = torch.cat(soft_labels, dim=0)
soft_label_dataset = TensorDataset(all_images, all_soft_labels)

batch_size = 128
soft_label_loader = DataLoader(
    soft_label_dataset,
    batch_size=batch_size,
    shuffle=True,
)

The first step in contructing this new dataset is to implement `get_batch_labels` by calling the pretrained model with temperature T. 

## Task #3 and #4: Train distilled model

In [None]:
# skeleton of the model we trained
distilled = FeedforwardMNIST().to(device)

The optimization problem from the original paper was formalized by the authors using the following equation:

$$
\arg\min_{\theta_F} -\frac{1}{|\mathcal{X}|} \sum_{X \in \mathcal{X}} \sum_{i \in 0..N} F_i(X) \log F_i^d(X)
$$

The loss for a single example is simply cross entropy loss with soft labels:

$$
\mathcal{L}(X) = -\sum_{i \in 0..N} F_i(X) \log F_i^d(X)
$$

<details>
<summary>🔐 <b>Solution for Task #3</b></summary>

```python
def cross_entropy_loss_soft(soft_labels, probs):
    """Computes cross-entropy loss between soft labels and predicted probabilities
    Args:
        soft_labels [batch, num_classes]: Target probability distributions.
        probs [batch, num_classes]: Predicted probability distributions.
    Returns:
        scalar tensor: Normalized cross-entropy loss value.
    """
    assert soft_labels.shape == probs.shape
    batch_size = soft_labels.shape[0]

    log_probs = torch.log(probs)
    return torch.sum(-1 * log_probs *  soft_labels) / batch_size
```
</details>

In [None]:
def cross_entropy_loss_soft(soft_labels, probs):
    """Computes cross-entropy loss between soft labels and predicted probabilities
    Args:
        soft_labels [batch, num_classes]: Target probability distributions.
        probs [batch, num_classes]: Predicted probability distributions.
    Returns:
        scalar tensor: Normalized cross-entropy loss value.
    """

    assert soft_labels.shape == probs.shape
    raise NotImplementedError("cross_entropy_loss_soft hasn't been implemented.")

In [None]:
_ = xlab.tests.distillation.task3(cross_entropy_loss_soft)

Now you will fill in the function to train your distilled model. Most of this work has already been done for you.

Note that there are no tests for this task. You can evaluate the quality of your solution by benchmarking your model in the following section.

<details>
<summary>🔐 <b>Solution for Task #4</b></summary>

```python
def train(model, epochs, train_loader, T):
    """Trains model using soft label cross-entropy loss with temperature scaling
    Args:
        model: Neural network model to train.
        epochs (int): Number of training epochs.
        train_loader: DataLoader providing batches of images and soft labels.
        T (float): Temperature scaling parameter for softmax.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

    for epoch in range(epochs):
        for i, (img, soft_label) in enumerate(train_loader):
            optimizer.zero_grad()

            # 1. get logits from model
            img, soft_label = img.to(device), soft_label.to(device)
            logits = model(img)

            # 2. process the logits with softmax_with_temp
            out = softmax_with_temp(logits, T)

            # 3. compute batch loss
            batch_loss = cross_entropy_loss_soft(soft_label, out)
    
            if i % 50==0:
                print(f"Epoch #{epoch + 1}: batch loss = {batch_loss.item():.4f}")
    
            batch_loss.backward()
            optimizer.step()
```
</details>

In [None]:
def train(model, epochs, train_loader, T):
    """Trains model using soft label cross-entropy loss with temperature scaling
    Args:
        model: Neural network model to train.
        epochs (int): Number of training epochs.
        train_loader: DataLoader providing batches of images and soft labels.
        T (float): Temperature scaling parameter for softmax.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

    for epoch in range(epochs):
        for i, (img, soft_label) in enumerate(train_loader):
            optimizer.zero_grad()

            ######### YOUR CODE STARTS HERE #########
            # 1. get logits from model
            # 2. process the logits with softmax_with_temp
            # 3. compute batch loss
            ########## YOUR CODE ENDS HERE ##########

            if i % 50 == 0:
                print(f"Epoch #{epoch + 1}: batch loss = {batch_loss.item():.4f}")

            batch_loss.backward()
            optimizer.step()

In [None]:
train(distilled, 3, soft_label_loader, 30)

## Benchmarking our Defense

Below you should see that the clean accuracy is comparable to the original 94.90% accuracy. The attack success rate should be a bit below the success rate of the pretrained model. As we explained above, a lot of the protection comes from the original temperature smoothing, so you should not be surprised if the success rate is only slightly below the original pretrained model.

In [None]:
clean_acc = xlab.utils.evaluate_mnist_accuracy(distilled)
print(f"Clean accuracy of distilled model: {clean_acc * 100:.2f}%")

In [None]:
num_test_imgs = 100
imgs, ys = xlab.utils.load_mnist_test_samples(num_test_imgs)
loss_fn = torch.nn.CrossEntropyLoss()

num_success = 0

for img, y in zip(imgs, ys):
    adv_x = xlab.utils.PGD(
        distilled, loss_fn, img, y, epsilon=12 / 255, alpha=2 / 255, num_iters=20
    )
    adv_y = torch.argmax(distilled(adv_x))

    if adv_y.item() != y:
        num_success += 1

print(f"{(num_success / num_test_imgs) * 100:.4}% of attacks succeded")

In [None]:
model_path = hf_hub_download(
    repo_id="uchicago-xlab-ai-security/base-mnist-model",
    filename="mnist_mlp_temp_1.pth",
)
standard = torch.load(model_path, map_location=device, weights_only=False)

## Benchmarking a Traditional Model 

For reference, below you will the clean accuracy and attack success rate of a model with the same architecture trained with a softmax temperature of 1.

In [None]:
clean_acc = xlab.utils.evaluate_mnist_accuracy(standard)
print(f"Clean accuracy of standard model: {clean_acc * 100:.2f}%")

num_test_imgs = 30
imgs, ys = xlab.utils.load_mnist_test_samples(num_test_imgs)
loss_fn = torch.nn.CrossEntropyLoss()

num_success = 0

for img, y in zip(imgs, ys):
    adv_x = xlab.utils.PGD(
        standard, loss_fn, img, y, epsilon=12 / 255, alpha=2 / 255, num_iters=10
    )
    adv_y = torch.argmax(standard(adv_x))

    if adv_y.item() != y:
        num_success += 1

print(f"{(num_success / num_test_imgs) * 100:.4}% of attacks succeded")