In [1]:
import os
import sys
import pickle

# from collections import Counter

# import matplotlib.pyplot   as plt
# import matplotlib.gridspec as gridspec
# from tqdm import tqdm, tqdm_notebook

import numpy as np

# from sklearn.metrics.pairwise import cosine_distances
# from sklearn.neighbors        import NearestNeighbors

import torch
import torch.nn            as nn
import torch.nn.functional as F
import torch.optim         as optim
# from torch.utils.data.dataset import random_split
# from torchvision              import datasets 
# from torchvision              import transforms

from DkNN import CKNN
import dataset_input
import utilities
from cifar_model import CNN
from attack import PGD



def feature_space(cnnmod, num_rep, data, label, device):
    print('Building the feature spaces from the selected set.')

    conv_features = [[] for _ in range(num_rep)]
    targets       = []
    predictions   = []
    print('\tRunning predictions')
    cnnmod.eval()
    data = data.to(device)
    *out_convs, y_pred = cnnmod(data)
    for i, out_conv in enumerate(out_convs):
        conv_feat = out_conv.view(out_conv.size(0), -1).cpu().detach().numpy()
        conv_features[i].append(conv_feat)
    targets.append(label.numpy())
    predictions.append(y_pred.cpu().detach().numpy())
    print('\tConcatenating results')
    conv_features = [np.concatenate(out_convs) for out_convs in conv_features]
    targets       = np.concatenate(targets)
    predictions   = np.concatenate(predictions, axis = 0)

    return conv_features, targets, predictions




device = torch.device('cuda:5')
config = utilities.config_to_namedtuple(utilities.get_config('config_cifar.json'))
dataset = dataset_input.CIFAR10Data(config, seed=config.training.np_random_seed)
filename = 'models/cifarmodel.pt'
model = CNN().to(device)

if os.path.isfile(filename):
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['state_dict'])
#     print("=> loaded checkpoint '{}' (epoch {})"
#               .format(filename, checkpoint['state_dict']))
else:
    print("=> no checkpoint found at '{}'".format(filename))



#training data: select 49000 data points from 49000 data points
batch = 49000
x_batch, y_batch = dataset.train_data.get_next_batch(batch,
                                                         multiple_passes=True)
x_batch = x_batch / 255.0

x_batch_train = torch.from_numpy(x_batch.astype(np.float32).transpose((0, 3, 1, 2)))
y_batch_train = torch.from_numpy(y_batch.astype(np.int64))

conv_train, _, pred_train = feature_space(model, 4, x_batch_train, y_batch_train, device)




#calibrate data: select 1000 data points from 1000 data points
batch_cali = 1000
x_batch, y_batch = dataset.cali_data.get_next_batch(batch_cali,
                                                         multiple_passes=True)
x_batch = x_batch / 255.0

x_batch_cali = torch.from_numpy(x_batch.astype(np.float32).transpose((0, 3, 1, 2)))
y_batch_cali = torch.from_numpy(y_batch.astype(np.int64))


#build a calibration class
class Calibration():
    def __init__(self, x_cali, y_cali):
        self.x = x_cali
        self.y = y_cali
        self.n_sample = len(y_cali)
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    def __len__(self):
        return self.n_sample


calib_dataset = Calibration(x_batch_cali, y_batch_cali)
model.eval()    
#Obtain the features of the calibrate dataset
#conv_cali, _, pred_cali = feature_space(model, 4, x_batch_cali, y_batch_cali, device)

batch_size = 1000
deep_knn = CKNN(
    model         = model, 
    device        = device, 
    train_conv = conv_train,
    y_train = y_batch_train,
    calib_dataset = calib_dataset,
    batch_size    = batch_size,
    n_neighbors   = 3,
    n_embs        = 4 
)

=> loading checkpoint 'models/cifarmodel.pt'
Building the feature spaces from the selected set.
	Running predictions
	Concatenating results
Building Nearest Neighbor finders.
Building calibration set.


In [2]:
#eval data: select 10000 data points from 10000 data points
batch_eval = 10000
x_batch, y_batch = dataset.eval_data.get_next_batch(batch_eval,
                                                         multiple_passes=True)
x_batch = x_batch / 255.0

x_batch_eval = torch.from_numpy(x_batch.astype(np.float32).transpose((0, 3, 1, 2)))
y_batch_eval = torch.from_numpy(y_batch.astype(np.int64))

y_pred, confidence, credibility = deep_knn.predict(x_batch_eval)

In [3]:
print('Accuracy:                                  \t', (y_batch_eval.cpu().numpy() == y_pred).mean())
print('confidence for correct predictions:\t', confidence[y_batch_eval.cpu().numpy() == y_pred].mean())
print('credibility for correct predictions:\t', credibility[y_batch_eval.cpu().numpy() == y_pred].mean())

Accuracy:                                  	 0.5602
confidence for correct predictions:	 0.8143416636915389
credibility for correct predictions:	 0.9230003570153515


In [4]:
eps, step = (2.0,10)
at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step)
pois_x_batch_eval = at.attack(model, x_batch_eval.to(device), y_batch_eval.to(device))
y_pred, confidence, credibility = deep_knn.predict(pois_x_batch_eval)

print('adversarial accuracy:                                  \t', (y_batch_eval.cpu().numpy() == y_pred).mean())
print('confidence for correct predictions:\t', confidence[y_batch_eval.cpu().numpy() == y_pred].mean())
print('credibility for correct predictions:\t', credibility[y_batch_eval.cpu().numpy() == y_pred].mean())

adversarial accuracy:                                  	 0.54
confidence for correct predictions:	 0.8112348148148147
credibility for correct predictions:	 0.9206657407407406


In [None]:
eps, step = (8.0,10)
at = PGD(eps=eps / 255.0, sigma=2 / 255.0, nb_iter=step)
pois_x_batch_eval = at.attack(model, x_batch_eval.to(device), y_batch_eval.to(device))
y_pred, confidence, credibility = deep_knn.predict(pois_x_batch_eval)

print('adversarial accuracy:                                  \t', (y_batch_eval.cpu().numpy() == y_pred).mean())
print('confidence for correct predictions:\t', confidence[y_batch_eval.cpu().numpy() == y_pred].mean())
print('credibility for correct predictions:\t', credibility[y_batch_eval.cpu().numpy() == y_pred].mean())