## Notebook made easy to run in cloud like google, kaggle.

In [None]:
!pip install -q lightning wandb git+https://github.com/sampath017/extentions.git pillow

In [None]:
import torch
import os
import torch.nn.functional as F
import lightning.pytorch as pl
import numpy as np
import matplotlib.pyplot as plt
import os
import logging
import wandb
import torch

from collections import Counter

from torchvision.transforms import Compose, ToTensor
from torchvision.datasets import MNIST
from torchvision.utils import make_grid


from torchmetrics.functional import accuracy
from lightning.pytorch import LightningModule
from torch import nn

from pathlib import Path
from PIL import Image
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision import datasets, transforms


from pathlib import Path
from lightning.pytorch import (
    callbacks,
    loggers,
    Trainer,
    utilities
)

from model import Digits
from data_module import MNISTDataModule
from extentions.callbacks import DiffEarlyStopping, EarlyStopping

In [None]:
root_path = Path('../')

dataset = MNIST(
    root=(root_path / 'data').as_posix(),
    train=True,
    download=True,
)

dataset

In [None]:
## Single data point
for img, label in dataset:
    print(img, img.getextrema())
    plt.imshow(img, cmap='gray')
    plt.axis(False)
    plt.title(label)

    plt.show()
    break

In [None]:
## A grid of images from each class

dataset.transform = Compose([
    ToTensor()
])

num_classes = 10
images_per_class = 10

# Create a list to store the images for each class
class_images = [[] for _ in range(num_classes)]

# Iterate over the dataset and collect the first `images_per_class` images for each class
for image, label in dataset:
    if len(class_images[label]) < images_per_class:
        class_images[label].append(image)
    if all(len(images) == images_per_class for images in class_images):
        break

# Create a single grid of images with one row per class
grid = make_grid([img for images in class_images for img in images], nrow=images_per_class)
plt.imshow(grid.permute(1, 2, 0))

# Add labels for each class at the start of the row
for i in range(num_classes):
    plt.text(-20, i * 31 + 17, str(i), fontsize=16)

plt.axis(False)
plt.show()

In [None]:
## Random images

dataloader = DataLoader(dataset, batch_size=7, shuffle=True)

for images, labels in dataloader:
    grid = make_grid(images)
    plt.imshow(grid.permute(1, 2, 0))

    # Add labels for the images
    for i, label in enumerate(labels):
        plt.text(i * 32 + 8, -4, str(label.item()), fontsize=16)

    plt.axis(False)
    plt.show()
    break

In [None]:
## Class distributions

# Get the labels for all images in the dataset
labels = [label for _, label in dataset]

# Count the number of occurrences of each label
label_counts = Counter(labels)
label_counts_array = np.array(list(label_counts.values()))

In [None]:
print(f"Mean: {label_counts_array.mean()}, STD: {label_counts_array.std()}")

# Plot the distribution of labels
plt.bar(list(label_counts.keys()), list(label_counts.values()))
plt.xticks(range(10))
plt.xlabel('Class')
plt.ylabel('Count')
plt.title('Class Distribution')
plt.show()

In [None]:
model_cnn = nn.Sequential(
    # Feature extractor
    nn.Conv2d(1, 32, kernel_size=3, padding='same'),
    nn.ReLU(),
    nn.Conv2d(32, 32, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
    nn.Dropout(p=0.25),

    nn.Conv2d(32, 64, kernel_size=3, padding='same'),
    nn.ReLU(),
    nn.Conv2d(64, 64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
    nn.Dropout(p=0.25),

    nn.Conv2d(64, 128, kernel_size=3, padding='same'),
    nn.ReLU(),
    nn.Conv2d(128, 128, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
    nn.Dropout(p=0.25),

    # learner
    nn.Flatten(),

    nn.Linear(128, 256),
    nn.ReLU(),
    nn.Dropout(p=0.5),

    nn.Linear(256, 10)
)

model_ffn = nn.Sequential(
    nn.Flatten(),

    nn.Linear(28*28, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 10)
)


class Digits(LightningModule):
    def __init__(self, optimizer_name, optimizer_hparams={}):
        super().__init__()
        self.save_hyperparameters()
        self.model = model_cnn
        # self.model = model_ffn

    def forward(self, X):
        return self.model(X)

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self(X)
        loss = F.cross_entropy(y_pred, y)
        acc = accuracy(y_pred, y, task='multiclass', num_classes=10)

        self.log("train_loss", loss, on_step=True, on_epoch=True)
        self.log("train_acc", acc*100.0, on_step=True, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self(X)
        loss = F.cross_entropy(y_pred, y)
        acc = accuracy(y_pred, y, task='multiclass', num_classes=10)

        self.log("val_loss", loss, on_step=True, on_epoch=True)
        self.log("val_acc", acc*100.0, on_step=True, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self(X)
        loss = F.cross_entropy(y_pred, y)
        acc = accuracy(y_pred, y, task='multiclass', num_classes=10)

        self.log("test_loss", loss)
        self.log("test_acc", acc*100.0)

    def predict_step(self, batch, batch_idx):
        logits = self(batch)
        probs = F.softmax(logits)

        return probs

    def configure_optimizers(self):
        if self.hparams.optimizer_name == 'SGD':  # type: ignore
            optimizer = torch.optim.SGD(
                self.parameters(), **self.hparams.optimizer_hparams)  # type: ignore

        elif self.hparams.optimizer_name == 'Adam':  # type: ignore
            optimizer = torch.optim.Adam(
                self.parameters(), **self.hparams.optimizer_hparams)  # type: ignore
        else:
            assert False, f'Unknown optimizer: "{self.hparams.optimizer_name}"'  # type: ignore # nopep8

        return optimizer

In [None]:
class PredictMnist(Dataset):
    def __init__(self, data_dir=Path("."), transform=None):
        # Store the image and its label as attributes
        self.data_dir = data_dir
        self.transform = transform
        self.data = list(data_dir.iterdir())

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img = Image.open(self.data[index])
        img = img.convert(mode="L")

        img = self.transform(img)
        return img


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.batch_size = 64
        self.cpu_count = c if (c := os.cpu_count()) else 8

    def prepare_data(self):
        # download (ontime process)
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign datasets for use in dataloaders
        if stage == "fit":
            dataset = datasets.MNIST(
                self.data_dir,
                train=True,
                transform=self.transforms
            )
            train_size = int(0.7 * len(dataset))
            valid_size = len(dataset) - train_size

            self.train_dataset, self.valid_dataset = random_split(
                dataset,
                [train_size, valid_size]
            )

        # Assign test dataset for use in dataloader's
        if stage == "test":
            self.test_dataset = datasets.MNIST(
                self.data_dir,
                train=False,
                transform=self.transforms
            )

        if stage == "predict":
            self.predict_dataset = PredictMnist(
                self.data_dir/"predict", transform=self.transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.cpu_count)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, num_workers=self.cpu_count)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.cpu_count)

    def predict_dataloader(self):
        return DataLoader(self.predict_dataset, batch_size=self.batch_size, num_workers=self.cpu_count)


## Training

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.INFO)
root_path = Path('../')
os.environ['WANDB_NOTEBOOK_NAME'] = "train.ipynb"

dm = MNISTDataModule(data_dir=(root_path / 'data').as_posix())

In [None]:
model = Digits(
    optimizer_name='Adam',
    # optimizer_hparams={
    #     'lr': 0.001,
    #     'momentum': 0.9
    # }
)

earlystopping_callbacks = [
    DiffEarlyStopping(
        monitor1="val_loss",
        monitor2="train_loss",
        diff_threshold=0.05, # like val_loss=0.09, train_loss=0.04
        patience=5,
        verbose=True
    ),
    EarlyStopping(
        monitor="val_acc",
        min_delta=0.0,
        mode='max',
        stopping_threshold=99.99,
        patience=5,
        verbose=True
    ),
]

checkpoint_callback = callbacks.ModelCheckpoint(
    filename="epoch={epoch}-loss={val_loss:.3f}",
    auto_insert_metric_name=False,
    monitor='val_loss',
    mode='min',
    save_top_k=3
)

In [None]:
utilities.model_summary.ModelSummary(model)

In [None]:
log_dir = root_path/'logs'
log_dir.mkdir(exist_ok=True)
logger = loggers.WandbLogger(
    project='Digits',
    save_dir=log_dir,
    log_model='all',
)

max_time =  {'minutes': 20} if torch.cuda.is_available() else {'hours': 2}
trainer = Trainer(
    min_epochs=10,
    max_epochs=50,
    log_every_n_steps=1,
    max_time=max_time,
    logger=logger,
    callbacks=[checkpoint_callback] + earlystopping_callbacks, # type: ignore
    enable_model_summary=False,
)

In [None]:
trainer.fit(model, datamodule=dm)
wandb.finish()

In [None]:
checkpoint_callback.best_model_path

## Evaluate

In [None]:
from pathlib import Path

import wandb
import torch
from lightning.pytorch import loggers, utilities, Trainer

from model import Digits
from data_module import MNISTDataModule

In [None]:
model = Digits.load_from_checkpoint(model_path, map_location=torch.device('cpu'))
utilities.model_summary.ModelSummary(model)

In [None]:
trainer = Trainer(
    inference_mode=True,
    logger=logger,
    enable_model_summary=False
)

In [None]:
trainer.test(model, datamodule=dm);

In [None]:
wandb.finish()