In [1]:
# Hacky but for now..
import sys
sys.path.append('../')

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision.datasets import CIFAR10
from torchvision.datasets import MNIST
from torchvision import transforms
import torchmetrics

from torch.utils.data import DataLoader
import foolbox as fb
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from torch import nn
import itertools

import numpy as np
import PIL.Image
#Make sure of deterministic behaviour
# torch.use_deterministic_algorithms(True)


from utils import rescale
from DataSet.lightning_cifar import CIFARDataModule
from models.cifar_nn import Cifar_nn
from lightning_trainer import LitModelTrainer


In [2]:
### Set hyperparameter / global variables

# Data location variables:
cifar_loc = '../../../data/cifar'
mnist_loc = '../../../data/mnist'

val_model_loc = '../trained_models/CIFAR/base/version_1/checkpoints/model.ckpt'


# val_MNIST_adv_eps05_lr1e5_livedreams
model_identifier = "val_CIFAR_base"

# Device settings:
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.cuda.FloatTensor)


# Transform for data:
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Dataloader
batch_size = 1000    # Num images per batch
num_procs = 0       # Number of workers to fetch data

# 20 -> max eps of 6.
# 28 = L2 multiplier for 0.3 -> 8.4 on MNIST
# 16 = L2 multiplier for 0.3 -> ~ 1.66 for CIFAR (sqrt((32*32*3)*(0.03*0.03))) = ~1.62
L2_eps_multiplier = 16      # L2 norm requires a much larger epsilon value as "budget" to work with. CIFAR.



In [3]:
# Load validation data:
# Shuffling data is not necessary for validation runs btw.

# classes_list = list(range(10)) # Just numbers for MNIST
classes_list = ["airplanes", "cars", "birds", "cats", "deer", "dogs", "frogs", "horses", "ships", "trucks"]

# val_data = MNIST(root=mnist_loc, train=False, download=True, transform=transform)
val_data = CIFAR10(root=cifar_loc, train=False, download=True, transform=transform)
# val_subset = torch.utils.data.Subset(val_data, list(range(500)))


val_loader = DataLoader(val_data, batch_size=batch_size, num_workers=num_procs, drop_last=True)

Files already downloaded and verified


In [4]:
# Load model from checkpoint:

ModelCheckpoint = LitModelTrainer.load_from_checkpoint(val_model_loc)
val_model : nn.Module = ModelCheckpoint.model

val_model = val_model.to(torch_device)

val_model.eval()

# Setup model in FB for attacks
fmodel = fb.models.pytorch.PyTorchModel(model=val_model, bounds=(0, 1), device=torch_device)

# Criterion to use for model loss calculation (not for training!)
model_crit = nn.CrossEntropyLoss(reduction="mean")

# SummaryWriter for metric info and saved images:
logging_Obj = SummaryWriter(log_dir="validation_results/" + model_identifier, comment='')




In [5]:

# pairs of string identifiers with (attack, metric provider function) to test
# 
val_attacks = {
   "FGSM-RS": fb.attacks.FGSM(random_start=True),
    "FGSM-NORS": fb.attacks.FGSM(random_start=False),
   
    "PGD4-0.5": fb.attacks.LinfPGD(steps=4, rel_stepsize=1/2),
    "PGD5-0.4": fb.attacks.LinfPGD(steps=5, rel_stepsize=2/5),

    "PGD10-0.1": fb.attacks.LinfPGD(steps=10, rel_stepsize=1/10),
    "PGD10-0.2": fb.attacks.LinfPGD(steps=10, rel_stepsize=2/10),
    "PGD10-0.4": fb.attacks.LinfPGD(steps=10, rel_stepsize=2/5),

    "PGD20-0.1": fb.attacks.LinfPGD(steps=20, rel_stepsize=1/10),
    "PGD20-0.2": fb.attacks.LinfPGD(steps=20, rel_stepsize=1/5),

    "CW_L2_lr=0.01_steps=200_bstep=5": fb.attacks.carlini_wagner.L2CarliniWagnerAttack(steps=200, stepsize=0.01, abort_early=True, binary_search_steps=5, initial_const=1e-3),
    "CW_L2_lr=0.05_steps=100_bstep=5": fb.attacks.carlini_wagner.L2CarliniWagnerAttack(steps=100, stepsize=0.05, abort_early=True, binary_search_steps=5, initial_const=1e-3),
    "CW_L2_lr=0.1_steps=100_bstep=5":  fb.attacks.carlini_wagner.L2CarliniWagnerAttack(steps=100, stepsize=0.1, abort_early=True, binary_search_steps=5, initial_const=1e-3)
}


# 
# epsilons = [0.01, 0.03, 0.1, 0.2, 0.3, 0.4, 0.5] # For MNIST
epsilons = [0.001, 0.005, 0.01, 0.03, 0.1, 0.2, 0.3] # For CIFAR10

# Criterion to use (targeted vs untargeted) for attacks
adv_crit = fb.criteria.Misclassification

# Metrics to calculate/use. In identifier: (metric, reduction method)
# Supported reductions: sum, mean, ?conf.matrix?, avgdiff - > Calculates average distance of items in tensors
metricsPerAttack = {
    "Loss" :  (lambda : torchmetrics.MeanMetric().to(device=torch_device), ""),
    "Average_perturbation" : (lambda : torchmetrics.MeanMetric().to(device=torch_device), ""),
    "Top_1_Accuracy" : (lambda : torchmetrics.Accuracy(top_k=1).to(device=torch_device), ""),
    "Top_3_Accuracy" : (lambda : torchmetrics.Accuracy(top_k=3), ""), 
    "Confusion_Matrix" : (lambda : torchmetrics.ConfusionMatrix(num_classes=10), "")
}


In [6]:
import io
from matplotlib import pyplot as plt


def plot_confusion_matrix(cm, eps = "", att_id = "", class_names = list(range(10)), top_1_acc=0.0):
    """
    Returns a matplotlib figure containing the plotted confusion matrix.
    
    Args:
       cm (array, shape = [n, n]): a confusion matrix of integer classes
       class_names (array, shape = [n]): String names of the integer classes
    """
    
    figure = plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    # Normalize the confusion matrix.
    cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)
    
    # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)
    
    plt.text(0, -1.5, f'attack_id: {att_id}; eps:{eps}; acc: {top_1_acc:.3f}',
        horizontalalignment='center',
        verticalalignment='center')

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return buf

In [7]:
# Validation funct.
def validate_attacks():
    attack_num_handled = 0
    metric_val_dict = {}

    for attack_idx, attack_tuple  in enumerate(val_attacks.items()):
        attack_id, val_attack = attack_tuple
        if "_L2" in attack_id:
            epsilon_mult = L2_eps_multiplier
        else:
            epsilon_mult = 1

        epsilon_set = epsilons


        print("Doing val of " + attack_id + ", attack step# = ", attack_num_handled)
        metric_val_dict[attack_id] = {}
        metric_dict = {}
        for step, batch in enumerate(val_loader):
            x:torch.Tensor
            y:torch.Tensor
            x, y = batch
            x, y = x.to(torch_device), y.to(torch_device)
            raw_advs, clipped_advs, success = val_attack(fmodel, x, epsilons=np.multiply(epsilon_set, epsilon_mult), criterion=adv_crit(y))

            for eps_idx, epsilon_val in enumerate(epsilon_set):
                    

                # Setup if on first step of per epsilon
                if step == 0:
                    metric_dict[str(epsilon_val)] = {}
                    metric_val_dict[attack_id][str(epsilon_val)] = {}
                x_adv = clipped_advs[eps_idx].to(torch_device)

                with torch.no_grad():      # Do forward pass
                    y_hat_adv = val_model(x_adv)
                    y_hat_adv.to(torch_device)
                    y.to(torch_device)

                   
                    for metric_id, metric in metricsPerAttack.items():
                        metric_func: torchmetrics.Metric
                        metric_func, metric_reduction = metric
                        if step == 0:
                            # metric_val_dict[attack_id][str(epsilon_val)][metric_id] = {}
                            metric_dict[str(epsilon_val)][metric_id] = metric_func()

                        if metric_id == "Average_perturbation": # Calculate avg perturbation size on all images
                            alldiffstensor = torch.sub(x, x_adv).to(torch_device)
                            allmin, allmax = torch.aminmax(dim=0, input=alldiffstensor)
                            allmin, allmax = allmin.to(torch_device), allmax.to(torch_device)
                            metric_dict[str(epsilon_val)][metric_id](torch.sub(allmax, allmin))

                        elif metric_id == "Loss":  # Calculate model loss with model criterion
                            lossval = model_crit(y_hat_adv, y)
                            metric_dict[str(epsilon_val)][metric_id](lossval)
                        else:
                            metric_dict[str(epsilon_val)][metric_id](y_hat_adv, y)

                    #  We want to log the first image of the first batch for each attack, with it's details on perturbation etc.
                    if step == 0:
                        log_step = 10 * attack_idx + eps_idx
                        logging_Obj.add_image(tag=f"original", img_tensor=x[0], global_step=log_step)
                        logging_Obj.add_image(tag=f"adv_eps{epsilon_val:.3f}", img_tensor=x_adv[0], global_step=log_step)

                        diff_tensor = torch.sub(x[0], x_adv[0])
                        logging_Obj.add_image(tag=f"diff_eps{epsilon_val:.3f}", img_tensor=rescale(diff_tensor), global_step=log_step)
                        min_val, max_val = torch.aminmax(diff_tensor)
                        perturb_size = max_val.item() - min_val.item()
                        # print(f"_pertsize_of_eps{epsilon_val:.3f} is {perturb_size}")
                        logging_Obj.add_scalar(tag=f"pertsize_of_eps{epsilon_val:.3f}", scalar_value=perturb_size, global_step=log_step)


        # Compound metrics for this validation attack vector
        epsilon_num = 0
        for epsilon_val_str, metric_pair in metric_dict.items():
            metric_num = 0
            metric_f: torchmetrics.Metric
            for metric_id, metric_f in metric_pair.items():
                metric_step = 100 * attack_num_handled + 10 * epsilon_num + metric_num
                if metric_id == "Loss":
                    metric_value = metric_f.compute()
                    metric_val_dict[attack_id][epsilon_val_str][metric_id] = metric_value
                    logging_Obj.add_scalar(tag=metric_id, scalar_value=metric_value, global_step=metric_step)
                elif metric_id == "Confusion_Matrix":
                    metric_value = metric_f.compute()
                   
                    cm_buffer = plot_confusion_matrix(cm=metric_value.detach().cpu().numpy(), eps=f"{(float(epsilon_val_str) * epsilon_mult):.3f}", att_id=attack_id, class_names=classes_list, top_1_acc=metric_val_dict[attack_id][epsilon_val_str]["Top_1_Accuracy"])

                    cm_img = PIL.Image.open(cm_buffer)
                    cm_img = transforms.ToTensor()(cm_img)
                    logging_Obj.add_image(tag=metric_id, img_tensor=cm_img, global_step=metric_step)
                    metric_val_dict[attack_id][epsilon_val_str][metric_id] = metric_value
                    # metric_f.reset()
                else:
                    metric_value = metric_f.compute()
                    logging_Obj.add_scalar(tag=metric_id, scalar_value=metric_value, global_step=metric_step)
                    metric_val_dict[attack_id][epsilon_val_str][metric_id] = metric_value
                    # metric_f.reset()
                metric_num += 1
            epsilon_num +=1
        attack_num_handled += 1
    return metric_val_dict


In [8]:
dict_result = validate_attacks()


Doing val of FGSM-RS, attack step# =  0
Doing val of FGSM-NORS, attack step# =  1
Doing val of PGD4-0.5, attack step# =  2
Doing val of PGD5-0.4, attack step# =  3
Doing val of PGD10-0.1, attack step# =  4
Doing val of PGD10-0.2, attack step# =  5
Doing val of PGD10-0.4, attack step# =  6
Doing val of PGD20-0.1, attack step# =  7
Doing val of PGD20-0.2, attack step# =  8
Doing val of CW_L2_lr=0.01_steps=200_bstep=5, attack step# =  9
Doing val of CW_L2_lr=0.05_steps=100_bstep=5, attack step# =  10
Doing val of CW_L2_lr=0.1_steps=100_bstep=5, attack step# =  11


In [9]:
eps_count = 0
for epsilon_val in epsilons:
    metric_count = 0
    for metric_id, _ in metricsPerAttack.items():
        print(f"Metric step {10*eps_count + metric_count} logs eps {str(epsilon_val)} for metric {metric_id}")
        metric_count += 1
    eps_count += 1

Metric step 0 logs eps 0.001 for metric Loss
Metric step 1 logs eps 0.001 for metric Average_perturbation
Metric step 2 logs eps 0.001 for metric Top_1_Accuracy
Metric step 3 logs eps 0.001 for metric Top_3_Accuracy
Metric step 4 logs eps 0.001 for metric Confusion_Matrix
Metric step 10 logs eps 0.005 for metric Loss
Metric step 11 logs eps 0.005 for metric Average_perturbation
Metric step 12 logs eps 0.005 for metric Top_1_Accuracy
Metric step 13 logs eps 0.005 for metric Top_3_Accuracy
Metric step 14 logs eps 0.005 for metric Confusion_Matrix
Metric step 20 logs eps 0.01 for metric Loss
Metric step 21 logs eps 0.01 for metric Average_perturbation
Metric step 22 logs eps 0.01 for metric Top_1_Accuracy
Metric step 23 logs eps 0.01 for metric Top_3_Accuracy
Metric step 24 logs eps 0.01 for metric Confusion_Matrix
Metric step 30 logs eps 0.03 for metric Loss
Metric step 31 logs eps 0.03 for metric Average_perturbation
Metric step 32 logs eps 0.03 for metric Top_1_Accuracy
Metric step 33 

In [10]:
print("Results:")

for attack_id, att in val_attacks.items():
    print("Attack vector: " , attack_id)

    epsilon_set = epsilons
    for eps in epsilon_set:
        eps_mult = 1
        if "_L2" in attack_id:
            eps_mult = L2_eps_multiplier

        print(f"With epsilon value: {eps * eps_mult:.3f}")
        for metric_id, metric_val in metricsPerAttack.items():
            metric_val = dict_result[attack_id][str(eps)][metric_id].detach().cpu().numpy()
            if metric_id == "Confusion_Matrix":
                print("Confusion matrix:")
                plot_confusion_matrix(metric_val, eps=f"{eps * eps_mult:.3f}", att_id = attack_id)
            else:
                print("Metric: ", metric_id, f" has value: {metric_val:.5f}")

Results:
Attack vector:  FGSM-RS
With epsilon value: 0.001
Metric:  Loss  has value: 0.73643
Metric:  Average_perturbation  has value: 0.00200
Metric:  Top_1_Accuracy  has value: 0.76090
Metric:  Top_3_Accuracy  has value: 0.94500
Confusion matrix:
With epsilon value: 0.005
Metric:  Loss  has value: 1.70092
Metric:  Average_perturbation  has value: 0.01000
Metric:  Top_1_Accuracy  has value: 0.53150
Metric:  Top_3_Accuracy  has value: 0.85890
Confusion matrix:
With epsilon value: 0.010
Metric:  Loss  has value: 3.20288
Metric:  Average_perturbation  has value: 0.02000
Metric:  Top_1_Accuracy  has value: 0.31530
Metric:  Top_3_Accuracy  has value: 0.73500
Confusion matrix:
With epsilon value: 0.030
Metric:  Loss  has value: 8.17782
Metric:  Average_perturbation  has value: 0.06000
Metric:  Top_1_Accuracy  has value: 0.03890
Metric:  Top_3_Accuracy  has value: 0.36540
Confusion matrix:
With epsilon value: 0.100
Metric:  Loss  has value: 15.28640
Metric:  Average_perturbation  has value: 

In [12]:
with open('validation_results/' + model_identifier + '.txt', 'w') as f:
    print(dict_result, file=f)