## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python -c "import matplotlib" || pip install -q matplotlib
!pip install -q pytorch-lightning~=2.0
%matplotlib inline

/bin/bash: line 1: python: command not found


## Setup imports

In [2]:
import pytorch_lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from monai.utils import set_determinism
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureType,
    EnsureTyped,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch, DataLoader
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib

print_config()

2024-06-20 19:53:07.073935: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


MONAI version: 1.3.0
Numpy version: 1.23.5
Pytorch version: 2.0.0+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/<username>/.local/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.8.0
Pillow version: 9.0.1
Tensorboard version: 2.12.2
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.15.1+cu117
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.7
pandas version: 1.5.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED 

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmpuznlvoil


In [3]:
data_dir = '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath'

## Download dataset

Downloads and extracts the dataset.
The dataset comes from http://medicaldecathlon.com/.

In [4]:
resource = "https://zenodo.org/records/10069289/files/AeroPath.zip?download=1"
md5 = "3fd5106c175c85d60eaece220f5dfd87"

compressed_file = os.path.join(root_dir, "AeroPath.zip")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

## Define the LightningModule

The LightningModule contains a refactoring of your training code. The following module is a refactoring of the code in `spleen_segmentation_3d.ipynb`:

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_dice',
    dirpath=os.path.join(root_dir, 'checkpoints'),  # Directory to save checkpoints
    filename='best-checkpoint_whole_64',  # Filename prefix for saving checkpoints
    save_top_k=1,  # Save only the best checkpoint
    mode='max',  # `min` for minimizing the metric, `max` for maximizing
    verbose=True,  # Log a message when saving the best checkpoint
)

In [7]:
class Net(pytorch_lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self._model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        )
        self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        self.post_pred = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(to_onehot=2)])
        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []


        self.common_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
        ]
        )


    def forward(self, x):
        return self._model(x)

    def prepare_data(self):
        # # set up the correct data path
        # pattern = os.path.join(data_dir, '**/*_CT_HR_label_airways.nii.gz')
        # train_labels = sorted(glob.glob(pattern, recursive=True))

        # pattern = os.path.join(data_dir, '**/*_CT_HR.nii.gz')
        # train_images = sorted(glob.glob(pattern, recursive=True))

        pattern = os.path.join('overlapping_labels', '**/quadrant_1_*.nii.gz')
        train_labels = sorted(glob.glob(pattern, recursive=True))

        pattern = os.path.join('overlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
        train_images = sorted(glob.glob(pattern, recursive=True))

        data_dicts = [
            {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
        ]
        train_files, val_files = data_dicts[:-9], data_dicts[-9:]

        # set deterministic training for reproducibility
        set_determinism(seed=0)

        # define the data transforms
        train_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.5, 1.5, 2.0),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-57,
                    a_max=164,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                # randomly crop out patch samples from
                # big image based on pos / neg ratio
                # the image centers of negative samples
                # must be in valid image area

                RandCropByPosNegLabeld(
                    keys=["image", "label"],
                    label_key="label",
                    spatial_size=(64, 64, 64),
                    pos=1,
                    neg=1,
                    num_samples=4,
                    image_key="image",
                    image_threshold=0,
                ),

                # user can also add other random transforms
                #                 RandAffined(
                #                     keys=['image', 'label'],
                #                     mode=('bilinear', 'nearest'),
                #                     prob=1.0,
                #                     spatial_size=(96, 96, 96),
                #                     rotate_range=(0, 0, np.pi/15),
                #                     scale_range=(0.1, 0.1, 0.1)),
            ]
        )
        val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.5, 1.5, 2.0),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-57,
                    a_max=164,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
            ]
        )
                    

        # we use cached datasets - these are 10x faster than regular datasets
        self.train_ds = CacheDataset(
            data=train_files,
            transform=train_transforms,
            cache_rate=1.0,
            num_workers=4,
        )
        self.val_ds = CacheDataset(
            data=val_files,
            transform=val_transforms,
            cache_rate=1.0,
            num_workers=4,
        )


    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
            collate_fn=list_data_collate,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_ds, batch_size=1, num_workers=4)
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        return {"loss": loss, "log": tensorboard_logs}

    
    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        self.dice_metric(y_pred=outputs, y=labels)
        d = {"val_loss": loss, "val_number": len(outputs)}
        self.validation_step_outputs.append(d)
        return d
    


    
    def perform_inference(self, model, data):
        # Perform inference using the model
        with torch.no_grad():
            data = torch.DoubleTensor(data)  # Convert data to type Double
            model_output = model(data.unsqueeze(0))
        return model_output

    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.dice_metric.aggregate().item()
        self.dice_metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        tensorboard_logs = {
            "val_dice": mean_val_dice,
            "val_loss": mean_val_loss,
        }
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.validation_step_outputs.clear()  # free memory
        self.log('val_dice', mean_val_dice, on_step=False, on_epoch=True, prog_bar=True, logger=True) # log

        return {"log": tensorboard_logs}
    
    def infer_on_single_image(self, data):
        self._model.eval()
        with torch.no_grad():
            roi_size = (64, 64, 64)
            sw_batch_size = 4
            outputs = sliding_window_inference(data, roi_size, sw_batch_size, self._model)
            post_processed_outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        return post_processed_outputs
    

    def infer(self, file_path, label_path = None):
        self._model.eval()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self._model.to(device)

        data_dict = {"image": file_path, "label": label_path}

        # Apply transformations
        data_dict = self.common_transforms(data_dict)
        data = data_dict["image"]
        label = data_dict["label"]

        # Convert data and label to tensors and add batch dimension

        data = torch.tensor(data).unsqueeze(0).to(device)
        label = torch.tensor(label).unsqueeze(0).to(device)


        # Run inference
        predictions = self.infer_on_single_image(data)

        # Process and save the predictions as needed
        for i, prediction in enumerate(predictions):
            # Save or process each prediction here

            prediction_np = prediction.cpu().numpy()
            # Save as NIfTI file, for testing only
            pred_img = nib.Nifti1Image(prediction_np, nib.load(file_path).affine)

        # Convert predictions and labels to binary format if necessary
        post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
        post_label = Compose([AsDiscrete(to_onehot=2)])

        # Apply the transformations directly to the tensors and ensure they are on the same device
        prediction_tensor = post_pred(prediction.to(device))
        label_tensor = post_label(label.to(device))

        dice_metric = DiceMetric(include_background=True, reduction="mean")
        dice_metric(y_pred=prediction_tensor, y=label_tensor)
        dice_score = dice_metric.aggregate().item()
        dice_metric.reset()

        print(f"Dice Score: {dice_score:.4f}")
        print("Inference complete.")

        return prediction_tensor.cpu(), label_tensor.cpu()
    
    def dice_score(self, prediction_tensor, label_tensor):
        # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
        # post_label = Compose([AsDiscrete(to_onehot=2)])

        # prediction_tensor = post_pred(prediction.to(device))
        # label_tensor = post_label(label.to(device))

        # Compute Dice score
        dice_metric = DiceMetric(include_background=True, reduction="mean")
        dice_metric(y_pred=prediction_tensor, y=label_tensor)
        dice_score = dice_metric.aggregate().item()
        dice_metric.reset()

        print(dice_score)




Testing Dice score on sample data

In [10]:
import numpy as np
from scipy.ndimage import zoom

model = Net.load_from_checkpoint('checkpoints/best-checkpoint_new_method_1Q_spatial_size_48_0.6382 at epoch 338.ckpt')
file_path = 'nonoverlapping_quadrants/7/quadrant_1_7_CT_HR.nii.gz'
label_path = 'nonoverlapping_labels/7/quadrant_1_7_CT_HR_label_airways.nii.gz'

pred_1Q, label_1Q = model.infer(file_path, label_path)

del model

model = Net.load_from_checkpoint('checkpoints/best-checkpoint_new_method_2Q_spatial_size_48_0.7714 at epoch 584.ckpt')
file_path = 'nonoverlapping_quadrants/7/quadrant_2_7_CT_HR.nii.gz'
label_path = 'nonoverlapping_labels/7/quadrant_2_7_CT_HR_label_airways.nii.gz'

pred_2Q, label_2Q = model.infer(file_path, label_path)

del model

model = Net.load_from_checkpoint('checkpoints/best-checkpoint_whole_48_0.7861 at epoch 411.ckpt')
file_path = 'AeroPath/7/7_CT_HR_label_lungs.nii.gz'
label_path = 'AeroPath/7/7_CT_HR_label_lungs.nii.gz'

pred_whole, label_whole = model.infer(file_path, label_path)



def interpolate_predictions(predictions, target_shape):
    # Extract the original shape of the predictions
    original_shape = predictions.shape
    # Compute the scaling factors for interpolation along each axis
    scale_factors = [t / o for t, o in zip(target_shape, original_shape)]
    # Interpolate the predictions using zoom
    interpolated_predictions = zoom(predictions, zoom=scale_factors, mode='nearest')
    return interpolated_predictions

pred_1Q_copy = pred_1Q
pred_2Q_copy = pred_2Q


# interpolate the predictions to the same shape as the whole image

pred_1Q = interpolate_predictions(pred_1Q, (*list(pred_whole.shape[:3]), pred_1Q.shape[3]))
pred_2Q = interpolate_predictions(pred_2Q, (*list(pred_whole.shape[:3]), pred_2Q.shape[3]))


# concatenate the predictions

pred_Q1_Q2 = np.concatenate((pred_1Q, pred_2Q), axis=3)


# use the maximum value to ensemble the predictions

ensemble_maximum = np.maximum(pred_Q1_Q2, pred_whole)



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.dice_score(torch.tensor(ensemble_maximum).unsqueeze(0).to(device), torch.tensor(label_whole).unsqueeze(0).to(device))



  data = torch.tensor(data).unsqueeze(0).to(device)
  label = torch.tensor(label).unsqueeze(0).to(device)


Dice Score: 0.5216
Inference complete.
Dice Score: 0.6703
Inference complete.
Dice Score: 0.9771
Inference complete.
0.5117281079292297


  model.dice_score(torch.tensor(ensemble_maximum).unsqueeze(0).to(device), torch.tensor(label_whole).unsqueeze(0).to(device))


In [11]:
pred_1Q = interpolate_predictions(pred_1Q, (*list(pred_whole.shape[:3]), pred_1Q.shape[3]))
pred_2Q = interpolate_predictions(pred_2Q, (*list(pred_whole.shape[:3]), pred_2Q.shape[3]))

pred_Q1_Q2 = np.concatenate((pred_1Q, pred_2Q), axis=3)

ensemble_maximum = np.maximum(pred_Q1_Q2, pred_whole)


## Create overlapping quadrants

In [28]:
from natsort import natsorted

def create_overlapping_quadrants(image_shape, overlap_percentage=0.0):
    x, y, z = image_shape
    overlap_z = int(z * overlap_percentage)
    if z < 3 * overlap_z:
        return [(0, z)]
    mid_z = z // 2
    return [
        (0, mid_z + overlap_z),
        (mid_z - overlap_z, z)
    ]

def load_nifti_file(file_path):
    return nib.load(file_path)

def save_nifti_file(data, affine, file_path):
    img = nib.Nifti1Image(data, affine)
    nib.save(img, file_path)

def process_file(file_path, output_dir):
    img = load_nifti_file(file_path)
    data = img.get_fdata()
    print("input shape: ", data.shape)
    affine = img.affine
    quadrants = create_overlapping_quadrants(data.shape)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for i, (start, end) in enumerate(quadrants):
        quadrant_data = data[:, :, start:end]
        quadrant_file_path = os.path.join(output_dir, f"quadrant_{i+1}_{os.path.basename(file_path)}")
        save_nifti_file(quadrant_data, affine, quadrant_file_path)
        print(f"Saved: {quadrant_file_path}")
        print(f'Quadrant {i} shape: {quadrant_data.shape}')


# Example Usage
pattern = os.path.join(data_dir, '**/*_CT_HR_label_airways.nii.gz')
train_labels = natsorted(glob.glob(pattern, recursive=True))

pattern = os.path.join(data_dir, '**/*_CT_HR.nii.gz')
train_images = natsorted(glob.glob(pattern, recursive=True))



for idx, (image_name, label_name) in enumerate(zip(train_images, train_labels)):
    output_dir = f'nonoverlapping_quadrants/{idx + 1}'
    process_file(image_name, output_dir)
    output_dir = f'nonoverlapping_labels/{idx + 1}'
    process_file(label_name, output_dir)
    
    print('')
    print('###########')
    print('')


input shape:  (512, 512, 767)
Saved: nonoverlapping_quadrants/1/quadrant_1_1_CT_HR.nii.gz
Quadrant 0 shape: (512, 512, 383)
Saved: nonoverlapping_quadrants/1/quadrant_2_1_CT_HR.nii.gz
Quadrant 1 shape: (512, 512, 384)
input shape:  (512, 512, 767)
Saved: nonoverlapping_labels/1/quadrant_1_1_CT_HR_label_airways.nii.gz
Quadrant 0 shape: (512, 512, 383)
Saved: nonoverlapping_labels/1/quadrant_2_1_CT_HR_label_airways.nii.gz
Quadrant 1 shape: (512, 512, 384)

###########

input shape:  (512, 512, 829)
Saved: nonoverlapping_quadrants/2/quadrant_1_2_CT_HR.nii.gz
Quadrant 0 shape: (512, 512, 414)
Saved: nonoverlapping_quadrants/2/quadrant_2_2_CT_HR.nii.gz
Quadrant 1 shape: (512, 512, 415)
input shape:  (512, 512, 829)
Saved: nonoverlapping_labels/2/quadrant_1_2_CT_HR_label_airways.nii.gz
Quadrant 0 shape: (512, 512, 414)
Saved: nonoverlapping_labels/2/quadrant_2_2_CT_HR_label_airways.nii.gz
Quadrant 1 shape: (512, 512, 415)

###########

input shape:  (512, 512, 714)
Saved: nonoverlapping_quad

In [31]:
pattern = os.path.join(data_dir, '**/*_CT_HR.nii.gz')
train_images = sorted(glob.glob(pattern, recursive=True))
train_images

['/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/1/1_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/10/10_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/11/11_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/12/12_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/13/13_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/14/14_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/15/15_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/16/16_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/17/17_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/18/18_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/19/19_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/2/2_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/20/20_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath/21/21_CT_HR.nii.gz',
 '/home/gasyna/RiSA_S3/3D_segmentation

In [None]:

pattern = os.path.join('overlapping_labels', '**/quadrant_1_*.nii.gz')
train_labels = sorted(glob.glob(pattern, recursive=True))

pattern = os.path.join('overlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
train_images = sorted(glob.glob(pattern, recursive=True))

In [7]:
# Check dimensions of random sample

img = 'nonoverlapping_quadrants/24/quadrant_2_7_CT_HR.nii.gz'
label = 'nonoverlapping_labels/24/quadrant_2_7_CT_HR_label_airways.nii.gz'
img = nib.load(img).get_fdata()
label = nib.load(label).get_fdata()

print(f'img shape: {img.shape}, label shape: {label.shape}')


# whole
# img shape: (512, 512, 723), label shape: (512, 512, 723)

# 1q
# img shape: (512, 512, 361), label shape: (512, 512, 361)

# 2q
# img shape: (512, 512, 362), label shape: (512, 512, 362)

img shape: (512, 512, 362), label shape: (512, 512, 362)


## Run the training

In [None]:
# initialise the LightningModule
net = Net()

# set up loggers and checkpoints
log_dir = os.path.join(root_dir, "logs")
tb_logger = pytorch_lightning.loggers.TensorBoardLogger(save_dir=log_dir)

# initialise Lightning's trainer.
trainer = pytorch_lightning.Trainer(
    devices=[0],
    max_epochs=600,
    logger=tb_logger,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=1,
    log_every_n_steps=16,
)

# train
trainer.fit(net)

In [None]:
print(f"train completed, best_metric: {net.best_val_dice:.4f} " f"at epoch {net.best_val_epoch}")

In [None]:
import torch
from monai.networks.nets import UNet


# Load the model weights from the checkpoint file
checkpoint_path = 'best-checkpoint.ckpt'
model = Net.load_from_checkpoint(checkpoint_path)

# Set the model to evaluation mode
model.eval()


## Model Ensembling

## Dice score

In [19]:
import numpy as np
from monai.data.meta_tensor import MetaTensor

common_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)


def infer(img): #TODO metoda Infer powinna być dodana do klasy Net()
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    
def merge_predictions(pred1: MetaTensor, pred2: MetaTensor) -> MetaTensor:
    pred_merged = np.concatenate((pred1, pred2), axis=3)
    return pred_merged




torch.Size([2, 259, 259, 182])
(2, 259, 259, 364)


In [None]:
from natsort import natsorted


file_paths_1Q = natsorted(glob.glob('nonoverlapping_labels/*/quadrant_1_*_CT_HR_label_airways.nii.gz', recursive=True))
label_paths_1Q = natsorted(glob.glob('nonoverlapping_quadrants/*/quadrant_1_*_CT_HR.nii.gz', recursive=True))


file_paths_2Q = natsorted(glob.glob('nonoverlapping_labels/*/quadrant_2_*_CT_HR_label_airways.nii.gz', recursive=True))
label_paths_2Q = natsorted(glob.glob('nonoverlapping_quadrants/*/quadrant_2_*_CT_HR.nii.gz', recursive=True))

file_paths_whole = natsorted(glob.glob('AeroPath/*/*_CT_HR.nii.gz', recursive=True))
label_paths_whole = natsorted(glob.glob('AeroPath/*/*_CT_HR_label_airways.nii.gz', recursive=True))


In [12]:
# Model inference and Dice score calculation
# Load the trained model
# del model
# model = Net.load_from_checkpoint('checkpoints/best-checkpoint_new_method_1Q_spatial_size_48_0.6382 at epoch 338.ckpt')
# model = Net.load_from_checkpoint('checkpoints/best-checkpoint_new_method_2Q_spatial_size_48_0.7714 at epoch 584.ckpt')
moel = Net.load_from_checkpoint('checkpoints/best-checkpoint_whole_48_0.7861 at epoch 411.ckpt')
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Define the transformations for validation and inference
common_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)






# Convert data and label to tensors and add batch dimension
for idx, (file_path, label_path) in enumerate(zip(file_paths_whole, label_paths_whole)):
    prediction_filename = f'predictions_whole/whole_{idx+1}_prediction.nii.gz'

    data_dict = {"image": file_path, "label": label_path}


    # Apply transformations
    data_dict = common_transforms(data_dict)
    data = data_dict["image"]
    label = data_dict["label"]

    data = torch.tensor(data).unsqueeze(0).to(device)
    label = torch.tensor(label).unsqueeze(0).to(device)

    data_nii = nib.load(file_path).get_fdata()
    label_nii = nib.load(label_path).get_fdata()

    print("input nii gz shape: ", data_nii.shape)
    print("input shape: ", data.shape)

    # Function to run inference
    def infer_on_single_image(model, data):
        model.eval()
        with torch.no_grad():
            roi_size = (48, 48, 48)
            sw_batch_size = 4
            outputs = sliding_window_inference(data, roi_size, sw_batch_size, model)
            post_processed_outputs = [model.post_pred(i) for i in decollate_batch(outputs)]
        return post_processed_outputs

    # Run inference
    predictions = infer_on_single_image(model, data)

    # Process and save the predictions as needed
    for i, prediction in enumerate(predictions):
        # Save or process each prediction here
        # For example, save as NIfTI file
        prediction_np = prediction.cpu().numpy()
        print(type(prediction_np))
        pred_img = nib.Nifti1Image(prediction_np, nib.load(file_path).affine)
        nib.save(pred_img, prediction_filename)

        # label_np = label.cpu().numpy()
        # label_img = nib.Nifti1Image(label_np, nib.load(label_path).affine)
        # nib.save(label_img, f'labels_whole/whole_{idx+1}_label.nii.gz')


while False:
    # Convert predictions and labels to binary format if necessary
    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])

    # Apply the transformations directly to the tensors and ensure they are on the same device
    prediction_tensor = post_pred(prediction.to(device))
    label_tensor = post_label(label.to(device))

    print('pred type:', type(prediction_tensor))
    # Compute Dice score
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric(y_pred=prediction_tensor, y=label_tensor)
    dice_score = dice_metric.aggregate().item()
    dice_metric.reset()

    print("input shape: ", data.shape)
    print("output shape: ", pred_img.shape)
    print(f"Dice Score: {dice_score:.4f}")
    print("Inference complete.")

    pred_nii = nib.load(prediction_filename).get_fdata()

    print("output nii gz shape: ", pred_nii.shape)


  data = torch.tensor(data).unsqueeze(0).to(device)
  label = torch.tensor(label).unsqueeze(0).to(device)


input nii gz shape:  (512, 512, 767)
input shape:  torch.Size([1, 1, 234, 234, 192])
<class 'numpy.ndarray'>
input nii gz shape:  (512, 512, 829)
input shape:  torch.Size([1, 1, 254, 213, 208])
<class 'numpy.ndarray'>
input nii gz shape:  (512, 512, 714)
input shape:  torch.Size([1, 1, 235, 214, 179])
<class 'numpy.ndarray'>
input nii gz shape:  (487, 487, 598)
input shape:  torch.Size([1, 1, 220, 209, 150])
<class 'numpy.ndarray'>
input nii gz shape:  (512, 512, 619)
input shape:  torch.Size([1, 1, 222, 187, 156])
<class 'numpy.ndarray'>
input nii gz shape:  (487, 441, 575)
input shape:  torch.Size([1, 1, 216, 168, 144])
<class 'numpy.ndarray'>
input nii gz shape:  (512, 512, 723)
input shape:  torch.Size([1, 1, 259, 194, 182])
<class 'numpy.ndarray'>
input nii gz shape:  (512, 512, 533)
input shape:  torch.Size([1, 1, 251, 221, 134])
<class 'numpy.ndarray'>
input nii gz shape:  (512, 512, 727)
input shape:  torch.Size([1, 1, 257, 230, 182])
<class 'numpy.ndarray'>
input nii gz shape:

In [12]:
pred_1Q_paths = natsorted(glob.glob('predictions1Q/*', recursive=True))
pred_2Q_paths = natsorted(glob.glob('predictions2Q/*', recursive=True))
pred_whole_paths = natsorted(glob.glob('predictions_whole/*', recursive=True))
label_paths = natsorted(glob.glob('nonoverlapping_labels/*/*_CT_HR_label_airways.nii.gz', recursive=True))

In [16]:
import numpy as np
from scipy.ndimage import zoom

# Define a function to interpolate predictions to match the dimensions of the WHOLE model
# def interpolate_predictions(predictions, target_shape):
#     # Extract the original shape of the predictions
#     original_shape = predictions.shape
#     # Compute the scaling factors for interpolation along each axis
#     scale_factors = [t / o for t, o in zip(target_shape, original_shape)]
#     # Interpolate the predictions using zoom
#     interpolated_predictions = zoom(predictions, zoom=scale_factors, mode='nearest')
#     return interpolated_predictions

import torch.nn.functional as F

def interpolate_predictions(predictions, target_shape):
    # Ensure predictions tensor has 5 dimensions
    predictions = torch.tensor(predictions)
    if predictions.dim() == 6:
        predictions = predictions.squeeze(0)  # Remove batch dimension if present
    
    # Interpolate the predictions using bilinear interpolation
    interpolated_predictions = F.interpolate(predictions, size=target_shape[2:], mode='nearest')
    
    return interpolated_predictions

common_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


for idx, (pred_1Q_path, pred_2Q_path, pred_whole_path, label_path) in enumerate(zip(pred_1Q_paths, pred_2Q_paths, pred_whole_paths, label_paths)):
    pred_1Q = nib.load(pred_1Q_path).get_fdata()
    pred_2Q = nib.load(pred_2Q_path).get_fdata()
    pred_whole = nib.load(pred_whole_path).get_fdata()
    label = nib.load(label_path).get_fdata()

    print(f'pred_1Q shape: {pred_1Q.shape}, pred_2Q shape: {pred_2Q.shape}, pred_whole shape: {pred_whole.shape}, label shape: {label.shape}')

    merged =np.concatenate((pred_1Q, pred_2Q), axis=3)
    # print(f'merged shape: {merged.shape}')

    interpolated_merged = interpolate_predictions(merged, pred_whole.shape)

    # print(f'interpolated_merged shape: {interpolated_merged.shape}')

    # Save the interpolated predictions
    interpolated_merged = interpolated_merged.cpu().numpy()
    merged = nib.Nifti1Image(interpolated_merged, nib.load(pred_whole_path).affine)
    nib.save(merged, f'predictions_merged/merged_{idx+1}_prediction.nii.gz')

    ensembled = np.maximum(interpolated_merged, pred_whole)
    ensembled_nifti = nib.Nifti1Image(ensembled, nib.load(pred_whole_path).affine)

    print('difference merged = ', (interpolated_merged != pred_whole).sum()/pred_whole.size, '%')
    print('difference ensembled = ', 100 * (ensembled != pred_whole).sum()/pred_whole.size, '%')

    nib.save(ensembled_nifti, f'predictions_ensembled/ensembled_{idx+1}_prediction.nii.gz')

    continue 

    #TODO FIX THE DICE SCORE CALCULATION (COMMON TRANSISTIONS)
    prediction = torch.tensor(ensembled).unsqueeze(0).to(device)

    data_dict = {"image": pred_whole_path, "label": label_path}
    data_dict = common_transforms(data_dict)
    data = data_dict["image"]
    label = data_dict["label"]

    label = torch.tensor(label).unsqueeze(0).to(device)


    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
    post_label = Compose([AsDiscrete(to_onehot=2)])

    prediction_tensor = post_pred(prediction.to(device))
    label_tensor = post_label(label.to(device))

    # Compute Dice score
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    dice_metric(y_pred=prediction_tensor, y=label_tensor)
    dice_score = dice_metric.aggregate().item()
    dice_metric.reset()

    print('DICE: ', dice_score)
    print('idx: ', idx)



pred_1Q shape: (2, 234, 234, 96), pred_2Q shape: (2, 234, 234, 97), pred_whole shape: (2, 234, 234, 192), label shape: (512, 512, 383)
difference merged =  0.001886399055202474 %
difference ensembled =  0.000943199527601237 %
pred_1Q shape: (2, 254, 254, 104), pred_2Q shape: (2, 254, 254, 104), pred_whole shape: (2, 254, 213, 208), label shape: (512, 512, 384)
difference merged =  0.0014681136485783263 %
difference ensembled =  0.0007340568242891632 %
pred_1Q shape: (2, 235, 235, 90), pred_2Q shape: (2, 235, 235, 90), pred_whole shape: (2, 235, 214, 179), label shape: (512, 512, 414)
difference merged =  0.0009895677695066936 %
difference ensembled =  0.0004947838847533468 %
pred_1Q shape: (2, 220, 220, 76), pred_2Q shape: (2, 220, 220, 76), pred_whole shape: (2, 220, 209, 150), label shape: (512, 512, 415)
difference merged =  0.001497317674351167 %
difference ensembled =  0.0007486588371755836 %
pred_1Q shape: (2, 222, 222, 78), pred_2Q shape: (2, 222, 222, 78), pred_whole shape: (2,

In [50]:
pred_ensemble = nib.load('predictions_whole/whole_9_prediction.nii.gz').get_fdata()

label = nib.load('labels_whole/whole_9_label.nii.gz').get_fdata()

pred_ensemble.shape
label.shape

(1, 1, 257, 230, 182)

In [26]:
ensemble_paths = natsorted(glob.glob('predictions_ensembled/*', recursive=True))
whole_paths = natsorted(glob.glob('predictions_whole/*', recursive=True))
labels_paths = natsorted(glob.glob('labels_whole/*', recursive=True))


def calc_dice_score(outputs, labels, device="cuda"):
    # Convert outputs and labels to tensors and apply post-processing transformations
    outputs = [EnsureType("tensor", device=device)(output) for output in outputs]
    labels = [EnsureType("tensor", device=device)(label) for label in labels]
    outputs = [AsDiscrete(argmax=True, to_onehot=2)(output) for output in outputs]
    labels = [AsDiscrete(to_onehot=2)(label) for label in labels]

    # Initialize DiceMetric with appropriate settings
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    # Compute Dice score
    for output, label in zip(outputs, labels):
        dice_metric(y_pred=output, y=label)

    # Aggregate Dice scores
    dice_score = dice_metric.aggregate().item()

    return dice_score

for idx, (ensemble, whole, label) in enumerate(zip(ensemble_paths, whole_paths, labels_paths)):

    pred_ensemble = nib.load(ensemble).get_fdata()
    pred_whole = nib.load(whole).get_fdata()

    label = nib.load(label).get_fdata()




    dice_ensembled = calc_dice_score([pred_ensemble], [label])
    dice_whole = calc_dice_score([pred_whole], [label])

    print(f'ensembled: {dice_ensembled}, whole: {dice_whole}, difference: {100 * (pred_ensemble != pred_whole).sum() / pred_whole.size}%')


ensembled: 0.668695867061615, whole: 0.669117271900177, difference: 0.0943199527601237%
ensembled: 0.6395779848098755, whole: 0.6395779848098755, difference: 0.07340568242891632%
ensembled: 0.5913287997245789, whole: 0.5912396311759949, difference: 0.04947838847533468%
ensembled: 0.6233521103858948, whole: 0.6249682307243347, difference: 0.07486588371755835%
ensembled: 0.6002840995788574, whole: 0.5980216264724731, difference: 0.054908878438290204%
ensembled: 0.6032496690750122, whole: 0.6032496690750122, difference: 0.05280862666568685%
ensembled: 0.6036602854728699, whole: 0.6038090586662292, difference: 0.04505306419886685%
ensembled: 0.5743035674095154, whole: 0.5768281817436218, difference: 0.04349455692459445%
ensembled: 0.6304212808609009, whole: 0.6304212808609009, difference: 0.0712259319094034%
ensembled: 0.68426513671875, whole: 0.6845105290412903, difference: 0.10439876420952734%
ensembled: 0.6173056364059448, whole: 0.6173912286758423, difference: 0.06916975962655157%
ense

## Interpolate ensemble

In [12]:
pred_1Q_paths = natsorted(glob.glob('predictions1Q/*', recursive=True))
pred_2Q_paths = natsorted(glob.glob('predictions2Q/*', recursive=True))
pred_whole_paths = natsorted(glob.glob('predictions_whole/*', recursive=True))
label_paths = natsorted(glob.glob('AeroPath/*/*_CT_HR_label_airways.nii.gz', recursive=True))


In [37]:
pred_whole = nib.load(pred_whole_path).get_fdata()
label = nib.load(label_path).get_fdata()

## View training in tensorboard

Please uncomment the following cell to load tensorboard results.

In [None]:
%load_ext tensorboard
%tensorboard --logdir=$log_dir

## Check best model output with the input image and label

In [None]:
net.eval()
device = torch.device("cuda:0")
net.to(device)
with torch.no_grad():
    for i, val_data in enumerate(net.val_dataloader()):
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_data["image"].to(device), roi_size, sw_batch_size, net)
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, 80])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 80])
        plt.show()

## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)