<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import sys
import numpy as np
import pandas as pd
import torch
from torchvision import datasets
from dataclasses import dataclass
import pytorch_lightning as pl
from torchvision import models
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.data import random_split, DataLoader
from grad_cam import GradCAM
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import Callback

In [2]:
class MobileNet_MNIST(nn.Module):
    def __init__(self, fc_size = 800, dropout_prob=0.8):
        super(MobileNet_MNIST, self).__init__()
        self.base_model = models.mobilenet_v2(pretrained=True)
        self.fc_layer1 = nn.Linear(1000, fc_size)
        self.fc_layer2 = nn.Linear(fc_size, 10)
        self.dropout = nn.Dropout(dropout_prob)
        #self.softmax = nn.Softmax()
    def forward(self,x):
        #x = x.repeat(1,3,1,1)
        x = self.base_model(x)
        x = self.fc_layer1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc_layer2(x)
        x = F.relu(x)
        #x = self.softmax(x)
        return x

In [106]:
class Interpretation_Callback(Callback):
    """
    Callback to calculate gradient heatmap for model interpretation on the fly. 
    """
    def __init__(self, n_samples_max, layer, save_dir, specified_input=None):
        """
        Args:
            n_samples_max   : Maximum no of images for whom the gradient images are required to be plotted. If the mini batch size is lower 
                              than_samples_max, it will plot the heatmap for all the images in a batch
            layer           : The layer on which gradient computations is to be performed
            save_dir        : Directory in which all the heatmap plots are to be saved
            specified_input : If dataloader returns more than a single input, specify the key for the batch dictionary to fetch the correct output. 
                              If None, _fetch_tensors() method tries to figure out the key that contains the tensor input for model 
        """
        super().__init__()
        self.n_samples_max = n_samples_max
        self.grad_layer = layer
        self.save_dir = save_dir
        self.specified_input = specified_input

    def on_train_start(self, trainer, pl_module):
        self.low_batch, self.high_batch = None, None #Stores the batch with lowest and highest training loss
        self.low_bound_loss, self.high_bound_loss = None, None #Stores the lowest and highest loss value recorded for any batch

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        print("outputs:",outputs)
        batch_train_loss = self._get_loss(outputs)
        batch_imgs = self._fetch_tensors(batch)
        self.low_bound_loss = self._update_val(self.low_bound_loss, batch_train_loss, condition='lower')
        self.high_bound_loss = self._update_val(self.high_bound_loss, batch_train_loss, condition='higher')

    @classmethod
    def _fetch_tensors(cls, batch, dict_inp=None):
        """
        Finds out the image tensors in a given batch sample. Since batch might return different kind of inputs, it is necessary to extract the target image
        to perform gradient interpretation
        """
        inp_type = type(batch)
        if inp_type is torch.Tensor:
            return batch
        elif inp_type is dict:
            if not dict_inp is None:
                #When the specific key in a dictionary is already given
                tensor_inp = batch[dict_inp]
                return tensor_inp
            else:
                #Interesting case when a specific key is not provided. In such case, try to search for a value with common tensor 
                #dimensions like (*,3,224,224) or (*,3,512,512) or any 4-dimensional tensor
                key = None
                for unique_key in batch.keys():
                    inp = batch[unique_key]
                    if type(inp) is torch.Tensor:
                        tensor_dim = len(inp.shape)
                        if tensor_dim == 4:
                            key = unique_key
                if key is None:
                    #replace it with error message
                    print("No value with 4 dimensions found in dictionary. Please check the dictionary structure or specify the key")
                else:
                    return batch[key]
        else:
            #This might be a case where the model returns multiple values at once in form of a tuple or a list. Then it is required to find out which of the
            #returned items from batch is an image tensor
            n_items = len(batch)
            tensor_index = None
            for ind in range(n_items):
                item = batch[ind]
                if type(item) is torch.Tensor:
                    tensor_dim = len(item.shape)
                    if tensor_dim == 4:
                        tensor_index = ind
            if tensor_index is None:
                #replace it with error message
                print("No value with 4 dimensional tensor found in batch tuple. Please check the dataloader or return dictionary to avoid such errors")
            else:
                return batch[tensor_index]

    @classmethod
    def _get_loss(cls, outputs):
        """
        Get the loss value from the outputs of training_step. Generally, the output will be either directly a loss value or a dictionary containing one 
        of the entries as loss value 
        """
        if type(outputs) is torch.Tensor and len(outputs.shape) == 0:
            #Usually a rare case when training_loop is directly returning loss value
            return outputs
        else:
            #In such case, the training_step is in form of an ordered dictionary that will definitely contain a 'progress_bar' key
            #print(outputs)
            return 0
    
    @classmethod
    def _update_val(cls, var, update, condition='greater'):
        """
        Update the value of var with update value based on a condition
        Two conditions are present - 1.)'greater' : update should be greater than current value of var 
                                     2.)'lower'   : update should be lower than current value of var 
        """
        if var is None:
            #var is not assigned any value till now. In such case, assign the update to var
            return update
        else:
            #Now a comparison is needed to be done between var and update based on condition
            if condition == 'greater':
                if update > var:
                    return update
                else:
                    return var
            elif condition == 'lower':
                if update < var:
                    return update
                else:
                    return var            

In [107]:
class Explainer_Classifier(pl.LightningModule):
    def __init__(self, hparams):
        super(Explainer_Classifier, self).__init__()
        self.hparams = hparams
        self.model = MobileNet_MNIST()
        self.criterion = nn.CrossEntropyLoss()
        self.dset = []
        
    def forward(self, x):
        return self.model(x)

    def apply_GradCam(self, images, labels, target_layers):
        probs, ids = self.gcam.forward(images)    
        #self.gcam.backward(ids=ids)

    def training_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        train_loss = self.criterion(outputs, labels)
        #print(len(train_loss.shape))
        tqdm_dict = {'train_loss':train_loss}
        outputs = OrderedDict({
            'loss':train_loss,
            'progressbar':tqdm_dict
        })
        return outputs
    
    def on_train_end(self):
        print("Initialising GradCam")
        self.gcam = GradCAM(model=self.model)
        model_layer = 'base_model.features.17.conv.3'
        #self.apply_GradCam(data, labels, target_layers=model_layer)
    
    def validation_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        lbls = labels.type(torch.cuda.LongTensor)
        val_loss = self.criterion(outputs, labels)
        tqdm_dict = {'val_loss':val_loss}
        outputs = OrderedDict({
            'loss':val_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        return outputs

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.1)
        return [optimizer]#, [scheduler]

    def train_dataloader(self):
        train_dataset = datasets.MNIST(
                            'dataset/', train=True, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Normalize((0.5),(0.5)), transforms.Lambda(lambda x:x.repeat(3,1,1))]))
        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size)
        return train_dataloader

    def val_dataloader(self):
        val_dataset = datasets.MNIST(
                            'dataset/', train=False, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Normalize((0.5),(0.5)), transforms.Lambda(lambda x:x.repeat(3,1,1))]))
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size)
        return val_dataloader
 

In [108]:
hyper_params = {'batch_size':64, 'lr':0.001, 'n_epochs':1}

In [109]:
#hyper_params = hparams_vals
model = Explainer_Classifier(hyper_params)

In [110]:
#Adding CSV logger to keep track of sample loss in order to look out for low loss and high loss examples
csv_logger = CSVLogger("training_logs", name="mnist_gradcam")

In [111]:
grad_cam_callback = Interpretation_Callback(10,'conv','gradcam/')

In [112]:
trainer = pl.Trainer(gpus=1, 
            max_epochs=hyper_params['n_epochs'],
            callbacks=[grad_cam_callback])

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [113]:
trainer.fit(model)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | MobileNet_MNIST  | 4.3 M 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
4.3 M     Trainable params
0         Non-trainable params
4.3 M     Total params
Epoch 0:   0%|          | 0/1095 [00:00<?, ?it/s] outputs: []
Epoch 0:   0%|          | 1/1095 [00:00<04:36,  3.95it/s, loss=2.62, v_num=23]outputs: []
Epoch 0:   0%|          | 2/1095 [00:00<04:06,  4.44it/s, loss=2.43, v_num=23]outputs: []
Epoch 0:   0%|          | 3/1095 [00:00<03:58,  4.58it/s, loss=2.32, v_num=23]outputs: []
Epoch 0:   0%|          | 4/1095 [00:00<03:51,  4.72it/s, loss=2.22, v_num=23]outputs: []
Epoch 0:   0%|          | 5/1095 [00:01<03:48,  4.78it/s, loss=2.11, v_num=23]outputs: []
Epoch 0:   1%|          | 6/1095 [00:01<03:47,  4.79it/s, loss=2.04, v_num=23]outputs: []
Epoch 0:   1%|          | 7/1095 [00:01<03:45,  4.83it/s, loss=1.98, v_num=23]outputs: []
Epoc

1

In [None]:
trainer.test(model)

In [4]:
mnist_mobile = MobileNet_MNIST()

In [6]:
for name, module in mnist_mobile.named_modules():
    print(name)


base_model
base_model.features
base_model.features.0
base_model.features.0.0
base_model.features.0.1
base_model.features.0.2
base_model.features.1
base_model.features.1.conv
base_model.features.1.conv.0
base_model.features.1.conv.0.0
base_model.features.1.conv.0.1
base_model.features.1.conv.0.2
base_model.features.1.conv.1
base_model.features.1.conv.2
base_model.features.2
base_model.features.2.conv
base_model.features.2.conv.0
base_model.features.2.conv.0.0
base_model.features.2.conv.0.1
base_model.features.2.conv.0.2
base_model.features.2.conv.1
base_model.features.2.conv.1.0
base_model.features.2.conv.1.1
base_model.features.2.conv.1.2
base_model.features.2.conv.2
base_model.features.2.conv.3
base_model.features.3
base_model.features.3.conv
base_model.features.3.conv.0
base_model.features.3.conv.0.0
base_model.features.3.conv.0.1
base_model.features.3.conv.0.2
base_model.features.3.conv.1
base_model.features.3.conv.1.0
base_model.features.3.conv.1.1
base_model.features.3.conv.1.2
b

In [20]:
sample = torch.randn(1,3,224,224).cuda()

In [23]:
sample_cpu = torch.randn(1,3,224,224)

In [53]:
len(sample_cpu.shape)

4

In [25]:
type(sample)

torch.Tensor

In [22]:
sample.is_cuda

True

In [24]:
sample_cpu.is_cuda

False

In [19]:
type(gpu_var)

torch.Tensor

In [26]:
a = {'adsa':'a'}

In [27]:
type(a)

dict

In [29]:
 type(a) is dict

False

In [9]:
if type(sample) is torch.Tensor:
    print("r")
else:
    print('z')

r


In [30]:
a = 1,2

In [44]:
len(a)

2

In [60]:
d = OrderedDict({'a':2,'b':3})

In [61]:
d

OrderedDict([('a', 2), ('b', 3)])

KeyError: 0