# A real NN, training MNIST data
We'll see what overfitting means.  On a trivial dataset, which we generate, we will have the neural network fit noise.  Early stopping can solve this problem.  In this exercise, we use Pytorch with Lightning to set up your neural network.

(c) Patrick van der Smagt, March 2023, heavily building on lightning documentation.  Please do not distribute this without Patrick's consent (he will give it).

In [None]:
import os
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE =  64
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [None]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
import lightning.pytorch as pl
from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torchvision import transforms
from torchvision.datasets import MNIST
from torchmetrics import Accuracy

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class MNISTModel(pl.LightningModule):
    def __init__(self, data_dir=PATH_DATASETS):
        super().__init__()
        self.hidden = 4
        self.data_dir = data_dir

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, self.hidden),
            nn.ReLU(),
            nn.Dropout(0.0),
            nn.Linear(self.hidden, self.hidden),
            nn.ReLU(),
            nn.Dropout(0.0),
            nn.Linear(self.hidden, self.num_classes),
        )

        self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        x = self.model(x)
        return nn.functional.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer
    
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        print ("stage = ", stage)
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

    
mnist_model = MNISTModel()

In [None]:
logger = TensorBoardLogger("tensorlogs", name="mnist")
trainer = pl.Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs=1,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=logger,
)

In [None]:
trainer.fit(mnist_model)

In [None]:
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(mnist_model.mnist_train.dataset[i][0][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(mnist_model.mnist_train.dataset[i][1]))
  plt.xticks([])
  plt.yticks([])

In [None]:
trainer.test()

Print a specific data point and its corresponding label.

Create a Lightning trainer, and start learning.

In [None]:
mnist_model.eval

fig = plt.figure()
for i in range(12):
  plt.subplot(3,4,i+1)
  plt.tight_layout()
  plt.imshow(mnist_model.mnist_val.dataset[i+100][0][0], cmap='gray', interpolation='none')
  pred = np.argmax(mnist_model(mnist_model.mnist_val.dataset[i+100][0]).detach().numpy())
  plt.title("prediction: {}".format(pred))
  plt.xticks([])
  plt.yticks([])
#fig