## Setup environment

In [2]:
!python3 -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python3 -c "import matplotlib" || pip install -q matplotlib
!pip install -q pytorch-lightning~=2.0
%matplotlib inline

2024-07-11 15:43:32.831357: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Setup imports

In [1]:
import pytorch_lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from monai.utils import set_determinism
from monai.transforms import (
AsDiscrete,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
EnsureType,
EnsureTyped,
Resized,
RandAdjustContrastd, 
RandFlipd, 
RandAffined, 
RandAdjustContrastd

)
from monai.networks.nets import UNet, UNETR
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch, DataLoader
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib
import numpy as np
from natsort import natsorted
from sklearn.model_selection import KFold
from pytorch_lightning.plugins import MixedPrecisionPlugin

print_config()

2024-07-12 18:55:05.918851: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


MONAI version: 1.4.dev2427
Numpy version: 1.26.4
Pytorch version: 2.0.0+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: cbf90d0ddb27dc96a91385e4dd2f4eb239dea976
MONAI __file__: /home/<username>/.local/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.4.0
Nibabel version: 5.2.1
scikit-image version: 0.24.0
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.12.2
gdown version: 5.2.0
TorchVision version: 0.15.1+cu117
tqdm version: 4.66.4
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 1.5.3
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: 1.0.0
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/inst

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [2]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmpm4b4_nf6


In [3]:
data_dir = os.path.join(os.getcwd(), 'AeroPath')

## Download dataset

Downloads and extracts the dataset

In [6]:
resource = "https://zenodo.org/records/10069289/files/AeroPath.zip?download=1"
md5 = "3fd5106c175c85d60eaece220f5dfd87"

compressed_file = os.path.join(root_dir, "AeroPath.zip")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

## Define the LightningModule

The LightningModule contains a refactoring of your training code. The following module is a refactoring of the code in `spleen_segmentation_3d.ipynb`:

In [4]:
class UNetClass(pytorch_lightning.LightningModule):
    def __init__(self, mode, roi_size, spatial_size):
        super().__init__()
        self._model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        )
        # self._model = UNETR(
        #     in_channels=1,
        #     out_channels=2,
        #     img_size=roi_size,
        #     feature_size=16,
        #     hidden_size=768,
        #     mlp_dim=3072,
        #     num_heads=12,
        #     pos_embed="perceptron",
        #     norm_name="instance",
        #     res_block=True,
        #     dropout_rate=0.0,
        # )
        
        self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        self.post_pred = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(to_onehot=2)])
        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []

        self.mode = mode
        self.roi_size = roi_size
        self.spatial_size = spatial_size

        self.common_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                # pixdim=(1.5, 1.5, 2.0),
                pixdim=(1.1, 1.1, 1.40),
                mode=("bilinear", "nearest"),
            ),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-1024,
                a_max=1024,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Resized(keys=["image", "label"], spatial_size=self.spatial_size),
            EnsureTyped(keys=["image", "label"]),
        ]
        )


    def forward(self, x):
        return self._model(x)

    def prepare_data(self, prepare_val_data=True, prepare_test_data=True):
        # # set up the correct data path
        if self.mode == 'whole':
            pattern = os.path.join(data_dir, '**/*_CT_HR_label_airways.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join(data_dir, '**/*_CT_HR.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == '1Q':
            pattern = os.path.join('nonoverlapping_labels', '**/quadrant_1_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('nonoverlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == '2Q':
            pattern = os.path.join('nonoverlapping_labels', '**/quadrant_1_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('nonoverlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'left_bottom':
            pattern = os.path.join('dataset/airways_patched_4', '**/*left_bottom_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'left_uppper':
            pattern = os.path.join('dataset/airways_patched_4', '**/*left_upper_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'right_bottom':
            pattern = os.path.join('dataset/airways_patched_4', '**/*right_bottom_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'right_upper':
            pattern = os.path.join('dataset/airways_patched_4', '**/*right_upper_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))


        data_dicts = [
            {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
        ]
        train_files, val_files = data_dicts[:-9], data_dicts[-9:]

        # set deterministic training for reproducibility
        set_determinism(seed=0)

        # define the data transforms
        train_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    # pixdim=(1.5, 1.5, 2.0),
                    pixdim=(1.1, 1.1, 1.40),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-1024,
                    a_max=1024,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                # randomly crop out patch samples from
                # big image based on pos / neg ratio
                # the image centers of negative samples
                # must be in valid image area
                RandCropByPosNegLabeld(
                    keys=["image", "label"],
                    label_key="label",
                    spatial_size=(64, 64, 64),
                    pos=1,
                    neg=1,
                    num_samples=4,
                    image_key="image",
                    image_threshold=0,
                ),

                # user can also add other random transforms
                #                 RandAffined(
                #                     keys=['image', 'label'],
                #                     mode=('bilinear', 'nearest'),
                #                     prob=1.0,
                #                     spatial_size=(96, 96, 96),
                #                     rotate_range=(0, 0, np.pi/15),
                #                     scale_range=(0.1, 0.1, 0.1)),
                # Adding the data augmentation transforms with a probability of 50%
                # RandFlipd(
                #     keys=["image", "label"],
                #     spatial_axis=[0],  # Horizontal flip
                #     prob=0.5
                # ),
                # RandFlipd(
                #     keys=["image", "label"],
                #     spatial_axis=[1],  # Vertical flip
                #     prob=0.5
                # ),
                # RandAffined(
                #     keys=["image", "label"],
                #     prob=0.5,
                #     rotate_range=(np.deg2rad(20), np.deg2rad(20), np.deg2rad(20)),
                #     translate_range=(0.2, 0.2, 0.2),
                #     scale_range=(0.5, 1.5),
                #     mode=('bilinear', 'nearest')
                # ),
                RandAdjustContrastd(
                    keys=["image"],
                    gamma=(0.5, 2.0),
                    prob=0.0
                ),
                EnsureTyped(keys=["image", "label"]),
            ]
        )
        val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    # pixdim=(1.5, 1.5, 2.0),
                    pixdim=(1.1, 1.1, 1.40),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-1024,
                    a_max=1024,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
            ]
        )
                    

        # we use cached datasets - these are 10x faster than regular datasets
        if prepare_test_data:
            self.train_ds = CacheDataset(
                data=train_files,
                transform=train_transforms,
                cache_rate=0.4,
                num_workers=4,
            )
        if prepare_val_data:
            self.val_ds = CacheDataset(
                data=val_files,
                transform=val_transforms,
                cache_rate=0.4,
                num_workers=4,
            )


    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_ds,
            batch_size=1,
            shuffle=True,
            num_workers=4,
            collate_fn=list_data_collate,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_ds, batch_size=1, num_workers=4)
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        return {"loss": loss, "log": tensorboard_logs}

    
    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        roi_size = self.roi_size
        sw_batch_size = 4
        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        self.dice_metric(y_pred=outputs, y=labels)
        d = {"val_loss": loss, "val_number": len(outputs)}
        self.validation_step_outputs.append(d)
        return d
    

    
    def perform_inference(self, model, data):
        # Perform inference using the model
        with torch.no_grad():
            data = torch.DoubleTensor(data)  # Convert data to type Double
            model_output = model(data.unsqueeze(0))
        return model_output

    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.dice_metric.aggregate().item()
        self.dice_metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        tensorboard_logs = {
            "val_dice": mean_val_dice,
            "val_loss": mean_val_loss,
        }
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.validation_step_outputs.clear()  # free memory
        self.log('val_dice', mean_val_dice, on_step=False, on_epoch=True, prog_bar=True, logger=True) # log

        return {"log": tensorboard_logs}
    
    
    def dice_score(self, prediction_tensor, label_tensor):
        # Compute Dice score
        dice_metric = DiceMetric(include_background=True, reduction="mean")
        dice_metric(y_pred=prediction_tensor, y=label_tensor)
        dice_score = dice_metric.aggregate().item()
        dice_metric.reset()

        print(dice_score)

In [9]:
from monai.losses import DiceCELoss

class UnetrNet(pytorch_lightning.LightningModule):
    def __init__(self, mode = 'whole'):
        super().__init__()

        self._model = UNETR(
            in_channels=1,
            out_channels=2,
            img_size=(96, 96, 96),
            feature_size=16,
            hidden_size=768,
            mlp_dim=3072,
            num_heads=12,
            proj_type="perceptron",
            norm_name="instance",
            res_block=True,
            conv_block=True,
            dropout_rate=0.0,
        ).to(device)

        self.loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
        self.post_pred = AsDiscrete(argmax=True, to_onehot=2)
        self.post_label = AsDiscrete(to_onehot=2)
        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.max_epochs = 600
        self.check_val = 30
        self.warmup_epochs = 20
        self.metric_values = []
        self.epoch_loss_values = []
        self.validation_step_outputs = []
        self.mode = mode

    def forward(self, x):
        return self._model(x)

    def prepare_data(self, prepare_val_data=True, prepare_test_data=True):
        # # set up the correct data path
        if self.mode == 'whole':
            pattern = os.path.join(data_dir, '**/*_CT_HR_label_airways.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join(data_dir, '**/*_CT_HR.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == '1Q':
            pattern = os.path.join('nonoverlapping_labels', '**/quadrant_1_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('nonoverlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == '2Q':
            pattern = os.path.join('nonoverlapping_labels', '**/quadrant_1_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('nonoverlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'left_bottom':
            pattern = os.path.join('dataset/airways_patched_4', '**/*left_bottom_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'left_uppper':
            pattern = os.path.join('dataset/airways_patched_4', '**/*left_upper_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'right_bottom':
            pattern = os.path.join('dataset/airways_patched_4', '**/*right_bottom_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))

        elif self.mode == 'right_upper':
            pattern = os.path.join('dataset/airways_patched_4', '**/*right_upper_*.nii.gz')
            train_labels = sorted(glob.glob(pattern, recursive=True))

            pattern = os.path.join('dataset/scan_patched_4', '**/*left_bottom_*.nii.gz')
            train_images = sorted(glob.glob(pattern, recursive=True))


        data_dicts = [
            {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
        ]
        train_files, val_files = data_dicts[:-9], data_dicts[-9:]

        # set deterministic training for reproducibility
        set_determinism(seed=0)

        # define the data transforms
        train_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.5, 1.5, 2.0),
                    # pixdim=(1.1, 1.1, 1.40),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-1024,
                    a_max=1024,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                # randomly crop out patch samples from
                # big image based on pos / neg ratio
                # the image centers of negative samples
                # must be in valid image area
                RandCropByPosNegLabeld(
                    keys=["image", "label"],
                    label_key="label",
                    spatial_size=(64, 64, 64),
                    pos=1,
                    neg=1,
                    num_samples=4,
                    image_key="image",
                    image_threshold=0,
                ),
                RandAdjustContrastd(
                    keys=["image"],
                    gamma=(0.5, 2.0),
                    prob=0.0
                ),
                EnsureTyped(keys=["image", "label"]),
            ]
        )
        val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.5, 1.5, 2.0),
                    # pixdim=(1.1, 1.1, 1.40),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-1024,
                    a_max=1024,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
            ]
        )
                    

        # we use cached datasets - these are 10x faster than regular datasets
        if prepare_test_data:
            self.train_ds = CacheDataset(
                data=train_files,
                transform=train_transforms,
                cache_rate=0.4,
                num_workers=4,
            )
        if prepare_val_data:
            self.val_ds = CacheDataset(
                data=val_files,
                transform=val_transforms,
                cache_rate=0.4,
                num_workers=4,
            )

    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_ds,
            batch_size=1,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
            collate_fn=list_data_collate,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self._model.parameters(), lr=1e-4, weight_decay=1e-5)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = (batch["image"].cuda(), batch["label"].cuda())
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        return {"loss": loss, "log": tensorboard_logs}

    def on_train_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.epoch_loss_values.append(avg_loss.detach().cpu().numpy())

    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        roi_size = (96, 96, 96)
        sw_batch_size = 2
        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        self.dice_metric(y_pred=outputs, y=labels)
        d = {"val_loss": loss, "val_number": len(outputs)}
        self.validation_step_outputs.append(d)
        return d

    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.dice_metric.aggregate().item()
        self.dice_metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        tensorboard_logs = {
            "val_dice": mean_val_dice,
            "val_loss": mean_val_loss,
        }
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.metric_values.append(mean_val_dice)
        self.validation_step_outputs.clear()  # free memory
        return {"log": tensorboard_logs}

In [5]:

# Net_left_upper       = UNetClass(mode = 'left_upper',       roi_size=(160, 160, 160),   spatial_size=(160, 160, 160))
# Net_left_bottom      = UNetClass(mode = 'left_bottom',      roi_size=(160, 160, 160),   spatial_size=(160, 160, 160))
# Net_right_upper      = UNetClass(mode = 'right_upper',      roi_size=(160, 160, 160),   spatial_size=(160, 160, 160))
# Net_right_bottom     = UNetClass(mode = 'right_bottom',     roi_size=(160, 160, 160),   spatial_size=(160, 160, 160))
# NetWhole             = UNetClass(mode = 'whole',            roi_size=(128, 128, 144),   spatial_size=(128, 128, 144))
NetWhole             = UNetClass(mode = 'whole',            roi_size=(128, 128, 128),   spatial_size=(128, 128, 128))
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# NetWhole = UnetrNet(mode = 'whole')
# Net1Q                = UNetClass(mode = '1Q',               roi_size=(160*2, 160, 160), spatial_size=(160*2, 160, 160))
# Net2Q                = UNetClass(mode = '2Q',               roi_size=(160*2, 160, 160), spatial_size=(160*2, 160, 160))

# NetWhole    = UNetClass(mode = 'whole', roi_size=(192, 192, 212),   spatial_size=(192, 192, 212))

monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.


## Run the training

In [6]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_dice',
    dirpath=os.path.join(data_dir, 'checkpoints'),  # Directory to save checkpoints
    filename='whole_bigger_pixdim_UNETR',  # Filename prefix for saving checkpoints
    save_top_k=1,  # Save only the best checkpoint
    mode='max',  # `min` for minimizing the metric, `max` for maximizing
    verbose=True,  # Log a message when saving the best checkpoint
)

In [7]:
# initialise the LightningModule
# net = Net1Q
# net = Net2Q
# net = Net_left_bottom
net = NetWhole
# set up loggers and checkpoints
log_dir = os.path.join(root_dir, "logs")
tb_logger = pytorch_lightning.loggers.TensorBoardLogger(save_dir=log_dir)

# initialise Lightning's trainer.
trainer = pytorch_lightning.Trainer(
    devices=[0],
    max_epochs=600,
    logger=tb_logger,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=1,
    log_every_n_steps=16,
)

# Mixed precision trainer
trainer = pytorch_lightning.Trainer(
    devices=[0],
    max_epochs=600,
    logger=tb_logger,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=1,
    log_every_n_steps=16,
    precision=16,  # Mixed precision
    # plugins=[MixedPrecisionPlugin()]
)

# train
trainer.fit(net)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
16 is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Loading dataset:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
print(f"train completed, best_metric: {net.best_val_dice:.4f} " f"at epoch {net.best_val_epoch}")

train completed, best_metric: 0.7536 at epoch 540


In [None]:
import torch
from monai.networks.nets import UNet


# Load the model weights from the checkpoint file
checkpoint_path = 'best-checkpoint.ckpt'
model = Net1Q.load_from_checkpoint('1Q_clipped_resized128_128_144_roibest_metric: 0.2797 at epoch 495.ckpt', mode = '1Q', roi_size=(160*2, 160, 160), spatial_size=(160*2, 160, 160))

# Set the model to evaluation mode
model.eval()


## Model Ensembling

## View training in tensorboard

Please uncomment the following cell to load tensorboard results.

In [None]:
%load_ext tensorboard
%tensorboard --logdir=$log_dir

## Load model and create prediction files

In [None]:
net = Net1Q.load_from_checkpoint('1Q_clipped_resized128_128_144_roibest_metric: 0.2797 at epoch 495.ckpt',                  mode = '1Q',    roi_size=(160*2, 160, 160), spatial_size=(160*2, 160, 160))
net = Net2Q.load_from_checkpoint('2Q_clipped_resized320_160_160.ckpt',                                                      mode = '2Q',    roi_size=(160*2, 160, 160), spatial_size=(160*2, 160, 160))
net = NetWhole.load_from_checkpoint('whole_clipped_resized160_best_metric: 0.8211 at epoch 562.ckpt',                       mode = 'whole', roi_size=(128, 128, 144),   spatial_size=(128, 128, 144))

net.prepare_data(prepare_test_data=False)

Loading dataset: 100%|██████████| 9/9 [01:05<00:00,  7.26s/it]


In [None]:
# test validation dataset labels
for i, val_data in enumerate(net.val_dataloader()):
    label = val_data['label'].cpu().numpy()[0, 0, :, :, :]
    nib.save(nib.Nifti1Image(label.astype(float), nib.load('AeroPath/1/1_CT_HR_label_airways.nii.gz').affine), f'labels_spacingd_10less/{i}.nii.gz')

In [None]:
# model = Net.load_from_checkpoint('checkpoints/best-checkpoint_whole_64_0.8023 at epoch: 427.ckpt')

# model.eval()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model.to(device)

# if net is None:
#     net = Net.load_from_checkpoint('checkpoints/best-checkpoint_whole_64_0.8023 at epoch: 427.ckpt')
#     net.prepare_data()

net.eval()
device = torch.device("cuda:0")
net.to(device)
with torch.no_grad():
    for i, val_data in enumerate(net.val_dataloader()):
        # roi_size = (64, 64, 64)
        roi_size = net.roi_size
        # roi_size = (128, 128, 144)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_data["image"].to(device), roi_size, sw_batch_size, net)
        # plot the slice [:, :, 80]
        # plt.figure("check", (18, 6))
        # plt.subplot(1, 3, 1)
        # plt.title(f"image {i}")
        # plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
        # plt.subplot(1, 3, 2)
        # plt.title(f"label {i}")
        # plt.imshow(val_data["label"][0, 0, :, :, 80])
        # plt.subplot(1, 3, 3)
        # plt.title(f"output {i}")
        pred = torch.argmax(val_outputs, dim=1).detach().cpu()
        # plt.imshow(pred[0, :, :, 80])
        plt.show()


        pred_np = pred.cpu().numpy()[0, :, :, :]

        # dir_name = 'whole_resized_roi160'
        # dir_name = '1Q_resized_roi160'
        # # dir_name = '1Q'
        # # dir_name = '2Q'

        dir_name = net.mode

        nib.save(nib.Nifti1Image(pred_np.astype(float), nib.load('AeroPath/1/1_CT_HR_label_airways.nii.gz').affine), f'pred/{dir_name}/{i}.nii.gz')

        # label = val_data['label'].cpu().numpy()[0, 0, :, :, :]
        # nib.save(nib.Nifti1Image(label.astype(float), nib.load('AeroPath/1/1_CT_HR_label_airways.nii.gz').affine), f'labels_resized/{i}.nii.gz')

        # print(f"pred shape: {pred_np.shape}, label shape: {label.shape}")


## Interpolate, Concat and ensemble predictions

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import nibabel as nib
from monai.transforms import Compose, AsDiscrete
from monai.metrics import DiceMetric

def interpolate_predictions(predictions, target_shape):
    # Ensure predictions tensor has 5 dimensions: (N, C, D, H, W)
    predictions = torch.tensor(predictions).unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions

    # Interpolate the predictions using nearest neighbor interpolation
    interpolated_predictions = F.interpolate(predictions, size=target_shape, mode='nearest')
    
    # Remove batch and channel dimensions after interpolation
    interpolated_predictions = interpolated_predictions.squeeze(0).squeeze(0).numpy()
    
    return interpolated_predictions

# Function to calculate Dice score
def calc_dice_score(pred, label):
    post_pred = Compose([AsDiscrete(argmax=False, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])

    # Ensure the inputs have the correct dimensions
    pred = torch.tensor(pred).unsqueeze(0).unsqueeze(0)
    label = torch.tensor(label).unsqueeze(0).unsqueeze(0)

    prediction_tensor = post_pred(pred)
    label_tensor = post_label(label)

    # Compute Dice score
    dice_metric = DiceMetric(include_background=False, reduction="mean")
    dice_metric(y_pred=prediction_tensor, y=label_tensor)

    dice_score = dice_metric.aggregate().item()
    dice_metric.reset()

    return dice_score


labels = natsorted(glob.glob('labels/*', recursive=True))


# Example usage with NIfTI files

for idx, _ in enumerate(labels):
    label = nib.load(f'labels_resized/{idx}.nii.gz').get_fdata()
    # whole = nib.load(f'pred/whole_resized_roi160/{idx}.nii.gz').get_fdata()
    whole = nib.load(f'pred/whole/{idx}.nii.gz').get_fdata()

    pred_1Q = nib.load(f'pred/1Q_resized_roi160/{idx}.nii.gz').get_fdata()
    pred_2Q = nib.load(f'pred/2Q_resized_roi160/{idx}.nii.gz').get_fdata()


    # Interpolate predictions to allow for concatenation

    inter_pred_1Q = interpolate_predictions(pred_1Q, (*whole.shape[:-1], pred_1Q.shape[-1]))
    inter_pred_2Q = interpolate_predictions(pred_2Q, (*whole.shape[:-1], pred_2Q.shape[-1]))


    # Merge interpolated predictions along the last axis (z-axis)
    merged = np.concatenate((inter_pred_1Q, inter_pred_2Q), axis=2)

    # Interpolate predictions to match whole.shape

    merged = interpolate_predictions(merged, whole.shape)

    # Save the merged predictions as a NIfTI file
    merged_nifti = nib.Nifti1Image(merged, nib.load('AeroPath/1/1_CT_HR.nii.gz').affine)
    nib.save(merged_nifti, 'pred_test_merged.nii.gz')
    # print(f'Merged shape: {merged.shape}')

    # Perform ensembling using the maximum values
    ensembled = np.maximum(merged, whole)
    
    ensembled[:5, :, :] = 0
    ensembled[-5:, :, :] = 0
    ensembled[:, :5, :] = 0
    ensembled[:, -5:, :] = 0
    ensembled[:, :, :5] = 0
    ensembled[:, :, -5:] = 0

    # Save the ensembled predictions as a NIfTI file
    ensembled_nifti = nib.Nifti1Image(ensembled, nib.load('AeroPath/1/1_CT_HR.nii.gz').affine)
    nib.save(ensembled_nifti, f'pred/ensembled/{idx}.nii.gz')

    # Compute Dice score
    # print(f'whole: {whole.shape}, label: {label.shape}')
    dice_whole = calc_dice_score(whole, label)
    print(f'Dice Score for whole prediction: {dice_whole:.4f}')



    dice_ensembled = calc_dice_score(ensembled, label)
    print(f'Dice Score for ensembled prediction: {dice_ensembled:.4f}')

    print('*'*50)

# Dice Score for whole prediction: 0.8711
# Dice Score for ensembled prediction: 0.8846
# **************************************************
# Dice Score for whole prediction: 0.7663
# Dice Score for ensembled prediction: 0.7592
# **************************************************
# Dice Score for whole prediction: 0.7318
# Dice Score for ensembled prediction: 0.7432
# **************************************************
# Dice Score for whole prediction: 0.8730
# Dice Score for ensembled prediction: 0.8764
# **************************************************
# Dice Score for whole prediction: 0.8862
# Dice Score for ensembled prediction: 0.8930
# **************************************************
# Dice Score for whole prediction: 0.7021
# Dice Score for ensembled prediction: 0.7184
# **************************************************
# Dice Score for whole prediction: 0.7940
# Dice Score for ensembled prediction: 0.8396
# **************************************************
# Dice Score for whole prediction: 0.6629
# Dice Score for ensembled prediction: 0.6776
# **************************************************
# Dice Score for whole prediction: 0.8969
# Dice Score for ensembled prediction: 0.9105
# **************************************************

Dice Score for whole prediction: 0.8429
Dice Score for ensembled prediction: 0.8514
**************************************************
Dice Score for whole prediction: 0.7267
Dice Score for ensembled prediction: 0.7223
**************************************************
Dice Score for whole prediction: 0.7134
Dice Score for ensembled prediction: 0.7225
**************************************************
Dice Score for whole prediction: 0.8573
Dice Score for ensembled prediction: 0.8615
**************************************************
Dice Score for whole prediction: 0.8257
Dice Score for ensembled prediction: 0.8436
**************************************************
Dice Score for whole prediction: 0.7070
Dice Score for ensembled prediction: 0.7301
**************************************************
Dice Score for whole prediction: 0.7620
Dice Score for ensembled prediction: 0.7973
**************************************************
Dice Score for whole prediction: 0.6476
Dice Score for 

In [None]:
ensembled_nifti = nib.Nifti1Image(ensembled, nib.load('AeroPath/1/1_CT_HR.nii.gz').affine)
nib.save(ensembled_nifti, f'pred/ensembled/{idx}.nii.gz')

# Compute Dice score
dice_whole = calc_dice_score(whole, label)
print(f'Dice Score for whole prediction: {dice_whole:.4f}')

dice_ensembled = calc_dice_score(ensembled, label)
print(f'Dice Score for ensembled prediction: {dice_ensembled:.4f}')

Dice Score for whole prediction: 0.9051
Dice Score for ensembled prediction: 0.8460


In [None]:
def calc_dice_score(pred, label):
    post_pred = Compose([AsDiscrete(argmax=False, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])

    # Ensure the inputs have the correct dimensions
    pred = torch.tensor(pred).unsqueeze(0).unsqueeze(0)
    label = torch.tensor(label).unsqueeze(0).unsqueeze(0)

    prediction_tensor = post_pred(pred)
    label_tensor = post_label(label)

    # Compute Dice score
    dice_metric = DiceMetric(include_background=False, reduction="mean")
    dice_metric(y_pred=prediction_tensor, y=label_tensor)

    dice_score = dice_metric.aggregate().item()
    dice_metric.reset()

    return dice_score

label = val_data['label'].cpu().numpy()[0, 0, :, :, :]
dice_whole = calc_dice_score(pred, label)
print(f'Dice Score for whole prediction: {dice_whole:.4f}')

  pred = torch.tensor(pred).unsqueeze(0).unsqueeze(0)


Dice Score for whole prediction: 0.4997


## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)