In [1]:
import torch
from lightning_datasets import MnistDataModule
import pytorch_lightning as L
from torch import nn
from torchmetrics import Accuracy
from pytorch_lightning.loggers import MLFlowLogger
import matplotlib.pyplot as plt

In [2]:
torch.set_float32_matmul_precision('highest')

In [3]:
class NIN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            NIN.NiN_block(1, 96, kernel_size=11, stride=4, padding=0),
            nn.MaxPool2d(3, stride=2),
            NIN.NiN_block(96, 256, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(3, stride=2),
            NIN.NiN_block(256, 384, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(3, stride=2),
            nn.Dropout(0.5),
            NIN.NiN_block(384, 10, kernel_size=3, stride=1, padding=1),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )

        self.net.apply(NIN.__init_vars)
        self.lr = 0.1
        self.loss = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task='multiclass', num_classes=10)

    def __init_vars(layer):
        if isinstance(layer, nn.Conv2d):
            nn.init.kaiming_uniform_(layer.weight, 0.2)

    def forward(self, X):
        return self.net(X)

    def training_step(self, batch, batch_idx):
        loss, output, y = self.__common_step(batch, batch_idx)
        accuracy = self.accuracy(output, y)
        self.logger.log_metrics({"training_loss": loss, "train_acc": accuracy})
        return loss

    def validation_step(self, batch, batch_idx):
        loss, output, y = self.__common_step(batch, batch_idx)
        accuracy = self.accuracy(output, y)
        self.logger.log_metrics(
            {"validation_loss": loss, "validation_acc": accuracy})
        return loss

    def test_step(self, batch, batch_idx):
        loss, output, y = self.__common_step(batch, batch_idx)
        accuracy = self.accuracy(output, y)
        self.logger.log_metrics({"test_loss": loss, "test_acc": accuracy})
        return loss

    def on_train_start(self):
        self.logger.log_hyperparams({"learning_rate": self.lr})

    def __common_step(self, batch, batch_idx):
        X, y = batch
        output = self(X)
        loss = self.loss(output, y)
        return loss, output, y

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.lr)

    def NiN_block(in_channel, out_channel, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel, 1),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel, 1),
            nn.ReLU()
        )

In [4]:
model = NIN()

k = torch.randn((1,1,224,224))

for i in model.net:
    k = i(k)
    print(k.shape)

torch.Size([1, 96, 54, 54])
torch.Size([1, 96, 26, 26])
torch.Size([1, 256, 26, 26])
torch.Size([1, 256, 12, 12])
torch.Size([1, 384, 12, 12])
torch.Size([1, 384, 5, 5])
torch.Size([1, 384, 5, 5])
torch.Size([1, 10, 5, 5])
torch.Size([1, 10, 1, 1])
torch.Size([1, 10])


In [5]:
model = NIN()
dataset = MnistDataModule(image_size=(224, 224), batch_size=128)
logger = MLFlowLogger(run_name="NIN mnist")

trainer = L.Trainer('gpu', logger=logger, max_epochs=20)


Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, dataset)

You are using a CUDA device ('NVIDIA GeForce RTX 4070 Ti SUPER') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | net      | Sequential         | 2.0 M  | train
1 | loss     | CrossEntropyLoss   | 0      | train
2 | accuracy | MulticlassAccuracy | 0      | train
--------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.969     Total estimated model params size (MB)
37        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.
