# Cell Classification Captum Interpretation 


Here, we will interpret the resnet model used while predicting the cell type using Integrated Gradients technique implemented in captum.

Integrated Gradient is an interpretability or explainability technique for deep neural networks which visualizes its input feature importance that contributes to the model's prediction. In this, we compute the integral of the gradients of the model outputs for the predicted output with respect to the input image pixels along the path from the black image to our input image.

This paper (which can be found [here](https://arxiv.org/pdf/1703.01365.pdf)) discusses Integrated gradients in much more detail along with introduing some axioms which every integrated gradients should follow. 

In [None]:
%%capture
!pip install captum

# Imports

In [None]:
import os
import time
import json
import random
import collections
import cv2

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

import torch
import torchvision
from torchvision import transforms
from torchvision import models
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision.models import resnet34

import torch.nn as nn

import torch
import torch.nn.functional as F

from PIL import Image

from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz

In [None]:
SEED = 3011

def fix_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_seeds(SEED)

## Configuration

In [None]:
TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
model = resnet34(True)
model.fc = nn.Linear(512, 3)

In [None]:
model = torch.load("../input/cell-classification-vanilla-torch/resnet34-finetuned.bin", map_location=DEVICE)

In [None]:
model

In [None]:
class ToTensorNew:
    def __call__(self, image, target):
        image = torchvision.transforms.functional.to_tensor(image)
        return image, target

class ClassificationInterpDataset(Dataset):
    def __init__(self, image_dir, df):
        self.image_dir = image_dir
        self.df = df
        self.transforms=transforms
        
        self.height = 520
        self.width = 704
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                'image_id': row['id'],
                'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                }

    def __getitem__(self, idx):
        
        img_path = self.image_info[idx]["image_path"]
        img = Image.open(img_path).convert("RGB")
        img, _ = ToTensorNew()(image=img, target=None)
        info = self.image_info[idx]
        image_id = torch.tensor([idx])
        target = {
            'image_id': image_id,
        }
        return img, target

    def __len__(self):
        return len(self.image_info)

In [None]:
df_train = pd.read_csv(TRAIN_CSV)
ds_train = ClassificationInterpDataset(TRAIN_PATH, df_train)

In [None]:
idx=20
image = ds_train[idx][0]
input = torch.unsqueeze(image, dim=0).to(DEVICE)

output = model(input)

prediction_score, pred_label_idx = torch.topk(output, 1)
pred_label_idx.squeeze_()
predicted_label = df_train[df_train.id == ds_train.image_info[idx]['image_id']].iloc[0].cell_type
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')

In [None]:
integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(input, target=pred_label_idx, n_steps=20)

In [None]:
plt.imshow(np.transpose(image.squeeze().cpu().detach().numpy(), (1,2,0)))
plt.axis('off')

In [None]:
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

_ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
                             np.transpose(image.squeeze().cpu().detach().numpy(), (1,2,0)),
                             method='heat_map',
                             cmap=default_cmap,
                             show_colorbar=True,
                             sign='positive',
                             outlier_perc=1)
plt.axis('off')

In [None]:
gradient_shap = GradientShap(model)

# Defining baseline distribution of images
rand_img_dist = torch.cat([input * 0, input * 1])

attributions_gs = gradient_shap.attribute(input,
                                          n_samples=5,
                                          stdevs=0.0001,
                                          baselines=rand_img_dist,
                                          target=pred_label_idx)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(image.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "absolute_value"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

In [None]:
noise_tunnel = NoiseTunnel(integrated_gradients)

attributions_ig_nt = noise_tunnel.attribute(input, nt_samples=2, nt_type='smoothgrad_sq', target=pred_label_idx)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

## References

For calculating the attributions, the [captum library](https://captum.ai/)  was used, which is a package for model interpretability in Pytorch.