<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [118]:
import sys
import numpy as np
import pandas as pd
import torch
from torchvision import datasets
from dataclasses import dataclass
import pytorch_lightning as pl
from torchvision import models
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.data import random_split, DataLoader

In [87]:
class MobileNet_MNIST(nn.Module):
    def __init__(self, fc_size = 600, dropout_prob=0.5):
        super(MobileNet_MNIST, self).__init__()
        self.base_model = models.mobilenet_v2(pretrained=True)
        self.fc_layer1 = nn.Linear(1000, fc_size, bias=False)
        self.fc_layer2 = nn.Linear(fc_size, 10, bias=False)
        self.dropout = nn.Dropout(dropout_prob)
        self.softmax = nn.Softmax()
    def forward(self,x):
        x = x.repeat(1,3,1,1)
        x = self.base_model(x)
        x = self.fc_layer1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc_layer2(x)
        x = F.relu(x)
        x = self.softmax(x)
        return x

In [120]:
class Explainer_Classifier(pl.LightningModule):
    def __init__(self, hparams):
        super(Explainer_Classifier, self).__init__()
        self.hparams = hparams
        self.model = MobileNet_MNIST()
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        data, labels = batch
        outputs = self.forward(data)
        #lbls = labels.type(torch.FloatTensor)
        #lbls = lbls.unsqueeze(1)
        train_loss = self.criterion(outputs, labels)
        tqdm_dict = {'train_loss':train_loss}
        outputs = OrderedDict({
            'loss':train_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        return outputs
    
    def validation_step(self, batch, batch_idx):
        #print(batch)
        #sys.exit()
        data, labels = batch
        #print(data.shape)
        #sys.exit()
        outputs = self.forward(data)
        #lbls = labels.type(torch.FloatTensor)
        #lbls = lbls.unsqueeze(1)
        train_loss = self.criterion(outputs, labels)
        tqdm_dict = {'train_loss':train_loss}
        outputs = OrderedDict({
            'loss':train_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        return outputs

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.1)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        train_dataset = datasets.MNIST(
                            'dataset/', train=True, download=True,
                            transform = transforms.Compose([transforms.ToTensor()]))
        train_dataloader = DataLoader(train_dataset, batch_size=self.hparams.batch_size)
        return train_dataloader

    def val_dataloader(self):
        val_dataset = datasets.MNIST(
                            'dataset/', train=False, download=True,
                            transform = transforms.Compose([transforms.ToTensor()]))
        val_dataloader = DataLoader(val_dataset, batch_size=self.hparams.batch_size)
        return val_dataloader

    def test_dataloader(self):
        test_dataset = datasets.MNIST(
                            'dataset/', train=False, download=True,
                            transform = transforms.Compose([transforms.ToTensor()]))
        test_dataloader = DataLoader(test_dataset, batch_size=self.hparams.batch_size)
        return test_dataloader
    

In [121]:
hyper_params = {'batch_size':64, 'lr':0.001}

In [122]:
#hyper_params = hparams_vals
model = Explainer_Classifier(hyper_params)

In [123]:
trainer = pl.Trainer()

GPU available: False, used: False
TPU available: None, using: 0 TPU cores


In [124]:
trainer.fit(model)


  | Name      | Type             | Params
-----------------------------------------------
0 | model     | MobileNet_MNIST  | 4.1 M 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…






1

In [81]:
f = torch.randn(64,1,28,28)

In [82]:
f.shape

torch.Size([64, 1, 28, 28])

In [86]:
z = f.repeat(1,3,1,1)
z.shape

torch.Size([64, 3, 28, 28])