# Build and Test our LightningModule

This is where we will build our LightningModule and test it. The PyTorch Lightning Lightningmodule is a wrapper around a vanilla PyTorch nn.module that provides functionality for configuring optimizers, defining training and validation steps, and more.

In [1]:
import rootutils
import os
root = rootutils.setup_root(search_from=os.getcwd(), indicator=".project-root", dotenv=True, pythonpath=True, cwd=True)

In [2]:
import lightning as L
from monai import utils, transforms, networks, data, engines, losses, metrics, visualize, config, inferers, apps
from monai.data import CacheDataset, DataLoader, list_data_collate, pad_list_data_collate, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
import torch
import matplotlib.pyplot as plt
# make sure you have numpy=1.26.0 bc of a bug between newer numpy and monai at the time of this writing. Will surely be solved by monai team in future.
import numpy as np
import pandas as pd
import glob
import os
import shutil
import tempfile
import rootutils

Let's paste our DataModule from the previous notebook.

In [3]:
class SpleenDatamodule(L.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 2,
        num_workers: int = 4,
        pin_memory: bool = False,
    ) -> None:
        super().__init__()
        # this below line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        # self.batch_size = batch_size
        # self.num_workers = num_workers

    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        # we already downloaded our data so we will not do anything here
        pass

    def setup(self, stage=None):
        # the stage is 'fit', 'validate', 'test', or 'predict'

        # We first need to define the transforms that we will apply to the data
        # Transforms may sound like they just augment the data, but they can also be used to preprocess the data, including loading the data and converting it to the correct format
        self.train_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
                transforms.RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4),
            ]
        )

        self.val_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
            ]
        )
        
        self.test_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
            ]
        )

        # make assignments here (val/train/test split)
        # called on every process in DDP
        # TODO: os.path.join() is not working???
        # IMAGE_SRC = os.path.join(root,"/data/Task09_Spleen/imagesTr")
        IMAGE_SRC = str(root) + "/data/Task09_Spleen/imagesTr"
        # LABEL_SRC = os.path.join(root,"/data/Task09_Spleen/labelsTr")
        LABEL_SRC = str(root) + "/data/Task09_Spleen/labelsTr"
        SPLIT_NAME = "MySplit"
        # this can be done by stage
        # filenames = None
        # if stage == "fit":
        #     filenames = pd.read_csv(f"/splits/{SPLIT_NAME}/train_{SPLIT_NAME}.csv")
        #     # Create a dictionary list of the image and label files labelled as 'image' and 'label'
        #     self.train_files = [{"image": os.path.join(IMAGE_SRC, f"{filename}.nii.gz"), "label": os.path.join(LABEL_SRC, f"{filename}.nii.gz")} for filename in filenames]
        #     self.train_ds = CacheDataset(data=self.train_files, transform=self.train_transforms, cache_rate=1.0, num_workers=self.num_workers)
        # elif stage == "validate":
        #     filenames = pd.read_csv(f"/splits/{SPLIT_NAME}/val_{SPLIT_NAME}.csv")
        #     self.val_files = [{"image": os.path.join(IMAGE_SRC, f"{filename}.nii.gz"), "label": os.path.join(LABEL_SRC, f"{filename}.nii.gz")} for filename in filenames]
        #     self.val_ds = CacheDataset(data=self.val_files, transform=self.val_transforms, cache_rate=1.0, num_workers=self.num_workers)
        # elif stage == "test":
        #     filenames = pd.read_csv(f"/splits/{SPLIT_NAME}/test_{SPLIT_NAME}.csv")
        #     self.test_files = [{"image": os.path.join(IMAGE_SRC, f"{filename}.nii.gz"), "label": os.path.join(LABEL_SRC, f"{filename}.nii.gz")} for filename in filenames]
        #     self.test_ds = CacheDataset(data=self.test_files, transform=self.test_transforms, cache_rate=1.0, num_workers=self.num_workers)
        # else:
        #     raise ValueError(f"Stage {stage} not supported")

        # TODO: os.path.join() is not working for some reason
        # train_csv = os.path.join(str(root),f"/splits/{SPLIT_NAME}/train_{SPLIT_NAME}.csv")
        train_csv = f"/splits/{SPLIT_NAME}/train_{SPLIT_NAME}.csv"
        # train_csv = os.path.join(str(root), train_csv)
        train_csv = str(root) + train_csv
        train_filenames = pd.read_csv(train_csv)
        self.train_files = [{"image": os.path.join(IMAGE_SRC, row["image"]), "label": os.path.join(LABEL_SRC, row["label"])} for index, row in train_filenames.iterrows()]
        # Take only first few files of the dataset for testing
        self.train_files = self.train_files[:2]     # take only the first 5 files for testing; COMMENT THIS LINE OUT FOR ACTUAL TRAINING
        self.train_ds = CacheDataset(data=self.train_files, transform=self.train_transforms, cache_rate=1.0, num_workers=self.hparams.num_workers)
        
        val_csv = f"/splits/{SPLIT_NAME}/val_{SPLIT_NAME}.csv"
        val_csv = str(root) + val_csv
        val_filenames = pd.read_csv(val_csv)
        self.val_files = [{"image": os.path.join(IMAGE_SRC, row["image"]), "label": os.path.join(LABEL_SRC, row["label"])} for index, row in val_filenames.iterrows()]
        self.val_files = self.val_files[:2]     # take only the first 5 files for testing; COMMENT THIS LINE OUT FOR ACTUAL TRAINING
        self.val_ds = CacheDataset(data=self.val_files, transform=self.val_transforms, cache_rate=1.0, num_workers=self.hparams.num_workers)
        
        test_csv = f"/splits/{SPLIT_NAME}/test_{SPLIT_NAME}.csv"
        test_csv = str(root) + test_csv
        test_filenames = pd.read_csv(test_csv)
        self.test_files = [{"image": os.path.join(IMAGE_SRC, row["image"]), "label": os.path.join(LABEL_SRC, row["label"])} for index, row in test_filenames.iterrows()]
        self.test_files = self.test_files[:2]     # take only the first 5 files for testing; COMMENT THIS LINE OUT FOR ACTUAL TRAINING
        self.test_ds = CacheDataset(data=self.test_files, transform=self.test_transforms, cache_rate=1.0, num_workers=self.hparams.num_workers)
        


    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_ds,
            batch_size=self.hparams.batch_size,         # we can use this nifty trick and access the hyperparameters directly since we used self.save_hyperparameters() up top
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=list_data_collate,               # this collates our list of dictionaries into a dictionary of lists; not needed for if your dataset outputs something the default collate_fn can handle
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_ds,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=pad_list_data_collate,
            shuffle=False,
        )
    
    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_ds,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=list_data_collate,
            shuffle=False,
        )

In [4]:
class SpleenLightningModule(L.LightningModule):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer = None,
        scheduler: torch.optim.lr_scheduler = None,
        loss_fn: torch.nn.Module = None,
        compile: bool = False,
        lr: float = 1e-3,
        ) -> None:
        super().__init__()
        self.save_hyperparameters()

        self.model = networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        )

        self.loss_fn = losses.DiceLoss(to_onehot_y=True, softmax=True)
        self.metric = metrics.DiceMetric(include_background=False, reduction="mean")

        # self.post_pred = transforms.Compose([transforms.EnsureType("tensor", device="cpu"), transforms.AsDiscrete(armax=True, to_onehot=2)])
        self.post_pred = transforms.Compose([transforms.EnsureType("tensor", device="cpu"), transforms.AsDiscrete(armax=True)])
        # The post_pred transform is giving me an error that "labels should have a channel with length equal to on" so I think it needs to output a single channel image
        # I think the issue is that the output of the model is a 2 channel image, but the labels are a single channel image
        # I can fix this by adding a channel dimension to the labels
        self.post_label = transforms.Compose([transforms.EnsureType("tensor", device="cpu"), transforms.AsDiscrete(to_onehot=2)])

        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []
    
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        outputs = self(images)
        loss = self.loss_fn(outputs, labels)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        # We should use sliding window inference not because it reduces memory usage (that's not a problem on HPG) but
        # rather because it apparently can increase accuracy by multiple percentage points for things like Dice.
        # We totally could just have outputs = self.forward(images)
        # However, sliding window inference is apparently more accurate for things like Dice.
        outputs = inferers.sliding_window_inference(images, roi_size, sw_batch_size, self.forward)
        loss = self.loss_fn(outputs, labels)
        print("output shape is:")
        print(outputs.shape)
        print("label shape is:")
        print(labels.shape)
        print("decollated batch shape is:")
        print(decollate_batch(outputs)[0].shape)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        dice = self.metric(y_pred=outputs, y=labels)
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/dice", dice.mean(), on_step=False, on_epoch=True, prog_bar=True)
        d = {"val_loss": loss, "val_number": len(outputs)}
        self.validation_step_outputs.append(d)
        return d
    
    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.metric.aggregate().item()
        self.metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        self.log("val/mean_dice", mean_val_dice, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/mean_loss", mean_val_loss, on_step=False, on_epoch=True, prog_bar=True)
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.validation_step_outputs.clear()  # free memory
        return
    
    def test_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        outputs = inferers.sliding_window_inference(images, roi_size, sw_batch_size, self.forward)
        loss = self.loss_fn(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        dice = self.metric(y_pred=outputs, y=labels)
        self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/dice", dice, on_step=False, on_epoch=True, prog_bar=True)
        return

In [5]:
# Run quick "training" run with fast_dev_run=True so that it just runs one batch of train and val

datamodule = SpleenDatamodule(batch_size=2, num_workers=4, pin_memory=False)
model = SpleenLightningModule(lr=1e-3)

logger = L.pytorch.loggers.csv_logs.CSVLogger(save_dir="logs", name="spleen_dev")

trainer = L.Trainer(
    fast_dev_run=True,
    max_epochs=1,
    devices="auto",
    accelerator="auto",
    logger=logger,
    log_every_n_steps=1,
)

trainer.fit(model=model, datamodule=datamodule)

The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sasank.desaraju/.conda/envs/monai/lib/python3. ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
Loading dataset: 100%|██████████| 2/2 [00:06<00:00,  3.31s/it]
Loading dataset: 100%|██████████| 2/2 [00:06<00:00,  3.15s/it]
Loading dataset: 100%|██████████| 2/2 [00:06<00:00,  3.43s/it]

  | Name    | Type     | Params | Mode 
---------------------------------------------
0 | model   | UNet     | 4.

Epoch 0: 100%|██████████| 1/1 [00:13<00:00,  0.08it/s]output shape is:
torch.Size([2, 2, 272, 264, 223])
label shape is:
torch.Size([2, 1, 272, 264, 223])
decollated batch shape is:
torch.Size([2, 272, 264, 223])
current epoch: 0 current mean dice: 0.0054
best mean dice: 0.0054 at epoch: 0
Epoch 0: 100%|██████████| 1/1 [00:42<00:00,  0.02it/s, val/loss=0.657, val/dice=0.00536, val/mean_dice=0.00536, val/mean_loss=0.329, train/loss=0.652]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:42<00:00,  0.02it/s, val/loss=0.657, val/dice=0.00536, val/mean_dice=0.00536, val/mean_loss=0.329, train/loss=0.652]


Okay, we are ready to incorporate the DataModule and LightningModule into .py files in the `src` directory.