In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os

import torch
import pytorch_lightning as pl 

from torch import nn
from torch.nn import functional as F 
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.datasets import MNIST
from pytorch_lightning.metrics.functional import accuracy

In [2]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)

    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

In [8]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, conf):
        super().__init__()
        self.data_path = conf.data_path
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.dims = conf.dims 
        self.num_classes = conf.num_classes
    
    def prepare_data(self):
        MNIST(self.data_path, train=True, download=True)
        MNIST(self.data_path, train=False, download=True)
    
    def setup(self, stage):
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_path, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_path, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

In [4]:
class LitMNIST(pl.LightningModule):
    def __init__(self, conf):
        super().__init__()

        self.conf = conf 
        channels, width, height = conf.dims
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels*width*height, conf.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(conf.hidden_size, conf.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(conf.hidden_size, conf.num_classes)
        )
    
    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch 
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss 
    
    def validation_step(self, batch, batch_idx):
        x, y = batch 
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
    
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss 
    
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.conf.learning_rate)
        return optimizer

In [5]:
conf = Config(
    data_path=r'data/',
    dims=(1, 28, 28),
    num_classes=10,
    hidden_size=64,
    learning_rate=2e-4,
)
conf

{'data_path': 'data/',
 'dims': (1, 28, 28),
 'num_classes': 10,
 'hidden_size': 64,
 'learning_rate': 0.0002}

In [9]:
dm = MNISTDataModule(conf)
model = LitMNIST(conf)
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [10]:
trainer.fit(model, dm)


  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55 K  
Epoch 0:  91%|█████████ | 1700/1876 [00:11<00:01, 148.95it/s, loss=0.277, v_num=10]
Validating: 0it [00:00, ?it/s][A
Epoch 0:  92%|█████████▏| 1720/1876 [00:11<00:01, 147.80it/s, loss=0.277, v_num=10]
Epoch 0:  93%|█████████▎| 1740/1876 [00:11<00:00, 148.15it/s, loss=0.277, v_num=10]
Epoch 0:  94%|█████████▍| 1760/1876 [00:11<00:00, 148.44it/s, loss=0.277, v_num=10]
Epoch 0:  95%|█████████▍| 1780/1876 [00:11<00:00, 148.77it/s, loss=0.277, v_num=10]
Epoch 0:  96%|█████████▌| 1800/1876 [00:12<00:00, 149.17it/s, loss=0.277, v_num=10]
Epoch 0:  97%|█████████▋| 1820/1876 [00:12<00:00, 149.57it/s, loss=0.277, v_num=10]
Epoch 0:  98%|█████████▊| 1840/1876 [00:12<00:00, 148.65it/s, loss=0.365, v_num=10, val_loss=0.26, val_acc=0.923]
Epoch 1:  91%|█████████ | 1700/1876 [00:10<00:01, 155.88it/s, loss=0.214, v_num=10, val_loss=0.26, val_acc=0.923]
Validating: 0it [00:00, ?it/s][A
Epoch 1:  92%

1

In [11]:
trainer.test()

Testing:  89%|████████▉ | 280/313 [00:01<00:00, 202.73it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': tensor(0.9559), 'val_loss': tensor(0.1470)}
--------------------------------------------------------------------------------
Testing:  96%|█████████▌| 300/313 [00:01<00:00, 193.84it/s]


[{'val_loss': 0.14701983332633972, 'val_acc': 0.9559000134468079}]

In [10]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/