# Defensive Distillation

The author's of [Distillation as a Defense to Adversarial
Perturbations against Deep Neural Networks](https://arxiv.org/pdf/1511.04508#page=16&zoom=100,416,109) give a 5 step process for distilling image classifiers as a defense against adversarial examples. In this notebook, you will implement the final 4 steps and evaluate the results.

In [5]:
from huggingface_hub import hf_hub_download

import torch
from torch import nn
import xlab

device = xlab.utils.get_best_device()

### Step 1: Train an image classifer on hard labels

We have already completed this step for you. We trained a simple MLP on the MNIST dataset on for two epochs and achieved a 96.86% accuracy on the test set. If interested you can see the output of our training run [here](https://github.com/zroe1/xlab-ai-security/blob/main/models/defensive_distillation/training_output.txt) and the complete code [here](https://github.com/zroe1/xlab-ai-security/tree/main/models/defensive_distillation). You will train your own version of this model for step 5 of this notebook.



In [6]:
# skeleton of the model we trained
class FeedforwardMNIST(nn.Module):
    """Simple 4-layer MLP for MNIST classification"""
    def __init__(self, num_classes=10):
        super(FeedforwardMNIST, self).__init__()
        
        input_size = 28 * 28
        self.fc1 = Linear(input_size, 256)
        self.fc2 = Linear(256, 64)
        self.fc3 = Linear(64, num_classes)
        
        self.flatten = Flatten()
        self.relu = ReLU()
        
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model_path = hf_hub_download(repo_id="uchicago-xlab-ai-security/base-mnist-model", filename="mnist_mlp.pth")
model = torch.load(model_path, map_location=device, weights_only=False)

In [7]:
print(model)

FeedforwardMNIST(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (relu): ReLU()
)
