# MNIST Pytorch to Pytorch Lightning Sample 1
from 'Pytorch Lightning' YouTube Channel, 'Episode 3: From PyTorch to PyTorch Lightning'<br/>
https://youtu.be/DbESHcCoWbM<br/>
<br/>
pytorch code<br/>
https://www.kaggle.com/stpeteishii/mnist-pytorch-linear-sample


This notebook referred to Niko Gamulin's scripts.<br/>
https://colab.research.google.com/drive/1_YYYHRA-blncinGFz3dJeZExAoWxIkGh?usp=sharing

In [1]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [2]:
!pip install pytorch-lightning

In [3]:
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

In [4]:
train_accuracy = pl.metrics.Accuracy()
valid_accuracy = pl.metrics.Accuracy(compute_on_step=False)

In [5]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
        datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        transform=transforms.Compose([transforms.ToTensor()])
        training_dataset = datasets.MNIST('data', train=True, download=False, transform=transform)
        test_dataset = datasets.MNIST('data', train=False, download=False, transform=transform)
        mnist_train, mnist_val = random_split(training_dataset, [55000, 5000])
        
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = test_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [6]:
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)
        self.loss = nn.CrossEntropyLoss()
  
    def forward(self, x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do = self.do(h2 + h1)
        logits = self.l3(do)
        return logits

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=1e-2)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        b = x.size(0)
        x = x.view(b, -1)
        l = self(x) # l: logits
        J = self.loss(l, y)
        acc = train_accuracy(l, y)
        pbar = {'train_acc': acc}
        return {'loss': J, 'progress_bar': pbar}

    def validation_step(self, batch, batch_idx):
        results = self.training_step(batch, batch_idx)
        results['progress_bar']['val_acc'] = results['progress_bar']['train_acc']
        del results['progress_bar']['train_acc']
        return results

    def validation_epoch_end(self, val_step_outputs):
        avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
        avg_val_acc = torch.tensor([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()
        pbar = {'avg_val_acc': avg_val_acc}
        return {'val_loss': avg_val_loss, 'progress_bar': pbar}

In [7]:
model = ImageClassifier()
mnist_dm = MNISTDataModule()
trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=5)
trainer.fit(model, mnist_dm)

In [8]:
model