In [None]:
from deepview import DeepView
import numpy as np
import time
import torch
# ---------------------------
import demo_utils as demo

%load_ext autoreload
%autoreload 2
%matplotlib qt

In [None]:
# matplotlib qt seems to be a bit buggy with notebooks, so we execute it multiple times
%matplotlib qt

In [None]:
# load data
# device will be detected automatically
# Set to 'cpu' or 'cuda:0' to set the device manually
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

testset = demo.make_cifar_dataset()
torch_model = demo.create_torch_model(device)


In [None]:
softmax = torch.nn.Softmax(dim=-1)

def pred_wrapper(x):
    with torch.no_grad():
        x = np.array(x, dtype=np.float32)
        tensor = torch.from_numpy(x).to(device)
        logits = torch_model(tensor)
        probabilities = softmax(logits).cpu().numpy()
    return probabilities


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# --- Deep View Parameters ----
batch_size = 128
max_samples = 1000
data_shape = (3, 32, 32)
n = 5
lam = .65
resolution = 100
cmap = 'tab10'
title = 'ResNet-20 - CIFAR10'


deepview = DeepView(pred_wrapper, classes, max_samples, batch_size, 
                    data_shape, n, lam, resolution, cmap, title=title)

In [None]:
# load data
n_samples = 300
# later special points will be added
add_points = 1

np.random.seed(42)
#sample_ids = np.random.choice(len(testset), n_samples)
sample_ids = np.random.permutation(len(testset))[0:n_samples+add_points]
print(sample_ids[0:5])

X = np.array([ testset[i][0].numpy() for i in sample_ids ])
Y = np.array([ testset[i][1] for i in sample_ids ])

print(X.shape)

In [None]:
from compute_fisher_matrix_comps import compute_fisher_matrix_comps
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib as plotlib
import torch.nn.functional as F

mean_ = [0.485, 0.456, 0.406]
std_  = [0.229, 0.224, 0.225]
# range of the data
minV = (0 - mean_[0])/std_[0]
maxV = (1 - mean_[0])/std_[0]

# create an adversarial example and include it into the data set
data      = testset.__getitem__(1) # 1, 7
data_item = torch.zeros([1, data_shape[0], data_shape[1], data_shape[2]]).to(device)
data_item[0,:,:,:] = data[0]
labs_curr = data[1]

print("curr label", labs_curr)

n_data = n_samples
p_cgivenx = torch.zeros([1, len(classes)]).to(device)
used_values = -1 * torch.ones([1, len(classes)+1], dtype=torch.long);
used_values[:, -1] = len(classes) 


data_item.requires_grad = True
output                  = torch_model(data_item)
p_cgivenx[-1,:]         = softmax(output.detach())
# compute derivative wrt the input
loss = F.cross_entropy(output, torch.tensor([labs_curr]).to(device))
torch_model.zero_grad()
loss.backward(retain_graph=True)
data_grad = data_item.grad.data


# print true label, predicted label and certainty
print("all predictions ", p_cgivenx) 
print("True Lab / cert:", classes[labs_curr], "/", p_cgivenx[-1,labs_curr].item()) 
val,idx = torch.max(p_cgivenx, 1)
print("Pred Lab / cert:", classes[idx.item()], "/", p_cgivenx[-1,idx.item()].item()) 

# plot the adversarial example
fig, axes = plt.subplots(1,4, figsize=(8, 6))
image = (data_item[0,:,:,:].detach() * std_[0] + mean_[0]).cpu().numpy().transpose([1, 2, 0])
axes[0].imshow(image)
image = (data_grad[-1].detach() * std_[0] + mean_[0]).cpu().numpy().transpose([1, 2, 0])
axes[1].imshow(image)
image = (data_grad[-1].sign().detach() * std_[0] + mean_[0]).cpu().numpy().transpose([1, 2, 0])
axes[2].imshow(image)
axes[0].axis('off'), axes[1].axis('off'), axes[2].axis('off'), axes[3].axis('off')
#plt.show()

# perform the pertubation
perturbed_dat = data_item[0,:,:,:] + 0.1*(data_grad[-1].sign())
perturbed_dat = torch.clamp(perturbed_dat, minV, maxV)
# check the classification
tmp                 = torch.zeros([1, data_shape[0], data_shape[1], data_shape[2]]).to(device)
tmp[:,:,:,:]        = perturbed_dat
output              = torch_model(tmp)
p_cgivenx[-1,:]     = softmax(output.detach())

image = (perturbed_dat.detach() * std_[0] + mean_[0]).cpu().numpy().transpose([1, 2, 0])
axes[3].imshow(image)
plt.show()

print("all predictions adv", p_cgivenx)
val,idx = torch.sort(p_cgivenx[0], descending=True)
print("indices after sorting", idx)
print("Top predictions are", classes[idx[0].item()] + ", " + 
      classes[idx[1].item()] + ", " + classes[idx[2].item()])


X[-1,:,:,:] = tmp.detach().cpu().numpy()
Y[-1] = labs_curr


In [None]:
#import deepview.embeddings as embeddings

np.random.seed(42)
rseed = round(np.random.rand() * 1e10)

umapParms = {
    "random_state": rseed,
    "n_neighbors": 30,
    "spread": 1,
    "min_dist": 0.1,
}
#     "verbose": True
deepview._init_mappers(None, None, umapParms)

t0 = time.time()

deepview.add_samples(X, Y)


#deepview.resolution = 200
#deepview.update_mappings()

deepview.show()

print('Time to calculate visualization for %d samples: %.2f sec' % (n_samples, time.time() - t0))

In [None]:
# or in more pretty ...
deepview.resolution = 200
deepview.update_mappings()

deepview.show()

In [None]:
# inspect how far the adv example is away from the other points according to different metrics
idx_plt = -1 # adv point

#fish_dists_orig[-1, :] or eucl_dists or fish_dists
n_classes = len(classes)
msize_s = 8
cmap = "tab10"
cmap = plt.get_cmap(cmap)

from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import floyd_warshall
fish_dists_spaths = csr_matrix(deepview.distances) #discr_distances
fish_dists_spaths = floyd_warshall(csgraph=fish_dists_spaths, directed=False, unweighted=False)

fig, axes = plt.subplots(1,4, figsize=(14, 8))
for c in range(n_classes):
    axes[0].plot(deepview.eucl_distances[idx_plt, Y==c], c*np.ones([sum(Y==c)]), 
                 'o', c=cmap(c/(n_classes-1)), label=classes[c], markersize=msize_s)
    axes[1].plot(deepview.discr_distances[idx_plt, Y==c], c*np.ones([sum(Y==c)]), 
                 'o', c=cmap(c/(n_classes-1)), label=classes[c], markersize=msize_s)
    axes[2].plot(deepview.distances[idx_plt, Y==c], c*np.ones([sum(Y==c)]), 
                 'o', c=cmap(c/(n_classes-1)), label=classes[c], markersize=msize_s)
    axes[3].plot(fish_dists_spaths[idx_plt, Y==c], c*np.ones([sum(Y==c)]), 
                 'o', c=cmap(c/(n_classes-1)), label=classes[c], markersize=msize_s)
    

axes[0].set_title('euclidean dist from adv', fontsize=12)
axes[1].set_title('fisher dist from adv', fontsize=12)
axes[2].set_title('regul fisher dist from adv', fontsize=12)
axes[3].set_title('shortest paths fisher dist', fontsize=12)

#labs_plot