In [None]:
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ProgressBar
import os
import sys

In [None]:
# source images are in (28*28*1)->(24*24*28)->(12*12*28)->(10*10*10)->(5*5*10)->250->18->10
class MyLightningModel(pl.LightningModule):
    def __init__(self):
        super(MyLightningModel, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 28, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2))
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(28, 10, kernel_size=2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2))
        self.dropout1=torch.nn.Dropout(0.25)
        self.fc1=torch.nn.Linear(250,18)
        self.dropout2=torch.nn.Dropout(0.08)
        self.fc2= torch.nn.Linear(18,10)
        
        self.valTotal = 0
        self.valCorrect = 0
        self.trainTotal = 0
        self.trainCorrect = 0
        self.testTotal = 0
        self.testCorrect = 0
        self.epoch = 0
    
    def prepare_data(self):
        transform=transforms.Compose([transforms.ToTensor(), 
                                  transforms.Normalize((0.1307,), (0.3081,))])
        #download data
        MNIST(os.getcwd(), train=True, download =True)
        MNIST(os.getcwd(), train=False, download =True)
        
        mnist_train= MNIST(os.getcwd(), train=True, download =False, transform= transform)
        self.train_set, self.val_set = random_split(mnist_train,[55000,5000])
        self.test_set = MNIST(os.getcwd(), train=False, download =False, transform = transform)
        
    
    def train_dataloader(self):
        return DataLoader(self.train_set,batch_size=128)
        
    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=128)
    
    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=128)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.dropout1(x)
        x = self.fc1(x.view(x.size(0), -1))
        x = torch.relu(x)
        x = self.dropout2(x)
        x = F.leaky_relu(x)
        x=self.fc2(x)
        return torch.log_softmax(x, dim=1)
       # return F.softmax(x)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)
    
    def on_epoch_start(self):
        self.valTotal=0
        self.valCorrect=0
        self.testTotal=0
        self.testCorrect=0
        self.trainTotal=0
        self.trainCorrect=0

        
    # Step is called for every batch in our dataset
    def training_step(self, batch, batch_index):
        x,y = batch
        pred = self.forward(x)
        self.trainCorrect+=pred.argmax(dim=1).eq(y).sum().item()
        self.trainTotal += len(y)
        #calculating the loss
        loss = F.cross_entropy(pred, y)

        logs={"train_loss" : loss}
        output = {"loss": loss, #essential
                 "log": logs
                 }
        return output

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.forward(x)
        self.valCorrect+=pred.argmax(dim=1).eq(y).sum().item()
        self.valTotal += len(y)
        loss = F.cross_entropy(pred, y)
        return {'val_loss': loss}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        pred = self.forward(x)
        self.testCorrect+=pred.argmax(dim=1).eq(y).sum().item()
        self.testTotal += len(y)
        loss = F.cross_entropy(pred, y)
        return {'test_loss': loss}
    
    def validation_epoch_end(self, outputs):
        if(self.epoch !=0):
            print("Epoch:{}".format(self.epoch))
        self.epoch += 1
        print("Validation Accuracy= {}\nNumber of Correctly identified Validation Images {} from a set of {}.".format(self.valCorrect/self.valTotal,self.valCorrect,self.valTotal))
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss,"Accuracy": self.valCorrect/self.valTotal}
        return {'val_loss': avg_loss, 'log': logs}

    def training_epoch_end(self,outputs):
        print("\nTraining Accuracy= {}\nNumber of Correctly identified Training Set Images {} from a set of {}.".format(self.trainCorrect/self.trainTotal,self.trainCorrect,self.trainTotal))
        print("---------------------------------------------------")
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        logs = {'loss': avg_loss,"Accuracy": self.trainCorrect/self.trainTotal}
        return {'loss': avg_loss, 'log': logs}
    
    
    def test_epoch_end(self, outputs):
        print("Testing Accuracy= {}\nNumber of Correctly identified Testing Images {} from a set of {}.".format(self.testCorrect/self.testTotal,self.testCorrect,self.testTotal))
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss, "Accuracy": self.testCorrect/self.testTotal}
        return {'test_loss': avg_loss, 'log': logs, 'progress_bar': logs}
    

class LitProgressBar(ProgressBar):
    def __init__(self):
        super().__init__()  # don't forget this :)
        self.enable = True
    def disable(self):
        self.enable = False
    def on_batch_end(self, trainer, pl_module):
        super().on_batch_end(trainer, pl_module)  # don't forget this :)
        percent = (self.train_batch_idx / self.total_train_batches) * 100
        sys.stdout.flush()
        sys.stdout.write(f'{percent:.01f} percent of epoch complete \r')
    
       



In [None]:
bar = LitProgressBar()
# trainer abstract away batch iteration, epoch iteration, optimize.step() and validation loop
trainer = pl.Trainer(gpus=[0], max_nb_epochs=5,checkpoint_callback=False, callbacks=[bar])

model = MyLightningModel()
trainer.fit(model)

In [None]:
trainer.test()

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir lightning_logs/