# Visualising CNNs

### Introduction to Pretrained PyTorch models 
### Based on lecture by Dr. Antonin Vacheret


<hr style="border:2px solid gray">

## Index: <a id='index'></a>
1. [Pre-trained Legacy computer vision classifier models](#PTL)
1. [AlexNet Model](#ANM)
1. [Resnet 101](#101)
1. [Convolution](#LTA)
1. [Visualise CNN](#CNN)


<hr style="border:2px solid gray">
A quick run through some basics of pyTorch starting from a quick exploration of the models readily available

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

import torch
import torch.nn as nn
torch.version.__version__

<hr style="border:2px solid gray">

## I. Pre-trained Legacy computer vision classifier models [^](#index)
<a id='PCL'></a>

In [None]:
from torchvision import models
dir(models)

This is the famous AlexNet [^](#index) <a id='ANM'></a> model that shaked the field of machine learning in 2012:
https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf

Note: the lowercase models have fixed 

In [None]:
alexnet_function = models.AlexNet() # this is the "empty shell" of Alexnet
alexnet_trained = models.alexnet(pretrained=True) # fixed artchitecture already pretrained

This one is **Resnet 101** [^](#index) <a id='101'></a> which stands for **residual network**. This one is the 101 layer version.
https://arxiv.org/abs/1512.03385
It has beaten several benchmark in 2015 and started the deep learning revolution. It is trained on imagenet with 1.2M images on 1000 categories.


In [None]:
resnet = models.resnet101(pretrained=True) # beware this is taking on average a few mins to download

<hr style="border:2px solid gray">

## Convolution [^](#index)
<a id='LTA'></a>

*From homl...*

Convolutional neural networks (CNNs) emerged from the study of the brain’s visual cortex, and they have been used in image recognition since the 1980s. In the last few years, thanks to the increase in computational power, the amount of available training data for training deep nets, CNNs have man‐ aged to achieve superhuman performance on some complex visual tasks. They power image search services, self-driving cars, automatic video classification systems, and more. Moreover, CNNs are not restricted to visual perception: they are also successful at many other tasks, such as voice recognition or natural language processing (NLP); however, we will focus on visual applications for now.

In this chapter we will present where CNNs came from, what their building blocks look like, and how to implement them using TensorFlow and Keras. Then we will dis‐ cuss some of the best CNN architectures, and discuss other visual tasks, including object detection (classifying multiple objects in an image and placing bounding boxes around them) and semantic segmentation (classifying each pixel according to the class of the object it belongs to).

Let's take a look at a high def picture of a dog. You can replace this one with your prefered one.

In [None]:
from PIL import Image
img = Image.open("img/mydoge.jpg")

In [None]:
img

Importing high-definition image from img folder but now defining some **transformation** first (a very powerful feature of pytorch!) to preprocess the image and get the right input size for the network.

In [None]:
from torchvision import transforms
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )])

In [None]:
img_t = preprocess(img)

In [None]:
img_t

In [None]:
plt.imshow(img_t[2,:,:])

In [None]:
batch_t = torch.unsqueeze(img_t, 0)
batch_t

In [None]:
resnet.eval() # putting the model in inference mode (no training of the weights) 

In [None]:
out = resnet(batch_t)
out

In [None]:
scores  = out.detach().numpy()
plt.plot(scores[0])
plt.show()

#### Now an operation involving a massive 44.5M parameters has just taken place !
This has produced a vector of a 1000 score, one for each label of the imagenet training set. Let's get the file that has the imagenet list of labels.

We need now to figure out what was the ranking for our dog picture. 

In [None]:
with open('data/data/imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]
labels

In [None]:
_, index = torch.max(out, 1) # this returns the value and index of the higest score
print(index)

**Resnet** gives us a score, but what we are interested in is more something like a the probability of being of a certain category. We will use the **softmax function** for that (multi-class classifier). 

In [None]:
percentage = torch.nn.functional.softmax(out, dim=1)[0] # only one dimension, [0] is to return one value.
percentage

In [None]:
labels[index[0]], percentage[index[0]].item() 

Exercises:

* Sort the output so the five highest probabilities come out from the resnet outpout
    
* Dowload alexnet and look at the output for our dog image. Which model is best ?

<hr style="border:2px solid gray">

## I-a. Visualize CNNs [^](#index)
<a id='CNN'></a>

The hidden `Conv2d` layers are able to extracted by the following part. They and Their weights are stored in `conv_layers` and `weights`. Most of the `Conv2d` layers are contained in `Sequential`, those layers are extracted as `grandchildren`.
If you want to check outputs from `MaxPool2d`, plase replace `nn.Conv2d` to `nn.MaxPool2d`.

In [None]:
#weights = []
conv_layers = []
maxpooling_layers = []
resnet_children=list(resnet.children())
for children in resnet_children:
    if (type(children) == nn.Conv2d) or (type(children) == nn.MaxPool2d):
        #print(children)
        #weights.append(children.weight)
        conv_layers.append(children)
    elif type(children) == nn.Sequential:
        for gen in list(children):
            for grandchildren in list(gen.children()):
                if (type(grandchildren) == nn.Conv2d) or (type(grandchildren) == nn.MaxPool2d):
                    #print(type(grandchildren))
                    #weights.append(grandchildren.weight)
                    conv_layers.append(grandchildren)
#print('len(weights):', len(weights))
print('len(conv_layers):', len(conv_layers))

The feature maps for `batch_t` are obtained in the following part. The outputs from each `Conv2d` layer are stored in `outputs_from_layer`.

In [None]:
outputs_from_layer = []
img_from_prev_layer = batch_t # a tensor containing a batch of image data

for layer in conv_layers:
    img_from_prev_layer = layer(img_from_prev_layer)
    outputs_from_layer.append(img_from_prev_layer)

The followings are example of the visualized feature maps.
   * 1st Conv2d layer
      * All 64 filters. 
      * The most active filter and the least active filter.
   * 50th Conv2d layer
      * Picked up 64 filters. 
      * The most active filter and the least active filter.
   * 99th (last) Conv2d layer
      * Picked up 64 filters. 
      * The most active filter and the least active filter.

In [None]:
# Feature maps of the first Conv2d layer
# There are 64 filters

layer_number = 0
feature_maps = outputs_from_layer[layer_number].detach().numpy()
figs, axes = plt.subplots(8, 8, figsize=[16,16])
for i in range(feature_maps.shape[1]):
    feature_map = feature_maps[0,i,:,:]
    axes[int(i/8), int(i%8)].set_title('idx: {0}'.format(i))
    axes[int(i/8), int(i%8)].imshow(feature_map)

plt.tight_layout()
figs.show()

In [None]:
# Find the most active filter in the first Conv2d layer
layer_number = 0
fmaps = outputs_from_layer[layer_number].detach().numpy()

output_from_filters = fmaps.sum(axis=3).sum(axis=2)
idx_max = output_from_filters.argmax()
max = output_from_filters.max()
idx_min = output_from_filters.argmin()
min = output_from_filters.min()

print('Max, idx: ', max, idx_max)
print('Min, idx: ', min, idx_min)

img_max = fmaps[0,idx_max,:,:]
img_min = fmaps[0,idx_min,:,:]

figs, axes = plt.subplots(1,2, figsize=[8,16])
axes[0].set_title('Max, idx {0}'.format(idx_max))
axes[0].imshow(img_max)
axes[1].set_title('Min, idx {0}'.format(idx_min))
axes[1].imshow(img_min)
figs.show()

In [None]:
# Feature maps of the 20th Conv2d layer
# Pickup 128 filters

pickup_idx = [2*x for x in range(64)]
layer_number = 20
feature_maps = outputs_from_layer[layer_number].detach().numpy()
figs, axes = plt.subplots(8, 8, figsize=[16,16])
for i in pickup_idx:
    feature_map = feature_maps[0,i,:,:]
    axes[int(i/2/8), int(i/2%8)].set_title('idx: {0}'.format(i))
    axes[int(i/2/8), int(i/2%8)].imshow(feature_map)

plt.tight_layout()
figs.show()

In [None]:
# Find the most active filter in the 20th layer
layer_number = 20
fmaps = outputs_from_layer[layer_number].detach().numpy()

output_from_filters = fmaps.sum(axis=3).sum(axis=2)
idx_max = output_from_filters.argmax()
max = output_from_filters.max()
idx_min = output_from_filters.argmin()
min = output_from_filters.min()

print('Max, idx: ', max, idx_max)
print('Min, idx: ', min, idx_min)

img_max = fmaps[0,idx_max,:,:]
img_min = fmaps[0,idx_min,:,:]

figs, axes = plt.subplots(1,2, figsize=[8,16])
axes[0].set_title('Max, idx {0}'.format(idx_max))
axes[0].imshow(img_max)
axes[1].set_title('Min, idx {0}'.format(idx_min))
axes[1].imshow(img_min)
figs.show()

In [None]:
# Feature maps of the last Conv2d layer
# There are 2048 filters

pickup_idx = [32*x for x in range(64)]
layer_number = 100
feature_maps = outputs_from_layer[layer_number].detach().numpy()
figs, axes = plt.subplots(8, 8, figsize=[16,16])
for i in pickup_idx:
    feature_map = feature_maps[0,i,:,:]
    axes[int(i/32/8), int(i/32%8)].set_title('idx: {0}'.format(i))
    axes[int(i/32/8), int(i/32%8)].imshow(feature_map)

plt.tight_layout()
figs.show()

In [None]:
# Find the most active filter in the 100th layer
layer_number = 100
fmaps = outputs_from_layer[layer_number].detach().numpy()

output_from_filters = fmaps.sum(axis=3).sum(axis=2)
idx_max = output_from_filters.argmax()
max = output_from_filters.max()
idx_min = output_from_filters.argmin()
min = output_from_filters.min()

print('Max, idx: ', max, idx_max)
print('Min, idx: ', min, idx_min)

img_max = fmaps[0,idx_max,:,:]
img_min = fmaps[0,idx_min,:,:]

figs, axes = plt.subplots(1,2, figsize=[8,16])
axes[0].set_title('Max, idx {0}'.format(idx_max))
axes[0].imshow(img_max)
axes[1].set_title('Min, idx {0}'.format(idx_min))
axes[1].imshow(img_min)
figs.show()