In [None]:
import os

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

In [None]:
from torchvision.datasets import MNIST
from torchvision import datasets, transforms

In [None]:
import pytorch_lightning as pl

In [None]:
from pytorch_lightning import Trainer

In [None]:
# transforms
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))])

# data
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=64)

In [None]:
# build your model
class CustomMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer1 = torch.nn.Linear(28 * 28, 128)
        self.layer2 = torch.nn.Linear(128, 256)
        self.layer3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        x = self.layer1(x)
        x = torch.relu(x)

        x = self.layer2(x)
        x = torch.relu(x)

        x = self.layer3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def training_step(self, batch, batch_idx):
        data, target = batch
        logits = self.forward(data)
        loss = F.nll_loss(logits, target)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [8]:
# train your model
model = CustomMNIST()
trainer = Trainer(max_epochs=5, gpus=0)

trainer.fit(model, mnist_train_loader)

Epoch 4: 100%|██████████| 938/938 [00:36<00:00, 25.46it/s, loss=0.0348, v_num=0]


Epoch 1:  45%|████▌     | 423/938 [00:10<00:12, 41.42it/s, loss=0.125, v_num=2]


1