## Certified Finetuning of a Classifier on the OCT-MNIST Dataset

In this notebook, we'll fine-tune a classifier on the OCT-MNIST dataset using Abstract Gradient Training. The OCT-MNIST dataset is a medical imaging dataset that contains 4 diagnostic classes (choroidal neovascularization, diabetic macular edem, drusen, and normal retina). We'll tackle the binary-classification problem of distinguishing between normal retina and the other classes.

We'll assume the following setting:

- The model is first pre-trained on the dataset with the drusen class removed, i.e. trained to distinguish normal vs. (choroidal neovascularization, diabetic macular edem) classes.
- We'll then fine-tune the model on the full dataset (including the drusen class) using Abstract Gradient Training for privacy-safe certification.
- The model is a convolutional network with 3 convolutional layers and 2 fully connected layers. The pre-training will train all the layers, while the fine-tuning only trains the dense layers using AGT.

This is simulating a setting in which a model is pre-trained on public data and then fine-tuned on private / sensitive data (in this case the drusen class). We'll then use the certificates provided by AGT to make privacy-preserving predictions:

- Running AGT for a range of k_private values, we can use the resulting parameter bounds to compute a bound on the smooth sensitivity of the model for a given prediction.
- Using the smooth sensitivity bounds, we can calibrate the noise to add to the prediction to ensure differential privacy for a given epsilon.
- We can then use the calibrated noise to make privacy-preserving predictions, which should maintain high utility when compared with noise calibrated to the global sensitivity. 


In [2]:
%load_ext autoreload
%autoreload 2
import os
import copy

import torch
import tqdm
import torchvision

import abstract_gradient_training as agt
from abstract_gradient_training import AGTConfig
from abstract_gradient_training.bounded_models import IntervalBoundedModel

from medmnist import OCTMNIST  # python -m pip install git+https://github.com/MedMNIST/MedMNIST.git

### 1. Load the dataset

In [3]:
def get_datasets(exclude_classes=None, balanced=False):
    """
    Get OCT MedMNIST dataset as a binary classification problem of class 3 (normal) vs classes 0, 1, 2.
    """

    # get the datasets
    train_dataset = OCTMNIST(split="train", transform=torchvision.transforms.ToTensor())
    test_dataset = OCTMNIST(split="test", transform=torchvision.transforms.ToTensor())
    train_imgs, train_labels = train_dataset.imgs, train_dataset.labels
    test_imgs, test_labels = test_dataset.imgs, test_dataset.labels

    # filter out excluded classes
    if exclude_classes is not None:
        for e in exclude_classes:
            train_imgs = train_imgs[(train_labels != e).squeeze()]
            train_labels = train_labels[(train_labels != e).squeeze()]
            test_imgs = test_imgs[(test_labels != e).squeeze()]
            test_labels = test_labels[(test_labels != e).squeeze()]

    # convert to a binary classification problem
    train_labels = train_labels != 3  # i.e. 0 = normal, 1 = abnormal
    test_labels = test_labels != 3

    # apply the appropriate scaling and transposition
    train_imgs = torch.tensor(train_imgs, dtype=torch.float32).unsqueeze(1) / 255
    test_imgs = torch.tensor(test_imgs, dtype=torch.float32).unsqueeze(1) / 255
    train_labels = torch.tensor(train_labels, dtype=torch.int64)
    test_labels = torch.tensor(test_labels, dtype=torch.int64)

    # balance the training dataset such that the number of samples in each class is equal
    if balanced:
        n_ones = train_labels.sum().item()
        n_zeros = len(train_labels) - n_ones
        n_samples = min(n_ones, n_zeros)
        # find the indices of the ones, and then randomly sample n_samples from them
        idx_ones = torch.where(train_labels == 1)[0]
        ones_selection = torch.randperm(n_ones)
        idx_ones = idx_ones[ones_selection][:n_samples]
        # find the indices of the zeros, and then randomly sample n_samples from them
        idx_zeros = torch.where(train_labels == 0)[0]
        zeros_selection = torch.randperm(n_zeros)
        idx_zeros = idx_zeros[zeros_selection][:n_samples]
        idx = torch.cat([idx_ones, idx_zeros])
        train_imgs, train_labels = train_imgs[idx], train_labels[idx]
    
    # form dataloaders
    train_dataset = torch.utils.data.TensorDataset(train_imgs, train_labels)
    test_dataset = torch.utils.data.TensorDataset(test_imgs, test_labels)
    return train_dataset, test_dataset


### 2. Initialize the model

In [4]:
torch.manual_seed(1)
# small architecture from https://arxiv.org/abs/1810.12715
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 16, 4, 2, 0),
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, 4, 1, 0),
    torch.nn.ReLU(),
    torch.nn.Flatten(),
    torch.nn.Linear(3200, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 1),
)

### 3. Pre-train the model without the drusen class

In [5]:
# set up the pre-training configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_batchsize = 100
pretrain_n_epochs = 20
pretrain_learning_rate = 0.001
dataset_pretrain, _ = get_datasets(exclude_classes=[2], balanced=True)
dl_pretrain = torch.utils.data.DataLoader(dataset_pretrain, batch_size=pretrain_batchsize, shuffle=True)
model = model.to(device)

In [6]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=pretrain_learning_rate)
model = model.to(device)

In [7]:
models_dir = ".models"
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

# check if a pre-trained model exists
if os.path.exists(".models/medmnist.ckpt"):
    model.load_state_dict(torch.load(".models/medmnist.ckpt"))
else:  # pre-train the model
    progress_bar = tqdm.trange(pretrain_n_epochs, desc="Epoch", )
    for epoch in progress_bar:
        for i, (x, u) in enumerate(dl_pretrain):
            # Forward pass
            u, x = u.to(device), x.to(device)
            output = model(x)
            loss = criterion(output.squeeze().float(), u.squeeze().float())
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                progress_bar.set_postfix(loss=loss.item())
    # save the model
    with open(".models/medmnist.ckpt", "wb") as file:
        torch.save(model.state_dict(), file)

  model.load_state_dict(torch.load(".models/medmnist.ckpt"))


### 4. Evaluate the pre-trained model

The pre-trained model performs poorly on the unseen drusen class, as expected.

In [8]:
# evaluate the pre-trained model
_, dataset_test_all = get_datasets()
x, u = dataset_test_all.tensors
u, x = u.to(device), x.to(device)
output = torch.sigmoid(model(x))
preds = (output > 0.5)
accuracy = (preds == u).float().mean().item()
print(f"Pre-trained model accuracy (all classes): {accuracy:.2f}")

_, dataset_test_no_drusen = get_datasets(exclude_classes=[2])
x, u = dataset_test_no_drusen.tensors
u, x = u.to(device), x.to(device)
output = torch.sigmoid(model(x))
preds = (output > 0.5)
accuracy = (preds == u).float().mean().item()
print(f"Pre-trained model accuracy (excluding drusen): {accuracy:.2f}")

_, dataset_test_drusen = get_datasets(exclude_classes=[0, 1, 3])
x, u = dataset_test_drusen.tensors
u, x = u.to(device), x.to(device)
output = torch.sigmoid(model(x))
preds = (output > 0.5)
accuracy = (preds == u).float().mean().item()
print(f"Pre-trained model accuracy (drusen class): {accuracy:.2f}")

Pre-trained model accuracy (all classes): 0.84
Pre-trained model accuracy (excluding drusen): 0.95
Pre-trained model accuracy (drusen class): 0.51


### 5. Fine-tune the model on the drusen class using AGT

In [26]:
# set up the AGT configuration
batchsize = 5000
nominal_config = AGTConfig(
    fragsize=2000,
    learning_rate=0.1,
    n_epochs=3,
    device="cuda:0",
    l2_reg=0.01,
    k_private=10,
    loss="binary_cross_entropy",
    log_level="INFO",
    lr_decay=2.0,
    clip_gamma=1.0,
    lr_min=0.001,
    optimizer="SGDM", # we'll use SGD with momentum
    optimizer_kwargs={"momentum": 0.9, "nesterov": True},
)

In [27]:
# get dataloaders, train dataloader is a mix of drusen and the "healthy" class
dataset_train, _ = get_datasets(exclude_classes=[0, 1], balanced=True)  # a mix of drusen (class 2) and normal (class 3)
_, dataset_test_drusen = get_datasets(exclude_classes=[0, 1, 3])  # drusen only (class 2)
torch.manual_seed(0)
dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=batchsize, shuffle=True)
dl_test_drusen = torch.utils.data.DataLoader(dataset_test_drusen, batch_size=batchsize, shuffle=False)

In [28]:
# we are only going to fine-tune the dense layers of our model, so we'll handle that logic here
# we'll form a separate bounded model for the convolutional layers and
# use it as the `transform` argument of the bounded model of the dense layers
conv_layers, dense_layers = model[0:5], model[5:]
conv_bounded_model = IntervalBoundedModel(conv_layers, trainable=False)
bounded_model = IntervalBoundedModel(dense_layers, trainable=True, transform=conv_bounded_model)

In [29]:
# fine-tune the model using abstract gradient training (keeping the convolutional layers fixed)
agt.privacy_certified_training(bounded_model, nominal_config, dl_train, dl_test_drusen)

[AGT] [INFO    ] [17:51:27] Starting epoch 1
[AGT] [INFO    ] [17:51:27] Batch 0. Loss (accuracy): 0.508 <= 0.508 <= 0.508
[AGT] [INFO    ] [17:51:28] Batch 1. Loss (accuracy): 0.660 <= 0.712 <= 0.740
[AGT] [INFO    ] [17:51:28] Batch 2. Loss (accuracy): 0.732 <= 0.796 <= 0.832
[AGT] [INFO    ] [17:51:29] Starting epoch 2
[AGT] [INFO    ] [17:51:29] Batch 3. Loss (accuracy): 0.748 <= 0.816 <= 0.868
[AGT] [INFO    ] [17:51:30] Batch 4. Loss (accuracy): 0.760 <= 0.840 <= 0.908
[AGT] [INFO    ] [17:51:31] Batch 5. Loss (accuracy): 0.740 <= 0.848 <= 0.928
[AGT] [INFO    ] [17:51:32] Starting epoch 3
[AGT] [INFO    ] [17:51:32] Batch 6. Loss (accuracy): 0.728 <= 0.864 <= 0.928
[AGT] [INFO    ] [17:51:33] Batch 7. Loss (accuracy): 0.708 <= 0.872 <= 0.932
[AGT] [INFO    ] [17:51:34] Batch 8. Loss (accuracy): 0.668 <= 0.872 <= 0.940
[AGT] [INFO    ] [17:51:34] Final Eval. Loss (accuracy): 0.648 <= 0.872 <= 0.952


<abstract_gradient_training.bounded_models.interval_bounded_model.IntervalBoundedModel at 0x707c64b0a5c0>

### 6. Evaluate the fine-tuned model

The fine-tuned model performs better on the drusen class (80%+ accuracy) while maintaining accuracy on the other classes.

In [12]:
# evaluate the fine-tuned model
_, dataset_test_all = get_datasets()
accuracy = agt.test_metrics.test_accuracy(bounded_model, *dataset_test_all.tensors)
print(f"Pre-trained model accuracy + certified bounds (all classes): {accuracy[0]:.2f} <= {accuracy[1]:.2f} <= {accuracy[2]:.2f}")

_, dataset_test_no_drusen = get_datasets(exclude_classes=[2])
accuracy = agt.test_metrics.test_accuracy(bounded_model, *dataset_test_no_drusen.tensors)
print(f"Pre-trained model accuracy + certified bounds (excluding drusen): {accuracy[0]:.2f} <= {accuracy[1]:.2f} <= {accuracy[2]:.2f}")

_, dataset_test_drusen = get_datasets(exclude_classes=[0, 1, 3])
accuracy = agt.test_metrics.test_accuracy(bounded_model, *dataset_test_drusen.tensors)
print(f"Pre-trained model accuracy + certified bounds (drusen class): {accuracy[0]:.2f} <= {accuracy[1]:.2f} <= {accuracy[2]:.2f}")

Pre-trained model accuracy + certified bounds (all classes): 0.80 <= 0.86 <= 0.91
Pre-trained model accuracy + certified bounds (excluding drusen): 0.84 <= 0.88 <= 0.92
Pre-trained model accuracy + certified bounds (drusen class): 0.67 <= 0.80 <= 0.86


### 7. Use the AGT certificates to make privacy-preserving predictions on the test dataset

First, run AGT for a range of k_private values. Then, use the parameter bounds to compute the smooth sensitivity of the model on the test-set predictions. Finally, use the smooth sensitivity to calibrate the amount of noise to add to the predictions to ensure differential privacy.

In [13]:
# to use privacy-safe certificates, we need to run AGT for a range of k_private values

# we'll just pick a reasonable range of k_private values. adding more values will increase the runtime
# but also result in tighter privacy results. even a few values are sufficient to demonstrate tighter privacy

k_private_values = [1, 2, 5, 10, 20, 50, 100] 
privacy_bounded_models = {}
config = copy.deepcopy(nominal_config)
config.log_level = "WARNING"

for k_private in tqdm.tqdm(k_private_values):
    # update config
    config.k_private = k_private
    # form bounded model
    conv_layers, dense_layers = model[0:5], model[5:]
    conv_bounded_model = IntervalBoundedModel(conv_layers, trainable=False)
    bounded_model = IntervalBoundedModel(dense_layers, trainable=True, transform=conv_bounded_model)
    torch.manual_seed(0)
    dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=batchsize, shuffle=True)
    # run AGT
    agt.privacy_certified_training(bounded_model, config, dl_train, dl_test_drusen)
    privacy_bounded_models[k_private] = bounded_model

100%|██████████| 7/7 [00:36<00:00,  5.23s/it]


In [14]:
# make privacy-safe predictions using the global sensitivity
epsilon = 1.0
_, dataset_test_all = get_datasets()
accuracy = agt.privacy_utils.noisy_test_accuracy(
    bounded_model, *dataset_test_all.tensors, noise_level=1 / epsilon, noise_type="laplace"
)
print(f"Accuracy using global sensitivity: {accuracy:.2f}")

# make privacy-safe predictions using the smooth sensitivity bounds from AGT
noise_level = agt.privacy_utils.get_calibrated_noise_level(
    dataset_test_all.tensors[0], privacy_bounded_models, epsilon=epsilon, noise_type="cauchy" 
)
accuracy = agt.privacy_utils.noisy_test_accuracy(
    bounded_model, *dataset_test_all.tensors, noise_level=noise_level, noise_type="cauchy"
)
print(f"Accuracy using AGT smooth sensitivity bounds: {accuracy:.2f}")

Accuracy using global sensitivity: 0.65
Accuracy using AGT smooth sensitivity bounds: 0.86
