In [1]:
import numpy as np
from matplotlib import pyplot as plt
from IPython import display

from keras.datasets import cifar10
from keras.models import Sequential, load_model
from keras.layers import Conv2D,Activation,MaxPooling2D,Dropout,Flatten,Dense
from keras.optimizers import SGD, Adam
from keras.utils import to_categorical

import os.path

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, Matern

In [2]:
#load some data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [3]:
# Define a simple CNN
def base_model():

    model = Sequential()
    model.add(Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]))
    model.add(Activation('relu'))
    model.add(Conv2D(32,(3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(Activation('relu'))
    model.add(Conv2D(64, (3,3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10))
    model.add(Activation('softmax'))

    sgd = Adam()

    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    return model

In [None]:
# Load or train cnn for cifar10 classification
if os.path.isfile('cnn.h5'):
    cnn_n = load_model('cnn.h5')
else:
    cnn_n = base_model()
    cnn_n.fit(x_train/255.0,to_categorical(y_train),validation_data=(x_test/255.0,to_categorical(y_test)),epochs=20,batch_size=256)
    cnn_n.save('cnn.h5')

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20

In [None]:
#Sample random test image
idx = 3976#np.random.randint(x_test.shape[0])
im = x_test[idx,:]
plt.figure(figsize=(15,5))
plt.imshow(im)
plt.show()

In [None]:
# Create sampling grid (position, and blanking window size)
x,y,w,h = np.meshgrid(np.linspace(0,32,32),np.linspace(0,32,32),np.linspace(1,8,7),np.linspace(1,8,7))

X = np.vstack((x.ravel(),y.ravel(),w.ravel(),h.ravel())).T

# Storage for samples
Xm = []
fm = []

# First sample
state = X[np.random.randint(0,X.shape[0]),:].astype(int)
Xm.append(state)

fig, ax = plt.subplots(1,3,figsize=(15,5))

base_prob = cnn_n.predict(im.reshape(1,32,32,3)/255.0)[0][y_test[idx]] # Base class probability 
for i in range(150):
    
    # Blank image at sample location
    im_b = im.copy()
    im_b[max(state[0]-state[2],0):min(32,state[0]+state[2]),max(state[1]-state[3],0):min(32,state[1]+state[3])] = 0
    
    # Determine change in logit
    logit = cnn_n.predict(im_b.reshape(1,32,32,3)/255.0)[0][y_test[idx]]
 
    fm.append(base_prob-logit)
    
    # Fit surrgate model
    gp = GaussianProcessRegressor(kernel=Matern(length_scale_bounds=[1,5],length_scale=[3,3,2,2],nu=1.5),alpha=0.05)
    gp.fit(np.array(Xm),np.array(fm))
    
    mu,sig = gp.predict(X,return_std=True)
    
    # Choose next blanking point, trading off exploration and exploitation
    aq_bin = np.argmax(1.5*sig.reshape(-1,1)+mu)
    
    state = X[aq_bin].astype(int)
    
    Xm.append(state+np.random.randn(4,)*[1,1,0.5,0.5]) # Add some jitter so we aren't stuck to grid
    
    ov = mu.reshape(32,32,7,7) # Generate saliency map (change in probability over all params)
    
    # Take expectation over box sizes
    ov = np.mean(ov,axis=-1)
    ov = np.mean(ov,axis=-1)
    
    # Plot every 10th sample
    ax[0].cla()
    ax[0].imshow(im_b)
    ax[0].plot(np.array(Xm)[:-2,1],np.array(Xm)[:-2,0],'*')
    ax[0].plot(np.array(Xm)[-2,1],np.array(Xm)[-2,0],'r*')


    ax[1].cla()
    ax[1].imshow(im)
    ax[1].imshow(ov.T,alpha=0.5)
    ax[1].plot(np.array(Xm)[:-2,1],np.array(Xm)[:-2,0],'*')
    ax[1].plot(np.array(Xm)[-2,1],np.array(Xm)[-2,0],'r*')
    ax[1].set_title('Mean sensitivity')

    sxy = np.sum(sig.reshape(32,32,-1),-1)

    ax[2].cla()
    ot = ax[2].imshow(5*sxy.T+ov.T)
    ax[2].plot(np.array(Xm)[:-2,1],np.array(Xm)[:-2,0],'*')
    ax[2].plot(np.array(Xm)[-2,1],np.array(Xm)[-2,0],'r*')
    ax[2].set_title('Acquisition function')
    display.display(plt.gcf())
    display.clear_output(wait=True)
    