# Defensive Distillation

The author's of [Distillation as a Defense to Adversarial
Perturbations against Deep Neural Networks](https://arxiv.org/pdf/1511.04508#page=16&zoom=100,416,109) gives a discription of four key ideas behind distilling image classifiers as a defense against adversarial examples. 

1. Start with hard labels (they describe this a series of one-hot vectors, but that is not necessarily how they would be stored in memory).
2. Train the initial model using a traditional procedure, but let the final layer have a softmax with a temperature greater than one.
3. Create a new training set using the outputs of this initial model. That is, instead of starting with hard labels like the previous model, we start with soft labels outputed by the initial model.
4. Train a new model from scratch using the same architecture but with the soft labels (and with the same temperature as before).


In this notebook, you will implement the final 2 steps and evaluate the results.

In [2]:
from huggingface_hub import hf_hub_download

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Flatten, Linear, ReLU
import xlab

device = xlab.utils.get_best_device()

### Step 1: Train an image classifer on hard labels

We have already completed this step for you. We trained a simple MLP on the MNIST dataset on for two epochs and achieved a 94.90% accuracy on the test set. Importantly, we use a softmax with temperature ($T=20$) as it is described on our [explainer page](https://xlabaisecurity.com/adversarial/defensive-distillation/).


If interested you can see the output of our training run [here](https://github.com/zroe1/xlab-ai-security/blob/main/models/defensive_distillation/training_output.txt) and the complete code [here](https://github.com/zroe1/xlab-ai-security/tree/main/models/defensive_distillation). You will train your own version of this model for step 5 of this notebook.



In [3]:
# skeleton of the model we trained
class FeedforwardMNIST(nn.Module):
    """Simple 4-layer MLP for MNIST classification"""
    def __init__(self, num_classes=10):
        super(FeedforwardMNIST, self).__init__()
        
        input_size = 28 * 28
        self.fc1 = Linear(input_size, 256)
        self.fc2 = Linear(256, 64)
        self.fc3 = Linear(64, num_classes)
        
        self.flatten = Flatten()
        self.relu = ReLU()
        
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model_path = hf_hub_download(repo_id="uchicago-xlab-ai-security/base-mnist-model", filename="mnist_mlp.pth")
model = torch.load(model_path, map_location=device, weights_only=False)

In [4]:
print(model)

FeedforwardMNIST(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (relu): ReLU()
)


## Step 3: Create new training set

We will be training our distilled model on the labels of the pretrained model you have loaded above. 


The model you loaded however, gives logits, not a temperature-smoothed softmax, so to get the proper labels, you will first have to implement the function below which returns softmax with temperature.

In [5]:
train_loader = xlab.utils.get_mnist_train_loader(batch_size=128, shuffle=True)

In [6]:
def softmax_with_temp(inputs, T):
    out = inputs / T
    return F.softmax(out, dim=1)

In [7]:
def get_batch_labels(batch, T):
    outs = model(batch)
    outs = softmax_with_temp(outs, T)
    return outs

In [8]:
from torch.utils.data import DataLoader, TensorDataset

In [9]:
imgs = []
soft_labels= []

with torch.no_grad():
    for x_batch, _ in train_loader:
        x_batch = x_batch.to(device)
        soft_labels_batch = get_batch_labels(x_batch, 20)
    
        imgs.append(x_batch.cpu())
        soft_labels.append(soft_labels_batch.cpu())

all_images = torch.cat(imgs, dim=0)
all_soft_labels = torch.cat(soft_labels, dim=0)
soft_label_dataset = TensorDataset(all_images, all_soft_labels)

batch_size = 128
soft_label_loader = DataLoader(
    soft_label_dataset, 
    batch_size=batch_size, 
    shuffle=True,
)

In [10]:
# print(f"Created soft-label dataset with {len(all_images)} samples")
# print(f"Image shape: {all_images.shape}")
# print(f"Soft label shape: {all_soft_labels.shape}")

In [11]:
soft_label_loader

<torch.utils.data.dataloader.DataLoader at 0x138cf4490>

The first step in contructing this new dataset is to implement `get_batch_labels` by calling the pretrained model with temperature T. 

## Step 4: Train distilled model

The optimization problem from the original paper was formalized by the authors using the following equation:

$$
\arg\min_{\theta_F} -\frac{1}{|\mathcal{X}|} \sum_{X \in \mathcal{X}} \sum_{i \in 0..N} F_i(X) \log F_i^d(X)
$$

The loss for a single example is simply cross entropy loss with soft labels:

$$
\mathcal{L}(X) = -\sum_{i \in 0..N} F_i(X) \log F_i^d(X)
$$

In [12]:
# skeleton of the model we trained
distilled =  FeedforwardMNIST().to(device)

In [13]:
def cross_entropy_loss_soft(soft_labels, probs):
    # print(soft_labels.shape, probs.shape)
    assert soft_labels.shape == probs.shape

    log_probs = torch.log(probs)
    return torch.sum(-1 * log_probs *  soft_labels) / (batch_size * 10)

    # scaled_logits = probs / 20
    # log_probs = F.log_softmax(probs, dim=1)
    # loss = -torch.sum(soft_labels * log_probs) / soft_labels.size(0)
    # return loss

In [14]:
def train(model, epochs, train_loader, T):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        for i, (img, soft_label) in enumerate(train_loader):
            optimizer.zero_grad()
            
            img, soft_label = img.to(device), soft_label.to(device)
            # print(soft_label)
            logits = model(img)
    
            out = softmax_with_temp(logits, T)
            l = cross_entropy_loss_soft(soft_label, out)
    
            print(l)
            if i % 50==0:
                print(l)
    
            l.backward()
            optimizer.step()

        

In [15]:
# a= torch.nn.NLLLoss()
# a(torch.tensor([0.1,0.2,0.7]), torch.tensor(0))

In [16]:
train(distilled, 3, soft_label_loader, 20)

tensor(0.2303, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2303, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2301, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2301, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2298, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2297, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2293, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2292, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2288, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2284, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2281, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2281, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2273, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2269, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2261, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2258, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2253, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2243, device='mps:0', grad_fn=<DivBackward0>)
tensor(0.2

In [17]:
xlab.utils.evaluate_mnist_accuracy(model)

0.949