In [21]:
# from https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb


In [22]:
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.loggers import TensorBoardLogger

In [23]:
from argparse import ArgumentParser
import __main__ as main

parser = ArgumentParser()
parser.add_argument('--epochs', default=2, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--lr', default=0.02, type=float)

parser.add_argument('--hidden_size', default=64, type=int)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--data_dir', default=f"{os.getcwd()}/mnist", type=str)  # should be datastore
parser.add_argument('--tb_dir', default="TB", type=str)
parser.add_argument('--tb_name', default="MNIST/ex_02", type=str)

if hasattr(main, '__file__'): 
    args = parser.parse_args()      
else:
    args = parser.parse_args("")    # take defaults in Jupyter 

In [24]:
class MNISTModel(pl.LightningModule):
    
    def __init__(self, data_dir=args.data_dir, hidden_size=args.hidden_size, learning_rate=args.lr):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

In [25]:
# Init our model
mnist_model = MNISTModel()

# Initialize a trainer
trainer = pl.Trainer(gpus=args.gpu, max_epochs=args.epochs, progress_bar_refresh_rate=20)

# Train the model ⚡
trainer.fit(mnist_model)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
Epoch 0:  92%|█████████▏| 1720/1876 [00:15<00:01, 113.00it/s, loss=0.647, v_num=1, val_loss=2.320, val_acc=0.141]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 0:  93%|█████████▎| 1740/1876 [00:15<00:01, 113.19it/s, loss=0.647, v_num=1, val_loss=2.320, val_acc=0.141]
Epoch 0:  94%|█████████▍| 1760/1876 [00:15<00:01, 112.69it/s, loss=0.647, v_num=1, val_loss=2.320, val_acc=0.141]
Epoch 0:  95%|█████████▍| 1780/1876 [00:15<00:00, 112.33it/s, loss=0.647, v_num=1, val_loss=2.320, val_acc=0.141]
Epoch 0:  96%|█████████▌| 1800/1876 [00:16<00:00, 112.15it/s, loss=0.647, v_num=1, val_loss=2.320, val_acc=0.1

1