<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import sys
import numpy as np
import pandas as pd
import torch
from torchvision import datasets
from dataclasses import dataclass
import pytorch_lightning as pl
from torchvision import models
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.data import random_split, DataLoader
from grad_cam import GradCAM

In [2]:
class MobileNet_MNIST(nn.Module):
    def __init__(self, fc_size = 800, dropout_prob=0.8):
        super(MobileNet_MNIST, self).__init__()
        self.base_model = models.mobilenet_v2(pretrained=True)
        self.fc_layer1 = nn.Linear(1000, fc_size)
        self.fc_layer2 = nn.Linear(fc_size, 10)
        self.dropout = nn.Dropout(dropout_prob)
        #self.softmax = nn.Softmax()
    def forward(self,x):
        #x = x.repeat(1,3,1,1)
        x = self.base_model(x)
        x = self.fc_layer1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc_layer2(x)
        x = F.relu(x)
        #x = self.softmax(x)
        return x

In [3]:
class Explainer_Classifier(pl.LightningModule):
    def __init__(self, hparams):
        super(Explainer_Classifier, self).__init__()
        self.hparams = hparams
        self.model = MobileNet_MNIST()
        self.criterion = nn.CrossEntropyLoss()
        self.gcam = GradCAM(model=self.model)
        
    def forward(self, x):
        return self.model(x)

    def apply_GradCam(self, images, labels, target_layers):
        probs, ids = self.gcam.forward(images)
        #print(probs.shape)
        #print(ids.shape)
        
        self.gcam.backward(ids=ids)
        #regions = self.gcam.generate(target_layer=target_layers)
        #sys.exit()

    def training_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        model_layer = 'base_model.features.17.conv.3'
        self.apply_GradCam(data, labels, target_layers=model_layer)
        train_loss = self.criterion(outputs, labels)
        #sys.exit()
        tqdm_dict = {'train_loss':train_loss}
        outputs = OrderedDict({
            'loss':train_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        
        return outputs
    
    def validation_step(self, batch, batch_idx):
        data, labels = batch
        #model_layer = 'base_model.features.17.conv.3'
        outputs = self.forward(data)
        lbls = labels.type(torch.cuda.LongTensor)
        #self.apply_GradCam(data, labels, target_layers=model_layer)
        val_loss = self.criterion(outputs, labels)
        tqdm_dict = {'val_loss':val_loss}
        outputs = OrderedDict({
            'loss':val_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        return outputs

    def test_step(self, batch, batch_idx):
        data, labels = batch
        model_layer = 'base_model.features.17.conv.3'
        self.apply_GradCam(data, labels, target_layers=model_layer)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.1)
        return [optimizer]#, [scheduler]

    def train_dataloader(self):
        train_dataset = datasets.MNIST(
                            'dataset/', train=True, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Normalize((0.5),(0.5)), transforms.Lambda(lambda x:x.repeat(3,1,1))]))
        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size)
        return train_dataloader

    def val_dataloader(self):
        val_dataset = datasets.MNIST(
                            'dataset/', train=False, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Normalize((0.5),(0.5)), transforms.Lambda(lambda x:x.repeat(3,1,1))]))
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size)
        return val_dataloader

    def test_dataloader(self):
        test_dataset = datasets.MNIST(
                            'dataset/', train=False, download=True,
                            transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(),
                            transforms.Normalize((0.5),(0.5)), transforms.Lambda(lambda x:x.repeat(3,1,1))]))
        test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size)
        return test_dataloader    

In [4]:
hyper_params = {'batch_size':1, 'lr':0.001, 'n_epochs':1}

In [5]:
#hyper_params = hparams_vals
model = Explainer_Classifier(hyper_params)

In [6]:
trainer = pl.Trainer(gpus=1, max_epochs=hyper_params['n_epochs'])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [7]:
trainer.fit(model)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | MobileNet_MNIST  | 4 M   
1 | criterion | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..



1

In [None]:
trainer.test(model)

In [4]:
mnist_mobile = MobileNet_MNIST()

In [6]:
for name, module in mnist_mobile.named_modules():
    print(name)


base_model
base_model.features
base_model.features.0
base_model.features.0.0
base_model.features.0.1
base_model.features.0.2
base_model.features.1
base_model.features.1.conv
base_model.features.1.conv.0
base_model.features.1.conv.0.0
base_model.features.1.conv.0.1
base_model.features.1.conv.0.2
base_model.features.1.conv.1
base_model.features.1.conv.2
base_model.features.2
base_model.features.2.conv
base_model.features.2.conv.0
base_model.features.2.conv.0.0
base_model.features.2.conv.0.1
base_model.features.2.conv.0.2
base_model.features.2.conv.1
base_model.features.2.conv.1.0
base_model.features.2.conv.1.1
base_model.features.2.conv.1.2
base_model.features.2.conv.2
base_model.features.2.conv.3
base_model.features.3
base_model.features.3.conv
base_model.features.3.conv.0
base_model.features.3.conv.0.0
base_model.features.3.conv.0.1
base_model.features.3.conv.0.2
base_model.features.3.conv.1
base_model.features.3.conv.1.0
base_model.features.3.conv.1.1
base_model.features.3.conv.1.2
b