<a href="https://colab.research.google.com/github/soemthlng/lightning_MNIST/blob/main/Ligntning_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
! pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torch.optim as optim
from torchvision import transforms
import pytorch_lightning as pl
from collections import OrderedDict


class LightNet(pl.LightningModule):

    def __init__(self, in_channels=1, out_channels=10):
        super(LightNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(64)

        self.maxpool = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, out_channels)
        

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.dropout1(x)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout2(x)
        logits = self.fc2(x)
        return logits
        
    def criterion(self, logits, targets):
        return F.cross_entropy(logits, targets)

    def training_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)

        #inbuilt tensorboard for logs
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}   


    def validation_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)
        pred = outputs.data.max(1)[1]  # get the index of the max log-probability
        incorrect = pred.ne(targets.long().data).cpu().sum()
        err = incorrect.item()/targets.numel()
        val_acc = torch.tensor(1.0-err)

        return {'val_loss': loss, 'val_acc': val_acc}    

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()  
        Accuracy = 100 * avg_acc.item()
        tensorboard_logs = {'val_loss': avg_loss, 'avg_val_acc': avg_acc}
        print('Val Loss:', round(avg_loss.item(),2), 'Val Accuracy: %f %%' % Accuracy) 
        return {'avg_val_loss': avg_loss, 'progress_bar': tensorboard_logs}

    def prepare_data(self):
        transform=transforms.Compose([transforms.ToTensor(), 
                                      transforms.Normalize((0.1307,), (0.3081,))])
        train_dataset = MNIST('data', train=True, download=True, transform=transform)
        test_dataset = MNIST('data', train=False, download=True, transform=transform)
        self.mnist_train, self.mnist_val = train_dataset, test_dataset

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=128, num_workers=2)

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=1e-3, momentum=0.9, nesterov=False,weight_decay=5e-4)
        return optimizer

In [5]:
model = LightNet()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Missing logger folder: /content/lightning_logs

  | Name     | Type        | Params
-----------------------------------------
0 | conv1    | Conv2d      | 320   
1 | bn1      | BatchNorm2d | 64    
2 | relu     | ReLU        | 0     
3 | conv2    | Conv2d      | 18.5 K
4 | bn2      | BatchNorm2d | 128   
5 | maxpool  | MaxPool2d   | 0     
6 | dropout1 | Dropout2d   | 0     
7 | dropout2 | Dropout2d   | 0     
8 | fc1      | Linear      | 1.2 M 
9 | fc2      | Linear      | 1.3 K 
-----------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.800     Total estimated model params size (MB)


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



Sanity Checking: 0it [00:00, ?it/s]



Val Loss: 2.31 Val Accuracy: 3.125000 %


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Val Loss: 0.21 Val Accuracy: 96.766216 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.14 Val Accuracy: 97.873813 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.13 Val Accuracy: 98.299050 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.11 Val Accuracy: 98.388052 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.09 Val Accuracy: 98.546284 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.09 Val Accuracy: 98.615509 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.08 Val Accuracy: 98.714399 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.09 Val Accuracy: 98.417723 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.07 Val Accuracy: 98.793513 %


Validation: 0it [00:00, ?it/s]

Val Loss: 0.08 Val Accuracy: 98.635286 %
