<hr style="border: solid 3px blue;">

# Introduction

![](https://thumbs.gfycat.com/SecondhandSourGuineapig-max-1mb.gif)

Picture Credit: https://thumbs.gfycat.com

In this notebook, we make a simple CNN model, and after training and testing, we want to check why our model makes such a decision using the MNIST dataset.

In [None]:
!pip install captum

In [None]:
# import libraries
import torch
import numpy as np

import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

from torchvision import models

from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz

--------------------------------------------------------------------
# Loading Datasets
Download and executing the MNIST data set provided by Torchvision. batch_size can be selected according to your needs. Create a dataloader using the MNIST data set and use it as input data for the neural network.

In [None]:
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 20
# percentage of training set to use as validation
valid_size = 0.2

# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# choose the training and test datasets
train_data = datasets.MNIST(root='../data', train=True,
                                   download=True, transform=transform)
test_data = datasets.MNIST(root='../data', train=False,
                                  download=True, transform=transform)

# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 
    sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, 
    num_workers=num_workers)

-------------------------------------------------------
# Checking Dataset

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
    
# obtain one batch of training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()

# plot the images in the batch, along with the corresponding labels
fig = plt.figure(figsize=(25, 4))
for idx in np.arange(10):
    ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images[idx]), cmap='gray')
    ax.set_title(str(labels[idx].item()),fontsize=30)

In [None]:
img = np.squeeze(images[1])

fig = plt.figure(figsize = (14,14)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
width, height = img.shape
thresh = img.max()/2.5
for x in range(width):
    for y in range(height):
        val = round(img[x][y],2) if img[x][y] !=0 else 0
        ax.annotate(str(val), xy=(y,x),
                    horizontalalignment='center',
                    verticalalignment='center',
                    color='white' if img[x][y]<thresh else 'black')

--------------------------------------------
# Defining Model

* It consists of two Conv2d layers.
* Use dropout to avoid overfitting.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# define the NN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

# initialize the NN
model = Net()
print(model)

**Defining Loss Function and Optimizer**

For classification, the loss function uses cross-entropy.

In [None]:
# specify loss function (categorical cross-entropy)
criterion = nn.CrossEntropyLoss()

# specify optimizer (stochastic gradient descent) and learning rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

-------------------------------------------
# Training

For each batch, the training proceeds in the following order.

* Clear the gradients of all optimized variables
* Forward pass: compute predicted outputs by passing inputs to the model 
* Calculate the loss
* Backward pass: compute gradient of the loss with respect to model parameters
* Perform a single optimization step (parameter update)
* Update average training loss

Train every number of epochs and save the model with the smallest valid loss

In [None]:
# number of epochs to train the model
n_epochs = 10

# initialize tracker for minimum validation loss
valid_loss_min = np.Inf # set initial "min" to infinity

for epoch in range(n_epochs):
    # monitor training loss
    train_loss = 0.0
    valid_loss = 0.0
    
    ###################
    # train the model #
    ###################
    model.train() # prep model for training
    for data, target in train_loader:
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the loss
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update running training loss
        train_loss += loss.item()*data.size(0)
        
    ######################    
    # validate the model #
    ######################
    model.eval() # prep model for evaluation
    for data, target in valid_loader:
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the loss
        loss = criterion(output, target)
        # update running validation loss 
        valid_loss += loss.item()*data.size(0)
        
    # print training/validation statistics 
    # calculate average loss over an epoch
    train_loss = train_loss/len(train_loader.dataset)
    valid_loss = valid_loss/len(valid_loader.dataset)
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch+1, 
        train_loss,
        valid_loss
        ))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(model.state_dict(), 'model_cnn.pt')
        valid_loss_min = valid_loss

Load the model with the minimum valid loss stored during training.

In [None]:
model.load_state_dict(torch.load('model_cnn.pt'))

-----------------------------------------------------
# Testing

* Finally, using the best model, the trained model is tested using the test data.
* Evaluate the model using test dataset not used for training.
* It can be evaluated by test loss and accuracy.

In [None]:
# initialize lists to monitor test loss and accuracy
test_loss = 0.0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

model.eval() # prep model for evaluation

for data, target in test_loader:
    # forward pass: compute predicted outputs by passing inputs to the model
    output = model(data)
    # calculate the loss
    loss = criterion(output, target)
    # update test loss 
    test_loss += loss.item()*data.size(0)
    # convert output probabilities to predicted class
    _, pred = torch.max(output, 1)
    # compare predictions to true label
    correct = np.squeeze(pred.eq(target.data.view_as(pred)))
    # calculate test accuracy for each object class
    for i in range(len(target)):
        label = target.data[i]
        class_correct[label] += correct[i].item()
        class_total[label] += 1

# calculate and print avg test loss
test_loss = test_loss/len(test_loader.dataset)
print('Test Loss: {:.6f}\n'.format(test_loss))

for i in range(10):
    if class_total[i] > 0:
        print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
            str(i), 100 * class_correct[i] / class_total[i],
            np.sum(class_correct[i]), np.sum(class_total[i])))
    else:
        print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))

print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
    100. * np.sum(class_correct) / np.sum(class_total),
    np.sum(class_correct), np.sum(class_total)))

---------------------------------------------------------
# Interpreting Model

![](https://i.pinimg.com/originals/ff/04/31/ff0431d11ff6b73e937280252f58f371.gif)

Picture Credit: https://i.pinimg.com

We want to understand on what basis our CNN model judged the flower type. Therefore, we want to visually understand the judgment basis of our model through the following two methods.

**Saliency detection**

Saliency detection refers to separating an object of interest from a background that is not of interest, and the result is a binarized image as shown below. These detection methods help you spend less time and energy in determining the most relevant parts of an image. In other words, simplifying the representation of an image to make it more meaningful and easier to analyze.

**IntegratedGradients**

A gradient for each pixel of the generated images will appear, and integrating these gradients will show the effect on the overall pixel output. More details about integrated gradients can be found in the original paper: https://arxiv.org/abs/1703.01365

In [None]:
def imshow(img, transpose = True):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(valid_loader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))


outputs = model(images)

_, predicted = torch.max(outputs, 1)

In [None]:
def attribute_image_features(algorithm, input, **kwargs):
    model.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=labels[ind],
                                              **kwargs
                                             )
    
    return tensor_attributions

In [None]:
for ind in range(5):
    print('---'*20)
    print(f'indicaton = {ind}')
    input = images[ind].unsqueeze(0)
    input.requires_grad = True
    
    saliency = Saliency(model)
    grads = saliency.attribute(input, target=labels[ind].item())
    squeeze_grads = grads.squeeze().cpu().detach()
    squeeze_grads = torch.unsqueeze(squeeze_grads,0).numpy()
    grads = np.transpose(squeeze_grads, (1, 2, 0))
    
    ig = IntegratedGradients(model)
    attr_ig, delta = attribute_image_features(ig, input, baselines=input * 0, return_convergence_delta=True)
    squeeze_ig = attr_ig.squeeze().cpu().detach()
    squeeze_ig = torch.unsqueeze(squeeze_ig,0).numpy()
    attr_ig = np.transpose(squeeze_ig, (1, 2, 0))
    #print('Approximation delta: ', abs(delta))
    
    original_image = np.transpose((images[ind].cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))

    _ = viz.visualize_image_attr(None, original_image, 
                      method="original_image", title="Original Image")

    _ = viz.visualize_image_attr(grads, original_image, method="blended_heat_map", sign="absolute_value",
                          show_colorbar=True, title="Overlayed Gradient Magnitudes")

    _ = viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all",
                          show_colorbar=True, title="Overlayed Integrated Gradients")

Looking at the pictures above, we can see the pixels that our CNN model has determined to be important. Visually, we can see that our model is looking at the appropriate pixels.

<hr style="border: solid 3px blue;">