In [None]:
from deepview import DeepView
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
# ---------------------------
import demo_utils as demo

%load_ext autoreload
%autoreload 2
%matplotlib qt

In [None]:
# matplotlib qt seems to be a bit buggy with notebooks, so we execute it multiple times
%matplotlib qt

In [None]:
# device will be detected automatically
# Set to 'cpu' or 'cuda:0' to set the device manually
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

testset = demo.make_cifar_dataset()
torch_model = demo.create_torch_model(device)


In [None]:
# softmax operation to use in pred_wrapper
softmax = torch.nn.Softmax(dim=-1)

# this is the prediction wrapper, it encapsulates the call to the model
# and does all the casting to the appropriate datatypes
def pred_wrapper(x):
    with torch.no_grad():
        x = np.array(x, dtype=np.float32)
        tensor = torch.from_numpy(x).to(device)
        logits = torch_model(tensor)
        probabilities = softmax(logits).cpu().numpy()
    return probabilities

def visualization(image, point2d, pred, label=None, title=None):
    f, a = plt.subplots()
    a.set_title(title)
    a.imshow(image.transpose([1, 2, 0]))

# the classes in the dataset to be used as labels in the plots
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# --- Deep View Parameters ----
batch_size = 512
max_samples = 1000
data_shape = (3, 32, 32)
n = 5
lam = .65
resolution = 100
cmap = 'tab10'
title = 'ResNet-20 - CIFAR10'

deepview = DeepView(pred_wrapper, classes, max_samples, batch_size, 
                    data_shape, n, lam, resolution, cmap, title=title, data_viz=visualization)

umapParms = {
    "random_state": 42*42,
    "n_neighbors": 30,
    "spread": 1,
    "min_dist": 0.1,
    "a": 600
}
#"verbose": True,
deepview._init_mappers(None, None, umapParms)


In [None]:
# select random points and visualize them together with the classifier
n_samples = 300
sample_ids = np.random.choice(len(testset), n_samples)
X = np.array([ testset[i][0].numpy() for i in sample_ids ])
Y = np.array([ testset[i][1] for i in sample_ids ])

t0 = time.time()
deepview.add_samples(X, Y)
deepview.show()


print('Time to calculate visualization for %d samples: %.2f sec' % (n_samples, time.time() - t0))

In [None]:
# choose additional points from one class and add these to deepView
n_samples = 50
# go through the data set and select the first data points with label pick_l
pick_l = 0
i = 0
count = 0
X = np.empty([n_samples, data_shape[0], data_shape[1], data_shape[2]])
Y = np.empty([n_samples])
while (count < n_samples):
    if testset[i][1] == 0:
        #print(i)
        X[count,:,:,:] = testset[i][0] #*0.2 # simulate darker
        X[count,0,:,:] = X[count,0,:,:] + 6 # simulate sunset
        Y[count] = testset[i][1]
        count += 1
    i += 1

t0 = time.time()
deepview.resolution = 200
deepview.add_samples(X, Y)
deepview.show()


print('Time to calculate visualization for %d samples: %.2f sec' % (n_samples, time.time() - t0))

In [None]:
deepview.resolution = 200
deepview.update_mappings()
deepview.show()


In [None]:

#X.max()

f, a = plt.subplots()
#curr_img = testset[10][0].numpy().transpose([1, 2, 0])
curr_img = X[1].copy().transpose([1, 2, 0])
print(curr_img.shape)
curr_img[:,:,0] = curr_img[:,:,0]
curr_img = curr_img - curr_img.min()
curr_img = curr_img/curr_img.max()
a.imshow(curr_img)#, vmin = curr_img.min()*1.2,vmax=curr_img.max()*1.2)#, vmin=-2.1, vmax=2.6)
