# Build and Test our LightningModule

This is where we will build our LightningModule and test it. The PyTorch Lightning Lightningmodule is a wrapper around a vanilla PyTorch nn.module that provides functionality for configuring optimizers, defining training and validation steps, and more.

In [1]:
import rootutils
import os
root = rootutils.setup_root(search_from=os.getcwd(), indicator=".project-root", dotenv=True, pythonpath=True, cwd=True)

In [5]:
import lightning as L
from monai import utils, transforms, networks, data, engines, losses, metrics, visualize, config, inferers, apps
from monai.data import CacheDataset, DataLoader, list_data_collate, pad_list_data_collate, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
import torch
import matplotlib.pyplot as plt
# make sure you have numpy=1.26.0 bc of a bug between newer numpy and monai at the time of this writing. Will surely be solved by monai team in future.
import numpy as np
import pandas as pd
import glob
import os
import shutil
import tempfile
import rootutils

Let's paste our DataModule from the previous notebook.

In [3]:
class SpleenDatamodule(L.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 2,
        num_workers: int = 4,
        pin_memory: bool = False,
    ) -> None:
        super().__init__()
        # this below line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        # self.batch_size = batch_size
        # self.num_workers = num_workers

    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        # we already downloaded our data so we will not do anything here
        pass

    def setup(self, stage=None):
        # the stage is 'fit', 'validate', 'test', or 'predict'

        # We first need to define the transforms that we will apply to the data
        # Transforms may sound like they just augment the data, but they can also be used to preprocess the data, including loading the data and converting it to the correct format
        self.train_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
                transforms.RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4),
            ]
        )

        self.val_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
            ]
        )
        
        self.test_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
            ]
        )

        # make assignments here (val/train/test split)
        # called on every process in DDP
        # TODO: os.path.join() is not working???
        # IMAGE_SRC = os.path.join(root,"/data/Task09_Spleen/imagesTr")
        IMAGE_SRC = str(root) + "/data/Task09_Spleen/imagesTr"
        # LABEL_SRC = os.path.join(root,"/data/Task09_Spleen/labelsTr")
        LABEL_SRC = str(root) + "/data/Task09_Spleen/labelsTr"
        SPLIT_NAME = "MySplit"
        # this can be done by stage
        # filenames = None
        # if stage == "fit":
        #     filenames = pd.read_csv(f"/splits/{SPLIT_NAME}/train_{SPLIT_NAME}.csv")
        #     # Create a dictionary list of the image and label files labelled as 'image' and 'label'
        #     self.train_files = [{"image": os.path.join(IMAGE_SRC, f"{filename}.nii.gz"), "label": os.path.join(LABEL_SRC, f"{filename}.nii.gz")} for filename in filenames]
        #     self.train_ds = CacheDataset(data=self.train_files, transform=self.train_transforms, cache_rate=1.0, num_workers=self.num_workers)
        # elif stage == "validate":
        #     filenames = pd.read_csv(f"/splits/{SPLIT_NAME}/val_{SPLIT_NAME}.csv")
        #     self.val_files = [{"image": os.path.join(IMAGE_SRC, f"{filename}.nii.gz"), "label": os.path.join(LABEL_SRC, f"{filename}.nii.gz")} for filename in filenames]
        #     self.val_ds = CacheDataset(data=self.val_files, transform=self.val_transforms, cache_rate=1.0, num_workers=self.num_workers)
        # elif stage == "test":
        #     filenames = pd.read_csv(f"/splits/{SPLIT_NAME}/test_{SPLIT_NAME}.csv")
        #     self.test_files = [{"image": os.path.join(IMAGE_SRC, f"{filename}.nii.gz"), "label": os.path.join(LABEL_SRC, f"{filename}.nii.gz")} for filename in filenames]
        #     self.test_ds = CacheDataset(data=self.test_files, transform=self.test_transforms, cache_rate=1.0, num_workers=self.num_workers)
        # else:
        #     raise ValueError(f"Stage {stage} not supported")

        # TODO: os.path.join() is not working for some reason
        # train_csv = os.path.join(str(root),f"/splits/{SPLIT_NAME}/train_{SPLIT_NAME}.csv")
        train_csv = f"/splits/{SPLIT_NAME}/train_{SPLIT_NAME}.csv"
        # train_csv = os.path.join(str(root), train_csv)
        train_csv = str(root) + train_csv
        train_filenames = pd.read_csv(train_csv)
        self.train_files = [{"image": os.path.join(IMAGE_SRC, row["image"]), "label": os.path.join(LABEL_SRC, row["label"])} for index, row in train_filenames.iterrows()]
        # self.train_ds = CacheDataset(data=self.train_files, transform=self.train_transforms, cache_rate=1.0, num_workers=self.hparams.num_workers)
        
        val_csv = f"/splits/{SPLIT_NAME}/val_{SPLIT_NAME}.csv"
        val_csv = str(root) + val_csv
        val_filenames = pd.read_csv(val_csv)
        self.val_files = [{"image": os.path.join(IMAGE_SRC, row["image"]), "label": os.path.join(LABEL_SRC, row["label"])} for index, row in val_filenames.iterrows()]
        self.val_ds = CacheDataset(data=self.val_files, transform=self.val_transforms, cache_rate=1.0, num_workers=self.hparams.num_workers)
        
        test_csv = f"/splits/{SPLIT_NAME}/test_{SPLIT_NAME}.csv"
        test_csv = str(root) + test_csv
        test_filenames = pd.read_csv(test_csv)
        self.test_files = [{"image": os.path.join(IMAGE_SRC, row["image"]), "label": os.path.join(LABEL_SRC, row["label"])} for index, row in test_filenames.iterrows()]
        # self.test_ds = CacheDataset(data=self.test_files, transform=self.test_transforms, cache_rate=1.0, num_workers=self.hparams.num_workers)
        


    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_ds,
            batch_size=self.hparams.batch_size,         # we can use this nifty trick and access the hyperparameters directly since we used self.save_hyperparameters() up top
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=list_data_collate,               # this collates our list of dictionaries into a dictionary of lists; not needed for if your dataset outputs something the default collate_fn can handle
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_ds,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=pad_list_data_collate,
            shuffle=False,
        )
    
    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_ds,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=list_data_collate,
            shuffle=False,
        )

In [4]:
class SpleenLightningModule(L.LightningModule):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer = None,
        scheduler: torch.optim.lr_scheduler._LRScheduler = None,
        loss: torch.nn.Module = None,
        compile: bool = False,
        lr: float = 1e-3,
        ) -> None:
        super().__init__()
        self.model = networks.nets.UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        )
        self.loss = losses.DiceLoss(to_onehot_y=True, softmax=True)
        self.metric = metrics.DiceMetric(include_background=False, reduction="mean")
        self.post_pred = inferers.Activation(inferers.Softmax())
        self.post_label = inferers.Argmax()
        self.val_dice = metrics.DiceMetric(include_background=False, reduction="mean")
        self.test_dice = metrics.DiceMetric(include_background=False, reduction="mean")

SyntaxError: incomplete input (3259313981.py, line 1)