In [6]:
import torch
from torch import nn
import torch.nn.functional as F
import lightning as L
import torchmetrics
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet18
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

class ImagenetteDataModule(L.LightningDataModule):
    def __init__(self, data_dir="data/imagenette", batch_size=128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.train_transforms = transforms.Compose([
            transforms.CenterCrop(160),
            transforms.Resize(64),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        self.test_transforms = self.train_transforms

    def prepare_data(self):
        datasets.Imagenette(self.data_dir, split="train", size="160px", download=True)
        datasets.Imagenette(self.data_dir, split="val", size="160px", download=True)

    def setup(self, stage=None):
        full_train = datasets.Imagenette(self.data_dir, split="train", size="160px", transform=self.train_transforms)
        self.train_set, self.val_set = random_split(full_train, [int(0.9*len(full_train)), len(full_train)-int(0.9*len(full_train))])
        self.val_set.dataset.transform = self.test_transforms
        self.test_set = datasets.Imagenette(self.data_dir, split="val", size="160px", transform=self.test_transforms)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size)

class BaselineModel(L.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.estimator = nn.Sequential(
            nn.Linear(64 * 64, 1024), nn.ReLU(),
            nn.Linear(1024, 512), nn.ReLU(),
            nn.Linear(512, 128), nn.ReLU(),
            nn.Linear(128, num_classes)
        )
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.estimator(x.view(x.size(0), -1))

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.accuracy(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.accuracy(logits, y)
        self.log("test_loss", loss)
        self.log("test_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

class SimpleCNN(L.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 256), nn.ReLU(),
            nn.Linear(256, num_classes)
        )
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.fc(self.conv(x))

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.accuracy(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = self.accuracy(logits, y)
        self.log("test_loss", loss)
        self.log("test_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Training script
datamodule = ImagenetteDataModule()
model = SimpleCNN(num_classes=10)  # Change to BaselineModel() or ResNet18Model() as needed

callbacks = [
    EarlyStopping(monitor="val_loss", patience=5, mode="min"),
    ModelCheckpoint(monitor="val_acc", mode="max", filename="best-model", save_top_k=1)
]

trainer = L.Trainer(max_epochs=50, callbacks=callbacks)
trainer.fit(model, datamodule=datamodule)
trainer.test(model, datamodule=datamodule)

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | conv     | Sequential         | 18.8 K | train
1 | fc       | Sequential         | 4.2 M  | train
2 | accuracy | MulticlassAccuracy | 0      | train
--------------------------------------------------------
4.2 M     Trainable params
0         Non-trainable params
4.2 M     Total params
16.864    Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name     | Type               | Params | Mode 
--------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 2.2128238677978516, 'test_acc': 0.5391082763671875}]

In [4]:
!pip install lightning torchmetrics

Collecting lightning
  Downloading lightning-2.5.2-py3-none-any.whl.metadata (38 kB)
Collecting torchmetrics
  Downloading torchmetrics-1.7.4-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.2-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9