In [11]:
import os
import torch
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import numpy as np
import torchmetrics

import torch.utils.data as data
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [12]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, root_dir: str, batch_size: int = 32):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def prepare_data(self):
        # Downloads only, do not transform
        MNIST(self.root_dir, train=True, download=False)
        MNIST(self.root_dir, train=False, download=False)

    def setup(self, stage=None):
        # Transform and split datasets
        if stage == 'fit' or stage is None:
            mnist_train = MNIST(self.root_dir, train=True, transform=self.transform)
            train_set_size = int(len(mnist_train) * 0.8)
            valid_set_size = len(mnist_train) - train_set_size
            self.train_set, self.valid_set = random_split(mnist_train, [train_set_size, valid_set_size])
        if stage == 'test' or stage is None:
            self.test_set = MNIST(self.root_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_set, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size)

In [19]:
class LitClassifierModel(L.LightningModule):
    def __init__(self, hidden_dim: int = 64, learning_rate=2e-4):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.l1 = nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = nn.Linear(self.hparams.hidden_dim, 10)
        
        self.accuracy = torchmetrics.Accuracy(num_classes=10, average='macro', task='multiclass')
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.accuracy(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        return {'loss': loss, 'acc': acc}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.accuracy(logits, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return {'val_loss': loss, 'val_acc': acc}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


In [20]:
dm = MNISTDataModule(root_dir="/data2/eranario/data/MNIST-Dataset/lightning")

model = LitClassifierModel()
trainer = L.Trainer(max_epochs=50,
                    default_root_dir="/data2/eranario/intermediate_data/MNIST_logs/classifier")

trainer.fit(model, dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | l1       | Linear             | 50.2 K
1 | l2       | Linear             | 650   
2 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
50.9 K    Trainable params
0         Non-trainable params
50.9 K    Total params
0.204     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 1500/1500 [00:17<00:00, 83.88it/s, v_num=1]       
Validation: |          | 0/? [00:00<?, ?it/s]
Validation:   0%|          | 0/375 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/375 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 1/375 [00:00<00:02, 143.87it/s]
Validation DataLoader 0:   1%|          | 2/375 [00:00<00:02, 128.67it/s]
Validation DataLoader 0:   1%|          | 3/375 [00:00<00:02, 124.67it/s]
Validation DataLoader 0:   1%|          | 4/375 [00:00<00:03, 122.58it/s]
Validation DataLoader 0:   1%|▏         | 5/375 [00:00<00:03, 121.42it/s]
Validation DataLoader 0:   2%|▏         | 6/375 [00:00<00:03, 120.70it/s]
Validation DataLoader 0:   2%|▏         | 7/375 [00:00<00:03, 120.31it/s]
Validation DataLoader 0:   2%|▏         | 8/375 [00:00<00:03, 120.02it/s]
Validation DataLoader 0:   2%|▏         | 9/375 [00:00<00:03, 119.69it/s]
Validation DataLoader 0:   3%|▎         | 10/375 [00:00<00:03, 119.55it/s]
Validation DataLoad

/home/eranario/miniconda3/envs/lightning/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
