# Defensive Distillation

The author's of [Distillation as a Defense to Adversarial
Perturbations against Deep Neural Networks](https://arxiv.org/pdf/1511.04508#page=16&zoom=100,416,109) gives a discription of four key ideas behind distilling image classifiers as a defense against adversarial examples. 

1. Start with hard labels (they describe this a series of one-hot vectors, but that is not necessarily how they would be stored in memory).
2. Train the initial model using a traditional procedure, but let the final layer have a softmax with a temperature greater than one.
3. Create a new training set using the outputs of this initial model. That is, instead of starting with hard labels like the previous model, we start with soft labels outputed by the initial model.
4. Train a new model from scratch using the same architecture but with the soft labels (and with the same temperature as before).


In this notebook, you will implement the final 2 steps and evaluate the results.

In [10]:
from huggingface_hub import hf_hub_download

import torch
from torch import nn
import torch.nn.functional as F
import xlab

device = xlab.utils.get_best_device()

### Step 1: Train an image classifer on hard labels

We have already completed this step for you. We trained a simple MLP on the MNIST dataset on for two epochs and achieved a 94.90% accuracy on the test set. Importantly, we use a softmax with temperature ($T=20$) as it is described on our [explainer page](https://xlabaisecurity.com/adversarial/defensive-distillation/).


If interested you can see the output of our training run [here](https://github.com/zroe1/xlab-ai-security/blob/main/models/defensive_distillation/training_output.txt) and the complete code [here](https://github.com/zroe1/xlab-ai-security/tree/main/models/defensive_distillation). You will train your own version of this model for step 5 of this notebook.



In [6]:
# skeleton of the model we trained
class FeedforwardMNIST(nn.Module):
    """Simple 4-layer MLP for MNIST classification"""
    def __init__(self, num_classes=10):
        super(FeedforwardMNIST, self).__init__()
        
        input_size = 28 * 28
        self.fc1 = Linear(input_size, 256)
        self.fc2 = Linear(256, 64)
        self.fc3 = Linear(64, num_classes)
        
        self.flatten = Flatten()
        self.relu = ReLU()
        
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model_path = hf_hub_download(repo_id="uchicago-xlab-ai-security/base-mnist-model", filename="mnist_mlp.pth")
model = torch.load(model_path, map_location=device, weights_only=False)

In [7]:
print(model)

FeedforwardMNIST(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (relu): ReLU()
)


## Step 3: Create new training set

We will be training our distilled model on the labels of the pretrained model you have loaded above. 


The model you loaded however, gives logits, not a temperature-smoothed softmax, so to get the proper labels, you will first have to implement the function below which returns softmax with temperature.





In [8]:
train_loader = xlab.utils.get_mnist_train_loader(batch_size=128, shuffle=True)

In [36]:
def softmax_with_temp(inputs, T):
    out = inputs / T
    return F.softmax(out, dim=1)

In [38]:
for x, y in train_loader:
    x, y = x.to(device), y.to(device)
    batch = get_batch_labels(x, 20)
    break

In [35]:
def get_batch_labels(batch, T):
    outs = model(batch)
    outs = softmax_with_temp(outs, T)
    return outs

The first step in contructing this new dataset is to implement `get_batch_labels` by calling the pretrained model with temperature T. 