<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [3]:
import sys
import numpy as np
import pandas as pd
import torch
from torchvision import datasets
from dataclasses import dataclass
import pytorch_lightning as pl
from torchvision import models
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.data import random_split, DataLoader
from grad_cam import GradCAM
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import Callback

In [6]:
class MobileNet_MNIST(nn.Module):
    def __init__(self, fc_size = 800, dropout_prob=0.8):
        super(MobileNet_MNIST, self).__init__()
        self.base_model = models.mobilenet_v2(pretrained=True)
        self.fc_layer1 = nn.Linear(1000, fc_size)
        self.fc_layer2 = nn.Linear(fc_size, 10)
        self.dropout = nn.Dropout(dropout_prob)
        #self.softmax = nn.Softmax()
    def forward(self,x):
        #x = x.repeat(1,3,1,1)
        x = self.base_model(x)
        x = self.fc_layer1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc_layer2(x)
        x = F.relu(x)
        #x = self.softmax(x)
        return x

In [5]:
class Interpretation_Callback(Callback):
    """
    Callback to calculate gradient heatmap for model interpretation on the fly. 
    """
    def __init__(self, n_samples_max=10, layer, specified_input=None):
        """
        Args:
            n_samples_max   : Maximum no of images for whom the gradient images are required to be plotted. If the mini batch size is lower 
                              than_samples_max, it will plot the heatmap for all the images in a batch
            layer           : The layer on which gradient computations is to be performed
            specified_input : If dataloader returns more than a single input, specify the key for the batch dictionary to fetch the correct output. 
                              If None, _fetch_tensors() method tries to figure out the key that contains the tensor input for model 
        """
        super().__init__()
        self.n_samples_max = n_samples_max
    def on_train_start(self, trainer, pl_module):
        return
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        return
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        
        return
    def on_train_end(self, trainer, pl_module):
        return
    @classmethod
    def _fetch_tensors(cls, batch):
        """
        Finds out the image tensors in a given batch sample. Since batch might return different kind of inputs, it is necessary to extract the target image
        to perform gradient interpretation
        """
        inp_type = type(batch)
        if inp_type is torch.Tensor or inp_type is torch.cude.Tensor:
            return batch
        elif 
        return

In [7]:
class Explainer_Classifier(pl.LightningModule):
    def __init__(self, hparams):
        super(Explainer_Classifier, self).__init__()
        self.hparams = hparams
        self.model = MobileNet_MNIST()
        self.criterion = nn.CrossEntropyLoss()
        self.dset = []
        
    def forward(self, x):
        return self.model(x)

    def apply_GradCam(self, images, labels, target_layers):
        probs, ids = self.gcam.forward(images)    
        self.gcam.backward(ids=ids)

    def training_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        train_loss = self.criterion(outputs, labels)
        tqdm_dict = {'train_loss':train_loss}
        logger_stats = {'train_loss':train_loss, 'batch_id':batch_idx, 
                        'data_samples':[data.cpu()], 'labels':[labels.cpu()]}
        outputs = OrderedDict({
            'loss':train_loss,
            'progressbar':tqdm_dict,
            'log':logger_stats
        })
        return outputs
    
    def on_train_end(self):
        print("Initialising GradCam")
        self.gcam = GradCAM(model=self.model)
        model_layer = 'base_model.features.17.conv.3'
        #self.apply_GradCam(data, labels, target_layers=model_layer)
    
    def validation_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        lbls = labels.type(torch.cuda.LongTensor)
        val_loss = self.criterion(outputs, labels)
        tqdm_dict = {'val_loss':val_loss}
        outputs = OrderedDict({
            'loss':val_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        return outputs

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.1)
        return [optimizer]#, [scheduler]

    def train_dataloader(self):
        train_dataset = datasets.MNIST(
                            'dataset/', train=True, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Normalize((0.5),(0.5)), transforms.Lambda(lambda x:x.repeat(3,1,1))]))
        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size)
        return train_dataloader

    def val_dataloader(self):
        val_dataset = datasets.MNIST(
                            'dataset/', train=False, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Normalize((0.5),(0.5)), transforms.Lambda(lambda x:x.repeat(3,1,1))]))
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size)
        return val_dataloader
 

In [8]:
hyper_params = {'batch_size':64, 'lr':0.001, 'n_epochs':1}

In [9]:
#hyper_params = hparams_vals
model = Explainer_Classifier(hyper_params)

In [10]:
#Adding CSV logger to keep track of sample loss in order to look out for low loss and high loss examples
csv_logger = CSVLogger("training_logs", name="mnist_gradcam")

In [11]:
trainer = pl.Trainer(gpus=1, 
            max_epochs=hyper_params['n_epochs'],
            logger=csv_logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [12]:
trainer.fit(model)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | MobileNet_MNIST  | 4 M   
1 | criterion | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..

Initialising GradCam


1

In [None]:
trainer.test(model)

In [4]:
mnist_mobile = MobileNet_MNIST()

In [6]:
for name, module in mnist_mobile.named_modules():
    print(name)


base_model
base_model.features
base_model.features.0
base_model.features.0.0
base_model.features.0.1
base_model.features.0.2
base_model.features.1
base_model.features.1.conv
base_model.features.1.conv.0
base_model.features.1.conv.0.0
base_model.features.1.conv.0.1
base_model.features.1.conv.0.2
base_model.features.1.conv.1
base_model.features.1.conv.2
base_model.features.2
base_model.features.2.conv
base_model.features.2.conv.0
base_model.features.2.conv.0.0
base_model.features.2.conv.0.1
base_model.features.2.conv.0.2
base_model.features.2.conv.1
base_model.features.2.conv.1.0
base_model.features.2.conv.1.1
base_model.features.2.conv.1.2
base_model.features.2.conv.2
base_model.features.2.conv.3
base_model.features.3
base_model.features.3.conv
base_model.features.3.conv.0
base_model.features.3.conv.0.0
base_model.features.3.conv.0.1
base_model.features.3.conv.0.2
base_model.features.3.conv.1
base_model.features.3.conv.1.0
base_model.features.3.conv.1.1
base_model.features.3.conv.1.2
b