<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [5]:
import numpy as np
import pandas as pd
import torch
from torchvision import datasets
import pytorch_lightning as pl
from torchvision import models
from torchvision import transforms

In [7]:
class Explainer_Classifier(pl.LightningModule):
    def __init__(self, hparams):
        super (Explainer_Classifier, self).__init__()
        self.hparams = hparams
        self.model = models.MobileNetV2(pretrained=True)

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        data, labels = batch[0], batch[1]
        outputs = self.forward(data)
        lbls = labels.type(torch.cuda.FloatTensor)
        lbls = lbls.unsqueeze(1)
        train_loss = self.criterion(outputs, lbls)
        tqdm_dict = {'train_loss':train_loss}
        outputs = OrderedDict({
            'loss':train_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        return outputs
    
    def validation_step(self, batch, batch_idx):
        data, labels = batch[0], batch[1]
        outputs = self.forward(data)
        lbls = labels.type(torch.cuda.FloatTensor)
        lbls = lbls.unsqueeze(1)
        train_loss = self.criterion(outputs, lbls)
        tqdm_dict = {'train_loss':train_loss}
        outputs = OrderedDict({
            'loss':train_loss,
            'progressbar':tqdm_dict,
            'log':tqdm_dict
        })
        return outputs

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(lr=self.hparams.lr)
        return optimizer

    def train_dataloader(self):
        train_dataloader = datasets.MNIST(
                            'dataset/', train=True, 
                            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]))
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = datasets.MNIST(
                            'dataset/', train=False, 
                            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]))
        return val_dataloader

    def test_dataloader(self):
        test_dataloader = datasets.MNIST(
                            'dataset/', train=False, 
                            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]))
        return test_dataloader
    