<a href="https://colab.research.google.com/github/nassma2019/PracticalSessions/blob/master/vision/visualisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Part II: Visualise saliency maps
- Import an already trained baseline model.
- Visualise the gradients of class probabilities w.r.t inputs to obtain saliency maps.
- Generate inputs that maximise class probabilities.

#### Exercises:

1. Retrieve the gradient of the most probable class w.r.t. to input image using `tf.gradients` and plot saliency maps.
2. Iterate the above and take steps into the direction of this gradient starting from a test image.

>*  The gradient indicates how to modify the input image to make it look more like the class it is taken from, according to the network.
>* Note that the network weights are kept fixed, only the input is transformed, i.e. we retrieve gradients, but we never apply them to the network weights.
 


### Imports

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import time

import tensorflow as tf

# Don't forget to select GPU runtime environment in Runtime -> Change runtime type
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

# we will use Sonnet on top of TF 
!pip install -q dm-sonnet
import sonnet as snt

import numpy as np

# Plotting library.
from matplotlib import pyplot as plt
import pylab as pl
from IPython import display
from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean

In [0]:
# Reset graph
tf.reset_default_graph()

In [0]:
# Display function
class_mapping = [u'airplane', u'automobile', u'bird', u'cat', u'deer', 
                 u'dog', u'frog', u'horse', u'ship', u'truck']
def gallery(maps, imgs, pclass, gt, scale=4.0):
  num_images= maps.shape[0]
  maps = np.abs(maps).mean(axis=-1)
  ff, axes = plt.subplots(2, num_images,
                          subplot_kw={'xticks': [], 
                                      'yticks': []})
  for i in range(0, num_images):
    tt_pred = class_mapping[pclass[i]]
    tt_gt = class_mapping[gt[i]]
    mm = maps[i]/np.amax(maps[i])
    mm_rescale = rescale(mm, scale)                         
    axes[0,i].imshow(mm_rescale)
    img = (imgs[i]+1.0)/2.0
    img_rescale = rescale(img, scale)
    axes[1,i].imshow(img_rescale)
    plt.setp(axes[0,i].get_xticklabels(), visible=False)
    plt.setp(axes[0,i].get_yticklabels(), visible=False)
    axes[0,i].set_title('pred={}'.format(tt_pred))
    axes[1,i].set_title('gt={}'.format(tt_gt))
  plt.show()

### Copying the pretrained weights of baseline model on the virtual machine
- we download all three files to the Colab virtual machine:
- we will load a model with the same architecture that you defined earlier, but fully trained.

In [0]:
!wget https://github.com/nassma2019/PracticalSessions/blob/master/vision/baseline/baseline.ckpt.data-00000-of-00001?raw=true -O baseline.ckpt.data-00000-of-00001
!wget https://github.com/nassma2019/PracticalSessions/blob/master/vision/baseline/baseline.ckpt.index?raw=true -O baseline.ckpt.index
!wget https://github.com/nassma2019/PracticalSessions/blob/master/vision/baseline/checkpoint?raw=true -O checkpoint

### Get dataset to be used for visualisation
- Cifar-10 equivalent of MNIST for natural RGB images
- 60000 32x32 colour images in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- train: 50000; test: 10000

In [0]:
cifar10 = tf.keras.datasets.cifar10
# (down)load dataset
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

### Retrieve batches from the test set

In [0]:
# define dimension of the batches to sample from the datasets
BATCH_SIZE_TEST = 5 #@param

In [0]:
dataset_test = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
batched_dataset_test = dataset_test.repeat().batch(BATCH_SIZE_TEST)
iterator_test = batched_dataset_test.make_one_shot_iterator() 
(batch_test_images, batch_test_labels) = iterator_test.get_next()

### Model on which we will run the visualisation

In [0]:
class Baseline(snt.AbstractModule):
  
  def __init__(self, num_classes, name="baseline"):
    super(Baseline, self).__init__(name=name)
    self._num_classes = num_classes
    self._output_channels = [
        64, 64, 128, 128, 128, 256, 256, 256, 512, 512, 512
        ]
    self._num_layers = len(self._output_channels)

    self._kernel_shapes = [[3, 3]] * self._num_layers  # All kernels are 3x3.
    self._strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1]
    self._paddings = [snt.SAME] * self._num_layers
   
  def _build(self, inputs, is_training=None, test_local_stats=False):
    net = inputs
    # instantiate all the convolutional layers 
    layers = [snt.Conv2D(name="conv_2d_{}".format(i),
                         output_channels=self._output_channels[i],
                         kernel_shape=self._kernel_shapes[i],
                         stride=self._strides[i],
                         padding=self._paddings[i],
                         use_bias=True) for i in xrange(self._num_layers)]
    # connect them to the graph, adding batch norm and non-linearity
    for i, layer in enumerate(layers):
      net = layer(net)
      bn = snt.BatchNorm(name="batch_norm_{}".format(i))
      net = bn(net, is_training=is_training, test_local_stats=test_local_stats)
      net = tf.nn.relu(net)

    net = tf.reduce_mean(net, reduction_indices=[1, 2], keepdims=False,
                         name="avg_pool")

    logits = snt.Linear(self._num_classes)(net)

    return logits

In [0]:
num_classes = 10

In [0]:
# Test preprocessing: only scale to [-1,1].
def test_image_preprocess():
  def fn(image):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = image * 2.0 - 1.0
    return image
  return fn

In [0]:
# Instantiate the model 
with tf.variable_scope("baseline"):
  model = Baseline(num_classes)

In [0]:
# Connect the model to data
preprocess_op = test_image_preprocess()
batch_test_images = preprocess_op(batch_test_images)
test_predictions = model(batch_test_images, is_training=False)

In [0]:
# Create saver to restore the pre-trained model
# First remove the scope name from variables name, since the name in the checkpoint doesn't include it
var_list = snt.get_variables_in_scope("baseline", 
                                      collection=tf.GraphKeys.GLOBAL_VARIABLES)  
var_map = {}
for i in range(0, len(var_list)):
  name = var_list[i].name[len("baseline/"):-2]
  var_map[name] = var_list[i]
  
saver = tf.train.Saver(var_map, reshape=True)

In [0]:
# For evaluation, we look at top_k_accuracy since it's easier to interpret; normally k=1 or k=5
def top_k_accuracy(k, labels, logits):
  in_top_k = tf.nn.in_top_k(predictions=tf.squeeze(logits), 
                            targets=tf.squeeze(tf.cast(labels, tf.int32)), k=k)
  return tf.reduce_mean(tf.cast(in_top_k, tf.float32))

In [0]:
test_acc = top_k_accuracy(1, batch_test_labels, test_predictions)

### Visualise saliency maps

- We retrieve gradients w.r.t. inputs to obtain a saliency map over the input pixels, i.e. to understand which pixels in an image caused a certain output logit to be maximised.


In [0]:
#@title Exercise.
# Get the maximum output prediction
# maximum_prediction =  ############## YOUR CODE ##############

# Get the gradient w.r.t. input images
# saliency_op = ############## YOUR CODE ##############

In [0]:
#@title Solution.
# Get the maximum output prediction
maximum_prediction = tf.reduce_max(test_predictions, 1)

# Get the gradient w.r.t. input images
saliency_op = tf.gradients(maximum_prediction, batch_test_images)[:][0]

In [0]:
#@title Exercise.
# Get the predicted class index for visualisation purposes.
# pred_class_op = ############## YOUR CODE ##############

In [0]:
#@title Solution.
# Get the predicted class index for visualisation purposes.
pred_class_op = tf.argmax(test_predictions, axis=-1)

In [0]:
# Create the session and initialize variables
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [0]:
# Restore pre-trained weights
saver.restore(sess, "baseline.ckpt")

In [0]:
# Check if import was done correctly by running eval on cifar test set
# expected_accuracy = 0.94
num_batches = 1000  # 1000 batches * 5 samples per batch = 5000
avg_accuracy = 0.0
for _ in range(num_batches):
  accuracy = sess.run(test_acc)
  avg_accuracy += accuracy
avg_accuracy /= num_batches

print ("Accuracy {:.3f}".format(avg_accuracy))

In [0]:
# Get saliency maps
smap, inp_img, predicted_class, ground_truth = sess.run(
    [saliency_op, batch_test_images, 
     pred_class_op, tf.squeeze(batch_test_labels)])

# Display 
gallery(smap, inp_img, predicted_class, ground_truth)


### Not that impressive, right?

### Let's generate the image that maximises the probability of a given class $c$

The previous exercise computed
$$
\frac{\partial y_{c}}{\partial x}
$$

Now we modify $x$ to search for $\hat x$ that maximises $\frac{\partial y_{c}}{\partial x}$ using an iterative gradient-descent like approach:

$$
x_{t+1} = \min(1, \max(-1, x_t + \alpha \frac{\partial y_{c}}{\partial x})), t \in \{0, N\}
$$
$$
x_0 = \text{initial test image from class } c 
$$

Use e.g. $\alpha = 0.1$ and $N=10000$.

In [0]:
#@title Exercise.
alpha = 0.1
N = 10000

# get saliency maps
smap, inp_img, predicted_class, ground_truth = sess.run(
      [saliency_op, batch_test_images, 
       pred_class_op, tf.squeeze(batch_test_labels)])

for t in range(N):
  #############
  #           #
  # YOUR CODE #
  #           #
  #############
  
  # display transformed input image at every 1000 iterations
  if t % 1000 == 0:
    print ('Transformed input at iter {0:5d} out of {1:5d}'.format(int(t), int(N)))
    gallery(smap, inp_img, predicted_class, ground_truth)
  

In [0]:
#@title Solution.
alpha = 0.1
N = 10000

# get saliency maps
smap, inp_img, predicted_class, ground_truth = sess.run(
      [saliency_op, batch_test_images, 
       pred_class_op, tf.squeeze(batch_test_labels)])

for t in range(N):
  inp_img = inp_img + alpha * smap
  inp_img = np.minimum(1, np.maximum(-1, inp_img))
  
  smap = sess.run(saliency_op, 
                  feed_dict={batch_test_images: inp_img})
  # display transformed input image at every 1000 iterations
  if t % 1000 == 0:
    print ('Transformed input at iter {0:5d} out of {1:5d}'.format(int(t), int(N)))
    gallery(smap, inp_img, predicted_class, ground_truth)
  