# Certified Finetuning of a Classifier on the OCT-MNIST Dataset

In [1]:
%load_ext autoreload
%autoreload 2
import torch
import tqdm
import abstract_gradient_training as agt
from abstract_gradient_training.certified_training import utils as ct_utils
from models.deepmind import DeepMindSmall 
from datasets import oct_mnist

## Pre-train the model

Exclude class 2 (Drusen) from the pretraining.

In [2]:
# set up pre-training
torch.manual_seed(0)
device = torch.device("cuda:0")
pretrain_batchsize = 100
pretrain_n_epochs = 20
pretrain_learning_rate = 0.001

In [3]:
# define model, dataset and optimizer
model = DeepMindSmall(1, 1)
dl_train, dl_test = oct_mnist.get_dataloaders(pretrain_batchsize, exclude_classes=[2])
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=pretrain_learning_rate)
model = model.to(device)

In [4]:
# pre-train the model
progress_bar = tqdm.trange(pretrain_n_epochs, desc="Epoch", )
for epoch in progress_bar:
    for i, (x, u) in enumerate(dl_train):
        # Forward pass
        u, x = u.to(device), x.to(device)
        output = model(x)
        loss = criterion(output.squeeze().float(), u.squeeze().float())
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            progress_bar.set_postfix(loss=loss.item())

Epoch: 100%|██████████| 20/20 [00:45<00:00,  2.28s/it, loss=0.0207] 


## Fine-tune the model on the private Drusen data

In [5]:
# set up fine-tuning parameters
config = {
    "batchsize": 3000,
    "fragsize": 1000,
    "learning_rate": 0.1,
    "n_epochs": 1,
    "k_unlearn": 50,
    "interval_matmul": "rump",
    "forward_bound": "interval",
    "device": "cuda:0",
    "backward_bound": "interval",
    "loss": "binary_cross_entropy",
    "optimizer": "sgd",
    "optimizer_kwargs": {
        "decay_rate": 0.3,
        "lr_min": 0.001,
    },
}

In [6]:
# get dataloaders
dl_train_drusen, dl_test_drusen = oct_mnist.get_dataloaders(config["batchsize"], 1000, exclude_classes=[0, 1, 3])
_, dl_test_other = oct_mnist.get_dataloaders(config["batchsize"], 1000, exclude_classes=[2])
_, dl_test_all = oct_mnist.get_dataloaders(config["batchsize"], 1000)

In [7]:
# evaluate the pre-trained model
param_n, param_l, param_u = ct_utils.get_parameters(model)
drusen_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, dl_test_drusen, model, ct_utils.propagate_conv_layers)
other_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, dl_test_other, model, ct_utils.propagate_conv_layers)
all_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, dl_test_all, model, ct_utils.propagate_conv_layers)

print("=========== Pre-trained model accuracy ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}")

Class 2 (Drusen) : nominal = 0.6
Classes 0, 1, 3  : nominal = 0.95
All Classes      : nominal = 0.86


In [8]:
# fine-tune the model using abstract gradient training (keeping the convolutional layers fixed)
param_l, param_n, param_u, accuracy = agt.unlearning_certified_training(
    model, config, dl_train_drusen, dl_test_drusen, transform=ct_utils.propagate_conv_layers
)

100%|██████████| 1/1 [00:01<00:00,  1.46s/it, eval: (0.7480000257492065, 0.7760000228881836, 0.7760000228881836) bound: 0.1 batch: 1 frag: 2]


In [9]:
# evaluate the fine-tuned model
drusen_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, dl_test_drusen, model, ct_utils.propagate_conv_layers)
other_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, dl_test_other, model, ct_utils.propagate_conv_layers)
all_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, dl_test_all, model, ct_utils.propagate_conv_layers)

print("=========== Fine-tuned model accuracy + bounds ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}, certified bound = {drusen_acc[0]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}, certified bound = {other_acc[0]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}, certified bound = {all_acc[0]:.2g}")

Class 2 (Drusen) : nominal = 0.84, certified bound = 0.8
Classes 0, 1, 3  : nominal = 0.9, certified bound = 0.89
All Classes      : nominal = 0.88, certified bound = 0.87
