# Certified Finetuning of a Classifier on the OCT-MNIST Dataset

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import torch
import tqdm
import abstract_gradient_training as agt
from abstract_gradient_training import AGTConfig
from abstract_gradient_training import model_utils
from models.deepmind import DeepMindSmall 
from datasets import oct_mnist
from models.robust_regularizer import parameter_gradient_interval_regularizer


## Test the robustness of a non-robustly pre-trained classifier

In [3]:
device = torch.device("cuda:1")
_, dl_test = oct_mnist.get_dataloaders(1000, exclude_classes=[2], balanced=True)
standard_model = DeepMindSmall(1, 1).to(device)
standard_model.load_state_dict(torch.load(".models/medmnist.ckpt"))
params_l, params_n, params_u = model_utils.get_parameters(standard_model[5:-1])
epsilon = 0.01
test_batch, test_labels = next(iter(dl_test))
accs = agt.test_metrics.test_accuracy(
    params_l,
    params_n,
    params_u,
    test_batch,
    test_labels,
    transform=model_utils.get_conv_model_transform(standard_model[0:5]),
    epsilon=epsilon,
)
accs = ", ".join([f"{a:.2f}" for a in accs])

print(f"Accuracy of non-robustly trained classifier on test set with epsilon={epsilon}: [{accs}]")

  standard_model.load_state_dict(torch.load(".models/medmnist.ckpt"))


Accuracy of non-robustly trained classifier on test set with epsilon=0.01: [0.00, 0.96, 1.00]


## Pre-train the model

Exclude class 2 (Drusen) from the pretraining.

In [4]:
# set up pre-training
torch.manual_seed(1)
pretrain_batchsize = 100
pretrain_n_epochs = 10
pretrain_learning_rate = 0.001
pretrain_epsilon = 0.55
pretrain_model_epsilon = 0.001
pretrain_reg_strength = 0.4
model_path = f".models/medmnist_robust_eps{pretrain_epsilon}_alpha{pretrain_reg_strength}_meps{pretrain_model_epsilon}.ckpt"

In [5]:
# define model, dataset and optimizer
model = DeepMindSmall(1, 1)
dl_pretrain, _ = oct_mnist.get_dataloaders(pretrain_batchsize, exclude_classes=[2], balanced=True)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=pretrain_learning_rate)
model = model.to(device)

In [6]:
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
else:  # pre-train the model
    progress_bar = tqdm.trange(pretrain_n_epochs, desc="Epoch")
    for epoch in progress_bar:
        for i, (x, u) in enumerate(dl_pretrain):
            # Forward pass
            u, x = u.to(device), x.to(device)
            output = model(x)
            bce_loss = criterion(output.squeeze().float(), u.squeeze().float())
            regularization = parameter_gradient_interval_regularizer(
                model, x, u, "binary_cross_entropy", pretrain_epsilon, pretrain_model_epsilon
            )
            loss = bce_loss + pretrain_reg_strength * regularization
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                progress_bar.set_postfix(loss=loss.item(), bce_loss=bce_loss.item(), reg=regularization.item())
    # save the model
    with open(model_path, "wb") as file:
        torch.save(model.state_dict(), file)

  model.load_state_dict(torch.load(model_path))


### Test the robustness of the model pre-trained with the gradient interval regularization term

In [8]:
conv_layers = model[0:5]
linear_layers = model[5:-1]
conv_transform = model_utils.get_conv_model_transform(conv_layers)
params_l, params_n, params_u = model_utils.get_parameters(linear_layers)

_, dl_test = oct_mnist.get_dataloaders(1000, exclude_classes=[2], balanced=True)
test_batch, test_labels = next(iter(dl_test))
accs = agt.test_metrics.test_accuracy(
    params_l,
    params_n,
    params_u,
    test_batch,
    test_labels,
    transform=conv_transform,
    epsilon=0.01,
)
accs = ", ".join([f"{a:.2f}" for a in accs])
print(model_path, accs)
print(f"Accuracy of robustly trained classifier on test set with epsilon={epsilon}: [{accs}]")

.models/medmnist_robust_eps0.55_alpha0.4_meps0.001.ckpt 0.32, 0.87, 1.00
Accuracy of robustly trained classifier on test set with epsilon=0.01: [0.32, 0.87, 1.00]


## Fine-tune the model

Include all classes, only allowing class 2 (Drusen) to be potentially poisoned.

In [9]:
# set up fine-tuning parameters
clean_batchsize = 3000
drusen_batchsize = 3000
test_batchsize = 1000

In [11]:
from abstract_gradient_training.poisoning import poison_certified_training

torch.manual_seed(0)

# get dataloaders
dl_clean, dl_test_clean = oct_mnist.get_dataloaders(clean_batchsize, test_batchsize, exclude_classes=[2])
dl_drusen, dl_test_drusen = oct_mnist.get_dataloaders(drusen_batchsize, test_batchsize, exclude_classes=[0, 1, 3])
_, dl_test_all = oct_mnist.get_dataloaders(clean_batchsize, test_batchsize)

# evaluate the pre-trained model
param_l, param_n, param_u = model_utils.get_parameters(linear_layers)
drusen_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_drusen)), transform=conv_transform
)
clean_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_clean)), transform=conv_transform
)
all_acc = agt.test_metrics.test_accuracy(param_l, param_n, param_u, *next(iter(dl_test_all)), transform=conv_transform)

print("=========== Pre-trained model accuracy ===========", file=sys.stderr)
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}", file=sys.stderr)
print(f"Classes 0, 1, 3  : nominal = {clean_acc[1]:.2g}", file=sys.stderr)
print(f"All Classes      : nominal = {all_acc[1]:.2g}", file=sys.stderr)

config = AGTConfig(
    fragsize=2000,
    learning_rate=0.06,
    n_epochs=2,
    k_poison=50,
    epsilon=0.01,
    # clip_gamma = 2.0,
    forward_bound="interval",
    device="cuda:1",
    backward_bound="interval",
    loss="binary_cross_entropy",
    log_level="DEBUG",
    lr_decay=4.0,
    lr_min=0.001,
)

# fine-tune the model using abstract gradient training (keeping the convolutional layers fixed)
param_l, param_n, param_u = poison_certified_training(
    linear_layers, config, dl_drusen, dl_test_drusen, dl_clean=dl_clean, transform=conv_transform
)

# evaluate the fine-tuned model
drusen_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_drusen)), transform=conv_transform
)
clean_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_clean)), transform=conv_transform
)
all_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_all)), transform=conv_transform
)

print("=========== Fine-tuned model accuracy + bounds ===========", file=sys.stderr)
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}, certified bound = {drusen_acc[0]:.2g}", file=sys.stderr)
print(f"Classes 0, 1, 3  : nominal = {clean_acc[1]:.2g}, certified bound = {clean_acc[0]:.2g}", file=sys.stderr)
print(f"All Classes      : nominal = {all_acc[1]:.2g}, certified bound = {all_acc[0]:.2g}", file=sys.stderr)

Class 2 (Drusen) : nominal = 0.51
Classes 0, 1, 3  : nominal = 0.85
All Classes      : nominal = 0.77
[AGT] [DEBUG   ] [16:55:46] 	Optimizer params: n_epochs=2, learning_rate=0.06, l1_reg=0.0, l2_reg=0.0
[AGT] [DEBUG   ] [16:55:46] 	Learning rate schedule: lr_decay=4.0, lr_min=0.001, early_stopping=True
[AGT] [DEBUG   ] [16:55:46] 	Adversary feature-space budget: epsilon=0.01, k_poison=50
[AGT] [DEBUG   ] [16:55:46] 	Adversary label-space budget: label_epsilon=0, label_k_poison=0, poison_target=-1
[AGT] [DEBUG   ] [16:55:46] 	Clipping: gamma=inf, method=clamp
[AGT] [DEBUG   ] [16:55:46] 	Bounding methods: forward=interval, loss=binary_cross_entropy, backward=interval
[AGT] [INFO    ] [16:55:46] Starting epoch 1
[AGT] [DEBUG   ] [16:55:46] Initialising dataloader batchsize to 6000
[AGT] [INFO    ] [16:55:46] Training batch 1: Network eval bounds=(0.51, 0.51, 0.51), W0 Bound=0.0 
[AGT] [INFO    ] [16:55:47] Training batch 2: Network eval bounds=(0.87, 0.89, 0.89), W0 Bound=0.0135 
[AGT] 

In [13]:
from abstract_gradient_training.privacy import privacy_certified_training

torch.manual_seed(1)

# get dataloaders
dl_clean, dl_test_clean = oct_mnist.get_dataloaders(clean_batchsize, test_batchsize, exclude_classes=[2])
dl_drusen, dl_test_drusen = oct_mnist.get_dataloaders(drusen_batchsize, test_batchsize, exclude_classes=[0, 1, 3])
_, dl_test_all = oct_mnist.get_dataloaders(clean_batchsize, test_batchsize)

# evaluate the pre-trained model
param_l, param_n, param_u = model_utils.get_parameters(linear_layers)
drusen_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_drusen)), transform=conv_transform
)
clean_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_clean)), transform=conv_transform
)
all_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_all)), transform=conv_transform
)

print("=========== Pre-trained model accuracy ===========", file=sys.stderr)
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}", file=sys.stderr)
print(f"Classes 0, 1, 3  : nominal = {clean_acc[1]:.2g}", file=sys.stderr)
print(f"All Classes      : nominal = {all_acc[1]:.2g}", file=sys.stderr)

config = AGTConfig(
    fragsize=500,
    learning_rate=0.08,
    n_epochs=2,
    k_private=50,
    forward_bound="interval",
    device="cuda:0",
    clip_gamma=5.0,
    backward_bound="interval",
    loss="binary_cross_entropy",
    lr_decay=4.0,
    lr_min=0.001,
)

# fine-tune the model using abstract gradient training (keeping the convolutional layers fixed)
param_l, param_n, param_u = privacy_certified_training(
    linear_layers, config, dl_drusen, dl_test_drusen, dl_public=dl_clean, transform=conv_transform
)

# evaluate the fine-tuned model
drusen_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_drusen)), transform=conv_transform
)
clean_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_clean)), transform=conv_transform
)
all_acc = agt.test_metrics.test_accuracy(
    param_l, param_n, param_u, *next(iter(dl_test_all)), transform=conv_transform
)

print("=========== Fine-tuned model accuracy + bounds ===========", file=sys.stderr)
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}, certified bound = {drusen_acc[0]:.2g}", file=sys.stderr)
print(f"Classes 0, 1, 3  : nominal = {clean_acc[1]:.2g}, certified bound = {clean_acc[0]:.2g}", file=sys.stderr)
print(f"All Classes      : nominal = {all_acc[1]:.2g}, certified bound = {all_acc[0]:.2g}", file=sys.stderr)

Class 2 (Drusen) : nominal = 0.51
Classes 0, 1, 3  : nominal = 0.85
All Classes      : nominal = 0.77
[AGT] [INFO    ] [16:56:29] Starting epoch 1
[AGT] [INFO    ] [16:56:29] Training batch 1: Network eval bounds=(0.51, 0.51, 0.51), W0 Bound=0.0 
[AGT] [INFO    ] [16:56:30] Training batch 2: Network eval bounds=(0.68, 0.94, 0.99), W0 Bound=3.77 
[AGT] [INFO    ] [16:56:31] Starting epoch 2
[AGT] [INFO    ] [16:56:31] Training batch 3: Network eval bounds=(0.51, 0.94, 1   ), W0 Bound=4.53 
[AGT] [INFO    ] [16:56:32] Training batch 4: Network eval bounds=(0.36, 0.94, 1   ), W0 Bound=4.95 
[AGT] [INFO    ] [16:56:33] Final network eval: Network eval bounds=(0.24, 0.94, 1   ), W0 Bound=5.24 
Class 2 (Drusen) : nominal = 0.94, certified bound = 0.24
Classes 0, 1, 3  : nominal = 0.76, certified bound = 0.43
All Classes      : nominal = 0.81, certified bound = 0.38
