<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import sys
import numpy as np
import pandas as pd
import torch
import random
import torch.backends.cudnn as cudnn
import cv2
from torchvision import datasets
from dataclasses import dataclass
import pytorch_lightning as pl
from torchvision import models
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.data import random_split, DataLoader
from grad_cam import GradCAM
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import Callback
import matplotlib.cm as cm
from pytorch_memlab import profile, set_target_gpu

In [2]:
class MobileNet_MNIST(nn.Module):
    def __init__(self, fc_size = 800, dropout_prob=0.8):
        super(MobileNet_MNIST, self).__init__()
        self.base_model = models.mobilenet_v2(pretrained=True)
        self.fc_layer1 = nn.Linear(1000, fc_size)
        self.fc_layer2 = nn.Linear(fc_size, 10)
        self.dropout = nn.Dropout(dropout_prob)
        #self.softmax = nn.Softmax()
    def forward(self,x):
        #x = x.repeat(1,3,1,1)
        x = self.base_model(x)
        x = self.fc_layer1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc_layer2(x)
        x = F.relu(x)
        #x = self.softmax(x)
        return x

In [3]:
class VGGNet_MNIST(nn.Module):
    def __init__(self, fc_size=800, dropout_prob=0.5):
        super(VGGNet_MNIST, self).__init__()
        self.base_model = models.vgg16(pretrained=True)
        self.fc_layer1 = nn.Linear(1000, fc_size)
        self.fc_layer2 = nn.Linear(fc_size, 10)
        self.dropout = nn.Dropout(dropout_prob)
    def forward(self,x):
        #x = x.repeat(1,3,1,1)
        x = self.base_model(x)
        x = self.fc_layer1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc_layer2(x)
        x = F.relu(x)
        #x = self.softmax(x)
        return x

In [4]:
class Loss_Tracker:
    '''
    Keep track of running loss values obtained from progressbar callback dictionary and also stores data
    about minimum loss, maximum loss and mean loss value for each batch
    '''
    def __init__(self):
        self.loss_arr = np.array([])
        self.current_running_loss = None
        self.current_batch_loss = None
    def update(self, run_loss):
        if isinstance(run_loss, float):
            if len(self.loss_arr)==0:
                #This is the first loss value to be logged
                self.loss_arr = np.append(self.loss_arr, run_loss)
                self.current_running_loss = run_loss
                self.current_batch_loss = run_loss
            else:
                #Not the first batch. Need to reverse calculate batch loss from runinng loss values
                self.current_batch_loss = ((len(self.loss_arr) + 1)*run_loss) - np.sum(self.loss_arr)
                self.loss_arr = np.append(self.loss_arr, self.current_batch_loss)
                self.current_running_loss = run_loss
        elif isinstance(run_loss, torch.Tensor):
            #Convert tensor to numpy type
            loss_val = run_loss.detach().cpu().numpy()
            self.loss_arr = np.append(self.loss_arr, loss_val)
            self.current_batch_loss = loss_val
            self.current_running_loss = np.mean(self.loss_arr)
        
    def min_val(self):
        if len(self.loss_arr) != 0:
            #The loss values are already logged into the array
            return np.min(self.loss_arr)
        else:
            print("No loss values logged in module")
            return
    def max_val(self):
        if len(self.loss_arr) != 0:
            #The loss values are already logged into the array
            return np.max(self.loss_arr)
        else:
            print("No loss values logged in module")
            return
    def is_current_max(self):
        #Checks whether current batch loss is the highest loss value till now
        if not self.current_batch_loss is None:
            if np.max(self.loss_arr) == self.current_batch_loss:
                return True
            else:
                return False
        else:
            print("No loss values are logged into the system.")
    def is_current_min(self):
        #Checks whether current batch loss is the highest loss value till now
        if not self.current_batch_loss is None:
            if np.min(self.loss_arr) == self.current_batch_loss:
                return True
            else:
                return False
        else:
            print("No loss values are logged into the system.")

    def reset(self):
        #Required to be done at the end of training epoch
        self.loss_arr = np.array([])
        self.current_running_loss = None
        self.current_batch_loss = None

In [5]:
class GradCam_Pipeline(nn.Module):
    '''
    Pytorch wrapper to generate gradient heatmap using GradCam. It performs all the preprocessing as well as postprocessing operations leading to generation of      final heatmap image.
    Arguments:
    model        : Torch model on which model interpretation is to be done
    target_layer : Layer from which the gradients are to be calculated. It is usually the last convolution layer before FC layer.
    normal_val   : Normalisation values including ((mean for each channel),(standard deviation of each channel)) 
    '''
    def __init__(self, model, target_layer, normal_val=None):
        super(GradCam_Pipeline, self).__init__()
        self.model = model
        self.target_layer = target_layer
        self.gradcam_operator = GradCAM(model=self.model)
        self.normal_val = normal_val

    def prepare_img(self, grad_region, batch):
        '''
        Generate individual heatmap images from the image batch
        Arguments:
        grad_region : Gradient maps generated from gradcam wrapper
        batch       : Torch tensor batch of images
        '''
        gradcam_lst = []
        for j in range(len(grad_region)):
            gcam = grad_region[j,0]
            sample_img = batch[j]
            denorm_img = self._denormalise_img(sample_img, self.normal_val)
            cmap = cm.jet_r(gcam)[..., :3] * 255.0
            alpha = gcam[..., None]
            gcam = alpha * cmap + (1 - alpha) * denorm_img
            filename = os.path.join('trial_samples',str(j)+'.jpg')
            cv2.imwrite(filename, np.uint8(gcam))
            gradcam_lst.append(gcam)
        return gradcam_lst

    def forward(self, x):
        probs, ids = self.gradcam_operator.forward(x)
        self.gradcam_operator.backward(ids=ids)
        regions = self.gradcam_operator.generate(self.target_layer)
        self.gradcam_lst = self.prepare_img(regions, x)
        return self.gradcam_lst

    @classmethod
    def _denormalise_img(cls, image, normal_val=None):
        '''
        Reverse conversion of a normalised image to a non-normalised image
        Arguments:
        image      : Single batch torch tensor containing single normalised image
        normal_val : Normalisation values including ((mean for each channel),(standard deviation of each channel)) 
        '''
        image = image.cpu().numpy()
        image = np.einsum('ijk->jki', image)
        if normal_val is None:
            normal_val = ((0.5,0.5,0.5), (0.5,0.5,0.5)) #((mean), (std))
        for i in range(3):
            image[:,:,i] = image[:,:,i]*normal_val[0][i] + normal_val[1][i]
        image = image*255.0
        return image        

In [6]:
class Interpretation_Callback(Callback):
    """
    Callback to calculate gradient heatmap for model interpretation on the fly. 
    """
    def __init__(self, n_samples_max, layer, save_dir, normal_val=None, specified_input=None):
        """
        Args:
            n_samples_max   : Maximum no of images for whom the gradient images are required to be plotted. If the mini batch size is lower 
                              than_samples_max, it will plot the heatmap for all the images in a batch
            layer           : The layer on which gradient computations is to be performed
            save_dir        : Directory in which all the heatmap plots are to be saved
            specified_input : If dataloader returns more than a single input, specify the key for the batch dictionary to fetch the correct output. 
                              If None, _fetch_tensors() method tries to figure out the key that contains the tensor input for model 
        """
        super().__init__()
        self.n_samples_max = n_samples_max
        self.grad_layer = layer
        self.save_dir = save_dir
        self.specified_input = specified_input
        self.normal_val = normal_val
        self.track_loss = Loss_Tracker()

    def on_train_start(self, trainer, pl_module):
        self.low_batch, self.high_batch = None, None #Stores the batch with lowest and highest training loss
        self.low_bound_loss, self.high_bound_loss = None, None #Stores the lowest and highest loss value recorded for any batch

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        try:
            #look for self.train_loss atrribute in training loop
            batch_training_loss = pl_module.train_loss
            self.track_loss.update(batch_training_loss)
        except:
            #In case, self.train_loss is not defined, fetch loss data from progressbar 
            running_train_loss = self._get_loss(trainer.progress_bar_dict)
            self.track_loss.update(running_train_loss)
        batch_imgs = self._fetch_tensors(batch)
        if self.track_loss.is_current_max():
            self.high_batch = batch_imgs
        if self.track_loss.is_current_min():
            self.low_batch = batch_imgs
        #Note to self: Write tests to check correct logging

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        self.track_loss.reset()
    
    def on_train_end(self, trainer, pl_module):
        print("Initiating GradCAM")
        self.gradcam = GradCam_Pipeline(model=pl_module.model, target_layer=self.grad_layer, normal_val=self.normal_val)
        self.high_loss_grads = self.gradcam(self.high_batch[:10,].cuda())

    @classmethod
    def _fetch_tensors(cls, batch, dict_inp=None):
        """
        Finds out the image tensors in a given batch sample. Since batch might return different kind of inputs, it is necessary to extract the target image
        to perform gradient interpretation
        """
        inp_type = type(batch)
        if inp_type is torch.Tensor:
            return batch
        elif inp_type is dict:
            if not dict_inp is None:
                #When the specific key in a dictionary is already given
                tensor_inp = batch[dict_inp]
                return tensor_inp
            else:
                #Interesting case when a specific key is not provided. In such case, try to search for a value with common tensor 
                #dimensions like (*,3,224,224) or (*,3,512,512) or any 4-dimensional tensor
                key = None
                for unique_key in batch.keys():
                    inp = batch[unique_key]
                    if type(inp) is torch.Tensor:
                        tensor_dim = len(inp.shape)
                        if tensor_dim == 4:
                            key = unique_key
                if key is None:
                    #replace it with error message
                    print("No value with 4 dimensions found in dictionary. Please check the dictionary structure or specify the key")
                else:
                    return batch[key]
        else:
            #This might be a case where the model returns multiple values at once in form of a tuple or a list. Then it is required to find out which of the
            #returned items from batch is an image tensor
            n_items = len(batch)
            tensor_index = None
            for ind in range(n_items):
                item = batch[ind]
                if type(item) is torch.Tensor:
                    tensor_dim = len(item.shape)
                    if tensor_dim == 4:
                        tensor_index = ind
            if tensor_index is None:
                #replace it with error message
                print("No value with 4 dimensional tensor found in batch tuple. Please check the dataloader or return dictionary to avoid such errors")
            else:
                return batch[tensor_index]

    @classmethod
    def _get_loss(cls, outputs):
        """
        Get the loss value from the outputs of training_step. Generally, the output will be either directly a loss value or a dictionary containing one 
        of the entries as loss value 
        """
        if type(outputs) is torch.Tensor and len(outputs.shape) == 0:
            #Usually a rare case when training_loop is directly returning loss value
            return outputs
        else:
            #In such case, the training_step is in form of an ordered dictionary that will definitely contain a 'loss' key
            loss_val = float(outputs['loss'])
            return loss_val
             

In [7]:
class Explainer_Classifier(pl.LightningModule):
    def __init__(self, hparams):
        super(Explainer_Classifier, self).__init__()
        self.hparams = hparams
        self.model = VGGNet_MNIST(self.hparams.fc_size, self.hparams.dropout)
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()

    def forward(self, x):
        return self.model(x)

    def apply_GradCam(self, images, target_layer):
        probs, ids = self.gcam.forward(images)    
        self.gcam.backward(ids=ids)

    def training_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        train_loss = self.criterion(outputs, labels)
        self.train_loss = train_loss
        self.log('loss',train_loss)
        self.log('train_acc_step', self.accuracy(outputs, labels))
        return {'loss':train_loss, 'progress_bar':{'train_acc_step':self.accuracy(outputs, labels)}}

    def training_epoch_end(self, outs):
        return {'progress_bar':{'train_acc_epoch':self.accuracy.compute()}}
    
    def validation_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        lbls = labels.type(torch.cuda.LongTensor)
        val_loss = self.criterion(outputs, labels)
        return OrderedDict({'loss':val_loss, 'progress_bar':{'val_acc_step':self.val_acc(outputs, labels), 'val_loss':val_loss}})

    def validation_epoch_end(self, outputs):
        return {'progressbar':{'val_acc':self.val_acc.compute()}, 'log':{'val_acc':self.val_acc.compute()}}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.1)
        return [optimizer]#, [scheduler]

    def train_dataloader(self):
        train_dataset = datasets.MNIST(
                            'dataset/', train=True, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Lambda(lambda x:x.repeat(3,1,1)), transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])) #changes due to vgg
        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=4)
        return train_dataloader

    def val_dataloader(self):
        val_dataset = datasets.MNIST(
                            'dataset/', train=False, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Lambda(lambda x:x.repeat(3,1,1)), transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])) #changes due to vgg
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=4)
        return val_dataloader
 

In [8]:
hyper_params = {'batch_size':32, 'lr':0.002, 'n_epochs':3, 'fc_size':800,'dropout':0.8}

In [9]:
random.seed(42)
torch.manual_seed(42)
cudnn.deterministic = True

In [10]:
#hyper_params = hparams_vals
model = Explainer_Classifier(hyper_params)

In [11]:
grad_cam_callback = Interpretation_Callback(10,'base_model.features.30','gradcam/', normal_val=((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)))

In [12]:
trainer = pl.Trainer(gpus=1, 
            max_epochs=hyper_params['n_epochs'],
            callbacks=[grad_cam_callback],
            accumulate_grad_batches=32
            )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [13]:
trainer.fit(model)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | VGGNet_MNIST     | 139 M 
1 | criterion | CrossEntropyLoss | 0     
2 | accuracy  | Accuracy         | 0     
3 | val_acc   | Accuracy         | 0     
-----------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params
Epoch 0:   7%|▋         | 144/2188 [00:48<11:34,  2.95it/s, loss=5.19e+07, v_num=19, train_acc_step=0.0938]

KeyboardInterrupt: 

In [19]:
trainer.test(model)

1

In [13]:
mnist_mobile = VGGNet_MNIST()

In [14]:
for name, module in mnist_mobile.named_modules():
    print(name)


base_model
base_model.features
base_model.features.0
base_model.features.1
base_model.features.2
base_model.features.3
base_model.features.4
base_model.features.5
base_model.features.6
base_model.features.7
base_model.features.8
base_model.features.9
base_model.features.10
base_model.features.11
base_model.features.12
base_model.features.13
base_model.features.14
base_model.features.15
base_model.features.16
base_model.features.17
base_model.features.18
base_model.features.19
base_model.features.20
base_model.features.21
base_model.features.22
base_model.features.23
base_model.features.24
base_model.features.25
base_model.features.26
base_model.features.27
base_model.features.28
base_model.features.29
base_model.features.30
base_model.avgpool
base_model.classifier
base_model.classifier.0
base_model.classifier.1
base_model.classifier.2
base_model.classifier.3
base_model.classifier.4
base_model.classifier.5
base_model.classifier.6
fc_layer1
fc_layer2
dropout


In [20]:
sample = torch.randn(1,3,224,224).cuda()

In [3]:
sample_cpu = torch.randn(4,3,224,224)

In [4]:
sample_cpu[0].shape

torch.Size([3, 224, 224])

In [7]:
d = torch.unsqueeze(sample_cpu[0],0)

In [8]:
d.shape

torch.Size([1, 3, 224, 224])

In [3]:
import torch
from pytorch_memlab import MemReporter

In [3]:
model = MobileNet_MNIST().cuda()
#model.to('cuda')

In [4]:
gradcam = GradCAM(model=model, single_usage=True)
#calculating heatmaps for high loss batch images


In [7]:
reporter = MemReporter()

In [5]:
model.zero_grad()

In [10]:
high_batch_seg = torch.randn(32,3,224,224)
probs, ids = gradcam.forward(high_batch_seg.cuda())
gradcam.backward(ids=ids)
regions = gradcam.generate('base_model.features.17.conv.3')

In [12]:
regions.shape

(32, 1, 224, 224)

In [3]:
def test_func():
    model = MobileNet_MNIST().cuda()
    gradcam = GradCAM(model=model)
    model.zero_grad()
    high_batch_seg = torch.randn(4,3,224,224)
    probs, ids = gradcam.forward(high_batch_seg.cuda())
    gradcam.backward(ids=ids)
    regions = gradcam.generate('base_model.features.17.conv.3')

In [13]:
torch.cuda.empty_cache()

In [4]:
%load_ext pytorch_memlab

In [25]:
sample = torch.randn(32,3,224,224)

In [34]:
sample.shape

torch.Size([32, 3, 224, 224])

In [31]:
a = sample[0,0]

In [33]:
a[..., None].shape

torch.Size([224, 224, 1])