In [2]:
import lightning as L
from monai import utils, transforms, networks, data, engines, losses, metrics, visualize, config, inferers, apps
import torch
import matplotlib.pyplot as plt
import glob
import os
import shutil
import tempfile
import dotenv
import rootutils

In [3]:
# rootutils is not working but maybe that's because this is a .ipynb and not a .py
# root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=False)

In [4]:
dotenv.load_dotenv()

# print the contents of the .env file
print("Contents of .env file:")
for key in os.environ:
    if key.startswith("MONAI_"):
        print(f"  {key} = {os.environ[key]}")

config.print_config()

Contents of .env file:
  MONAI_DATA_DIRECTORY = /home/sasank/projects/med-start/data
MONAI version: 1.3.0
Numpy version: 1.26.0
Pytorch version: 2.2.2+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/<username>/miniconda3/envs/monai/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.3.0
Nibabel version: 5.2.1
scikit-image version: 0.22.0
scipy version: 1.13.1
Pillow version: 10.3.0
Tensorboard version: 2.16.2
gdown version: 4.7.3
TorchVision version: 0.17.2+cu121
tqdm version: 4.66.2
lmdb version: 1.4.1
psutil version: 6.0.0
pandas version: 2.2.2
einops version: 0.7.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 2.12.1
pynrrd version: 1.0.0
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/

In [5]:
# Datamodule
class MyDataModule(L.LightningDataModule):
    def __init__(self) -> None:
        super().__init__()
        # Make sure data directory exists
        directory = os.environ.get("MONAI_DATA_DIRECTORY")
        if directory is not None:
            os.makedirs(directory, exist_ok=True)
        if directory is None or not os.path.exists(directory):
            # throw an error if the data directory is not set
            raise ValueError("Please set the environment variable MONAI_DATA_DIRECTORY to a valid directory.")
        # root_dir = tempfile.mkdtemp() if directory is None else directory
        self.root_dir = directory
        print(self.root_dir)

        self.data_dir = os.path.join(self.root_dir, "Task09_Spleen")
        # self.root_dir = '/home/sasank/projects/med-start/data/'
        print(self.root_dir)
        print(self.data_dir)

        utils.misc.set_determinism(seed=0)

    def prepare_data(self):
        
        # download the data if it's not already downloaded
        resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
        md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

        compressed_file = os.path.join(self.root_dir, "Task09_Spleen.tar")
        if not os.path.exists(self.data_dir):
            # print the directory it will be downloaded to
            print(f"Data will be downloaded to {self.data_dir}")
            apps.download_and_extract(resource, compressed_file, self.root_dir, md5)
    
    def setup(self):
        # set up the correct data path
        train_images = sorted(glob.glob(os.path.join(self.data_dir, "imagesTr", "*.nii.gz")))
        train_labels = sorted(glob.glob(os.path.join(self.data_dir, "labelsTr", "*.nii.gz")))
        data_dicts = [
            {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
        ]
        print(f"training data: {len(data_dicts)}")
        # print the first few items to check
        print(data_dicts[:2])
        train_files, val_files = data_dicts[:-9], data_dicts[-9:]

        # set deterministic training for reproducibility
        # set_determinism(seed=42)
        # set_determinism()

        train_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
                transforms.RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4),
            ]
        )

        val_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image", "label"]),
                transforms.EnsureChannelFirstd(keys=["image", "label"]),
                transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
                transforms.Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
                transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
                # transforms.EnsureTyped(keys=["image", "label"]),
            ]
        )

        self.train_ds = data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
        self.val_ds = data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
        
    def train_dataloader(self):
        return data.DataLoader(self.train_ds, batch_size=2, shuffle=True, num_workers=4)
    
    def val_dataloader(self):
        return data.DataLoader(self.val_ds, batch_size=2, num_workers=4)

In [6]:
dm = MyDataModule()
dm.prepare_data()
dm.setup()

monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.


/home/sasank/projects/med-start/data
/home/sasank/projects/med-start/data
/home/sasank/projects/med-start/data/Task09_Spleen
training data: 41
[{'image': '/home/sasank/projects/med-start/data/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label': '/home/sasank/projects/med-start/data/Task09_Spleen/labelsTr/spleen_10.nii.gz'}, {'image': '/home/sasank/projects/med-start/data/Task09_Spleen/imagesTr/spleen_12.nii.gz', 'label': '/home/sasank/projects/med-start/data/Task09_Spleen/labelsTr/spleen_12.nii.gz'}]


Loading dataset: 100%|██████████| 32/32 [00:34<00:00,  1.09s/it]
Loading dataset: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s]


In [7]:
# LightningModule
class MyNetwork(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = networks.nets.UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=networks.Layers.Norm.BATCH,
        )
        self.loss = losses.DiceLoss(to_onehot_y=True, softmax=True)
        # self.post_pred = inferers.Activation(inferers.Argmax(), to_onehot=True, num_classes=2)
        # self.post_label = inferers.OneHot(num_classes=2)
        self.post_pred = transforms.Compose([transforms.EnsureType("tensor", device="cpu"), transforms.AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = transforms.Compose([transforms.EnsureType("tensor", device="cpu"), transforms.AsDiscrete(to_onehot=2)])
        # self.post_pred and self.post_label are for the case when the model output is not one-hot encoded and the labels are one-hot encoded.
        # These transforms are run after the model output and the labels are retrieved from the data loader.
        self.DiceMetric = metrics.DiceMetric(include_background=False, reduction="mean")
        # You would want to include background when you have a background class in your data.
        # You would not want to include background when you don't have a background class in your data.
        # Could you not want to include background when you have a background class in your data?
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch["image"], batch["label"]
        outputs = self.model(inputs)
        loss = self.loss(outputs, labels)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch["image"], batch["label"]
        outputs = inferers.sliding_window_inference(inputs, roi_size=(96, 96, 96), sw_batch_size=4, predictor=self.model)
        # outputs = self.model(inputs)
        loss = self.loss(outputs, labels)
        self.log("val_loss", loss)
        return loss