In [1]:
import numpy as np
import tensorflow as tf
import os
import matplotlib.pyplot as plt
%matplotlib inline
import _pickle
import time

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# from alexnet_backprop import *
from alexnet_guided_bp_vanilla import *
from utils import *

In [4]:
# load test set images
test_images = open('test_set_17_227.pkl', "rb")
data_set = _pickle.load(test_images) 
actor_code = get_actor_code(data_set)[0]
reversed_actor_code = get_actor_code(data_set)[1]


In [5]:
def generate_true_class_saliency(data_set, graph, sess, fname):
    
    # save saliencies in an external array
    saliency_masks = {}
    for actor in data_set:
        saliency_masks[actor] = np.zeros((0,227,227))
    
    
    plt.figure(dpi=70, figsize=(140, 100))
    col=0
    for actor in data_set:
        row = 0 
        for _ in range(len(data_set[actor])):
            
            # forward pass
            image = data_set[actor][_].astype(np.float64)
            image_feed = np.expand_dims(image,0)
            probabilities = sess.run(graph.probabilities, feed_dict={graph.inputs:image_feed})[0]
    
            # plot original image
            index = row*len(data_set)*2 + col +1
            plt.subplots_adjust(bottom = 0., wspace=0.15, hspace = 0.1, top=0.5)
            plt.subplot(len(data_set[actor]), len(data_set)*2, index)
            plt.axis('off')
            plt.title("True: " + actor + "\n Pred: " + actor_code[np.argmax(probabilities)], fontsize=14)
            plt.imshow(image.astype('uint8'))
    
            # plot true class_saliency
            i = reversed_actor_code[actor]
            one_hot = np.zeros(len(actor_code))
            one_hot[i] = 1
            plt.subplot(len(data_set[actor]), len(data_set)*2, index+1)
            plt.axis('off')
            plt.title(actor + '\n' + str(probabilities[i]), fontsize=14)
            saliency = guided_backprop(graph, image, one_hot, sess)
            saliency = np.sum(saliency, axis=-1)
            saliency_rank = saliency.ravel().argsort().argsort().reshape(saliency.shape)
            saliency_mask = (saliency_rank > 227*227*0.9).astype(float)
            
            plt.imshow(saliency_mask)

            saliency_masks[actor] = np.vstack((saliency_masks[actor], np.expand_dims(saliency_mask, axis=0)))
            
            row += 1
        col += 2
    
    folder = 'saliency masks/'
    plt.savefig(folder + fname +'.png', bbox_inches='tight')
    plt.close()
    
    pickle_out = open(folder + fname +".pkl","wb")
    _pickle.dump(saliency_masks, pickle_out)

In [6]:
# hyperparameters and backprop variants
tau = 1.0

In [7]:
tf.reset_default_graph()
b_graph = backprop_graph(17, 100, alexnet_face_classifier)
b_graph.classifier_graph(temp=tau)
b_graph.guided_backprop_graph()

In [8]:
for _ in range(2, 12):
    # load weights
    weight_fpath = 'e2e weights/'
    weight_fname = 'end_to_end_learning_weights_17_trial{}.pkl'.format(_) 
    weight_fname = weight_fpath + weight_fname
    with tf.Session() as sess:
        b_graph.cnn.load_weights(weight_fname, sess)
        generate_true_class_saliency(data_set, b_graph, sess, 'true_class_saliency_mask_trial{}_10'.format(_))

0 conv1W (11, 11, 3, 96)
1 conv1b (96,)
2 conv2W (5, 5, 48, 256)
3 conv2b (256,)
4 conv3W (3, 3, 256, 384)
5 conv3b (384,)
6 conv4W (3, 3, 192, 384)
7 conv4b (384,)
8 conv5W (3, 3, 192, 256)
9 conv5b (256,)
10 fc1W (43264, 100)
11 fc1b (100,)
12 fc2W (100, 17)
13 fc2b (17,)
0 conv1W (11, 11, 3, 96)
1 conv1b (96,)
2 conv2W (5, 5, 48, 256)
3 conv2b (256,)
4 conv3W (3, 3, 256, 384)
5 conv3b (384,)
6 conv4W (3, 3, 192, 384)
7 conv4b (384,)
8 conv5W (3, 3, 192, 256)
9 conv5b (256,)
10 fc1W (43264, 100)
11 fc1b (100,)
12 fc2W (100, 17)
13 fc2b (17,)
0 conv1W (11, 11, 3, 96)
1 conv1b (96,)
2 conv2W (5, 5, 48, 256)
3 conv2b (256,)
4 conv3W (3, 3, 256, 384)
5 conv3b (384,)
6 conv4W (3, 3, 192, 384)
7 conv4b (384,)
8 conv5W (3, 3, 192, 256)
9 conv5b (256,)
10 fc1W (43264, 100)
11 fc1b (100,)
12 fc2W (100, 17)
13 fc2b (17,)
0 conv1W (11, 11, 3, 96)
1 conv1b (96,)
2 conv2W (5, 5, 48, 256)
3 conv2b (256,)
4 conv3W (3, 3, 256, 384)
5 conv3b (384,)
6 conv4W (3, 3, 192, 384)
7 conv4b (384,)
8 conv5W (