Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Spleen 3D segmentation with MONAI

This tutorial demonstrates how MONAI can be used in conjunction with the [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) framework.

We demonstrate use of the following MONAI features:
1. Transforms for dictionary format data.
1. Loading Nifti images with metadata.
1. Add channel dim to the data if no channel dimension.
1. Scaling medical image intensity with expected range.
1. Croping out a batch of balanced images based on  the positive / negative label ratio.
1. Cache IO and transforms to accelerate training and validation.
1. Use of a a 3D UNet model, Dice loss function, and mean Dice metric for a 3D segmentation task.
1. The sliding window inference method.
1. Deterministic training for reproducibility.

The Spleen dataset can be downloaded from http://medicaldecathlon.com/.

![spleen](http://medicaldecathlon.com/img/spleen0.png)

Target: Spleen  
Modality: CT  
Size: 61 3D volumes (41 Training + 20 Testing)  
Source: Memorial Sloan Kettering Cancer Center  
Challenge: Large ranging foreground size

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d_lightning.ipynb)

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python -c "import matplotlib" || pip install -q matplotlib
!pip install -q pytorch-lightning~=2.0
%matplotlib inline

## Setup imports

In [2]:
import pytorch_lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from monai.utils import set_determinism
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureType,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch, DataLoader
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib

print_config()

MONAI version: 1.3.0
Numpy version: 1.23.5
Pytorch version: 2.0.0+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/<username>/.local/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.8.0
Pillow version: 9.0.1
Tensorboard version: 2.12.2
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.15.1+cu117
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.7
pandas version: 1.5.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED 

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [2]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmppuho5ada


In [3]:
data_dir = '/home/gasyna/RiSA_S3/3D_segmentation/AeroPath'

## Download dataset

Downloads and extracts the dataset.
The dataset comes from http://medicaldecathlon.com/.

In [None]:
resource = "https://zenodo.org/records/10069289/files/AeroPath.zip?download=1"
md5 = "3fd5106c175c85d60eaece220f5dfd87"

compressed_file = os.path.join(root_dir, "AeroPath.zip")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

## Define the LightningModule

The LightningModule contains a refactoring of your training code. The following module is a refactoring of the code in `spleen_segmentation_3d.ipynb`:

In [4]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_dice',
    dirpath=os.path.join(root_dir, 'checkpoints'),  # Directory to save checkpoints
    filename='best-checkpoint_whole_64',  # Filename prefix for saving checkpoints
    save_top_k=1,  # Save only the best checkpoint
    mode='max',  # `min` for minimizing the metric, `max` for maximizing
    verbose=True,  # Log a message when saving the best checkpoint
)

In [12]:
class Net(pytorch_lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self._model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        )
        self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        self.post_pred = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(argmax=True, to_onehot=2)])
        self.post_label = Compose([EnsureType("tensor", device="cpu"), AsDiscrete(to_onehot=2)])
        self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
        self.best_val_dice = 0
        self.best_val_epoch = 0
        self.validation_step_outputs = []

    def forward(self, x):
        return self._model(x)

    def prepare_data(self):
        # # set up the correct data path
        # pattern = os.path.join(data_dir, '**/*_CT_HR_label_airways.nii.gz')
        # train_labels = sorted(glob.glob(pattern, recursive=True))

        # pattern = os.path.join(data_dir, '**/*_CT_HR.nii.gz')
        # train_images = sorted(glob.glob(pattern, recursive=True))

        pattern = os.path.join('overlapping_labels', '**/quadrant_1_*.nii.gz')
        train_labels = sorted(glob.glob(pattern, recursive=True))

        pattern = os.path.join('overlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
        train_images = sorted(glob.glob(pattern, recursive=True))

        data_dicts = [
            {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
        ]
        train_files, val_files = data_dicts[:-9], data_dicts[-9:]

        # set deterministic training for reproducibility
        set_determinism(seed=0)

        # define the data transforms
        train_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.5, 1.5, 2.0),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-57,
                    a_max=164,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
                # randomly crop out patch samples from
                # big image based on pos / neg ratio
                # the image centers of negative samples
                # must be in valid image area

                RandCropByPosNegLabeld(
                    keys=["image", "label"],
                    label_key="label",
                    spatial_size=(64, 64, 64),
                    pos=1,
                    neg=1,
                    num_samples=4,
                    image_key="image",
                    image_threshold=0,
                ),

                # user can also add other random transforms
                #                 RandAffined(
                #                     keys=['image', 'label'],
                #                     mode=('bilinear', 'nearest'),
                #                     prob=1.0,
                #                     spatial_size=(96, 96, 96),
                #                     rotate_range=(0, 0, np.pi/15),
                #                     scale_range=(0.1, 0.1, 0.1)),
            ]
        )
        val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=(1.5, 1.5, 2.0),
                    mode=("bilinear", "nearest"),
                ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-57,
                    a_max=164,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                CropForegroundd(keys=["image", "label"], source_key="image"),
            ]
        )
                    

        # we use cached datasets - these are 10x faster than regular datasets
        self.train_ds = CacheDataset(
            data=train_files,
            transform=train_transforms,
            cache_rate=1.0,
            num_workers=4,
        )
        self.val_ds = CacheDataset(
            data=val_files,
            transform=val_transforms,
            cache_rate=1.0,
            num_workers=4,
        )

        self.infer_ds = CacheDataset(
            data=val_files,
            cache_rate=1.0,
            num_workers=4,
        )

    #         self.train_ds = monai.data.Dataset(
    #             data=train_files, transform=train_transforms)
    #         self.val_ds = monai.data.Dataset(
    #             data=val_files, transform=val_transforms)

    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
            collate_fn=list_data_collate,
        )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_ds, batch_size=1, num_workers=4)
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        return {"loss": loss, "log": tensorboard_logs}

    
    def validation_step(self, batch, batch_idx):
        images, labels = batch["image"], batch["label"]
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
        loss = self.loss_function(outputs, labels)
        outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
        labels = [self.post_label(i) for i in decollate_batch(labels)]
        self.dice_metric(y_pred=outputs, y=labels)
        d = {"val_loss": loss, "val_number": len(outputs)}
        self.validation_step_outputs.append(d)
        return d
    
    # def evaluate(self):
    #     images, labels = self.infer_ds["image"], self.infer_ds["label"]

    #     print(images)
    #     roi_size = (160, 160, 160)
    #     sw_batch_size = 4
    #     outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
    #     loss = self.loss_function(outputs, labels)
    #     outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
    #     labels = [self.post_label(i) for i in decollate_batch(labels)]
    #     self.dice_metric(y_pred=outputs, y=labels)
    #     d = {"val_loss": loss, "val_number": len(outputs)}
    #     # self.validation_step_outputs.append(d)

    #     print(d)
    #     return outputs
    
    # def evaluate(self):
    #     infer_loader = DataLoader(self.infer_ds, batch_size=1, num_workers=4)
    #     outputs = []

    #     for batch in infer_loader:
    #         images, labels = batch["image"], batch["label"]
    #         roi_size = (160, 160, 160)
    #         sw_batch_size = 4
    #         output = sliding_window_inference(images, roi_size, sw_batch_size, self)
    #         output = [self.post_pred(i) for i in decollate_batch(output)]
    #         labels = [self.post_label(i) for i in decollate_batch(labels)]

    #         outputs.extend(output)
        
    #         loss = self.loss_function(outputs, labels)
    #         d = {"val_loss": loss, "val_number": len(outputs)}
    #         print(d)

    #     return outputs
    # def evaluate(self):
    #     # Get the images and labels from the inference dataset
    #     dataset_images = [item['image'] for item in self.infer_ds]
    #     dataset_labels = [item['label'] for item in self.infer_ds]

    #     # Convert images and labels to tensors
    #     images = torch.stack(dataset_images)
    #     labels = torch.stack(dataset_labels)

    #     # Perform inference
    #     roi_size = (160, 160, 160)
    #     sw_batch_size = 4
    #     outputs = sliding_window_inference(images, roi_size, sw_batch_size, self)
    #     loss = self.loss_function(outputs, labels)
    #     outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
    #     labels = [self.post_label(i) for i in decollate_batch(labels)]
    #     self.dice_metric(y_pred=outputs, y=labels)
    #     d = {"val_loss": loss, "val_number": len(outputs)}
        
    #     print(d)
    #     return outputs
    
    def evaluate_single_model(self, model, file_paths):
        # Load and preprocess NIfTI files
        input_data = [self.load_and_preprocess_nifti(file_path) for file_path in file_paths]

        # Perform inference using the model
        model_output = [self.perform_inference(model, data) for data in input_data]

        return model_output
    
    def load_and_preprocess_nifti(self, file_path):
        img = nib.load(file_path)
        data = img.get_fdata()
        data = torch.DoubleTensor(data)  # Convert data to type Double
        # Apply any preprocessing steps here, such as normalization
        # Remember to ensure the data shape matches the input shape expected by your models
        return data


    
    def perform_inference(self, model, data):
        # Perform inference using the model
        with torch.no_grad():
            data = torch.DoubleTensor(data)  # Convert data to type Double
            model_output = model(data.unsqueeze(0))
        return model_output

    def on_validation_epoch_end(self):
        val_loss, num_items = 0, 0
        for output in self.validation_step_outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.dice_metric.aggregate().item()
        self.dice_metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        tensorboard_logs = {
            "val_dice": mean_val_dice,
            "val_loss": mean_val_loss,
        }
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        self.validation_step_outputs.clear()  # free memory
        self.log('val_dice', mean_val_dice, on_step=False, on_epoch=True, prog_bar=True, logger=True) # log

        return {"log": tensorboard_logs}

In [None]:
net = Net()

net.prepare_data()
net.evaluate()

In [None]:
def create_overlapping_quadrants(image_shape, overlap_percentage=0.2):
    x, y, z = image_shape
    overlap_z = int(z * overlap_percentage)
    if z < 3 * overlap_z:
        return [(0, z)]
    mid_z = z // 2
    return [
        (0, mid_z + overlap_z),
        (mid_z - overlap_z, z)
    ]

def load_nifti_file(file_path):
    return nib.load(file_path)

def save_nifti_file(data, affine, file_path):
    img = nib.Nifti1Image(data, affine)
    nib.save(img, file_path)

def process_file(file_path, output_dir):
    img = load_nifti_file(file_path)
    data = img.get_fdata()
    affine = img.affine
    quadrants = create_overlapping_quadrants(data.shape)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for i, (start, end) in enumerate(quadrants):
        quadrant_data = data[:, :, start:end]
        quadrant_file_path = os.path.join(output_dir, f"quadrant_{i+1}_{os.path.basename(file_path)}")
        save_nifti_file(quadrant_data, affine, quadrant_file_path)
        print(f"Saved: {quadrant_file_path}")

# Example Usage

for idx, (image_name, label_name) in enumerate(zip(train_images, train_labels)):
    output_dir = f'overlapping_quadrants/{idx}'
    process_file(image_name, output_dir)
    output_dir = f'overlapping_labels/{idx}'
    process_file(label_name, output_dir)


In [None]:
pattern = os.path.join('overlapping_labels', '**/quadrant_1_*.nii.gz')
train_labels = sorted(glob.glob(pattern, recursive=True))

pattern = os.path.join('overlapping_quadrants', '**/quadrant_1_*_CT_HR.nii.gz')
train_images = sorted(glob.glob(pattern, recursive=True))

In [None]:

for img, label in zip (train_images, train_labels):
    img = nib.load(img).get_fdata()
    label = nib.load(label).get_fdata()

    print(f'img shape: {img.shape}, label shape: {label.shape}')

In [None]:
import matplotlib.pyplot as plt

# quadrant1 = nib.load('quadrant1.nii.gz').get_fdata()
# quadrant2 = nib.load('quadrant2.nii.gz').get_fdata()
# quadrant3 = nib.load('quadrant3.nii.gz').get_fdata()
# quadrant4 = nib.load('quadrant4.nii.gz').get_fdata()


def plot_slice(data, title):
    plt.figure()
    plt.imshow(data[data.shape[0] // 2, :, :], cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()

# plot_slice(quadrant1, "Quadrant 1")
# plot_slice(quadrant2, "Quadrant 2")
# plot_slice(quadrant3, "Quadrant 3")
# plot_slice(quadrant4, "Quadrant 4")

In [None]:
new = nib.load('overlapping_quadrants/0/quadrant_1_1_CT_HR.nii.gz').get_fdata()

plot_slice(new, 'new')

In [None]:
new.shape

## Run the training

In [None]:
# initialise the LightningModule
net = Net()

# set up loggers and checkpoints
log_dir = os.path.join(root_dir, "logs")
tb_logger = pytorch_lightning.loggers.TensorBoardLogger(save_dir=log_dir)

# initialise Lightning's trainer.
trainer = pytorch_lightning.Trainer(
    devices=[0],
    max_epochs=600,
    logger=tb_logger,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=1,
    log_every_n_steps=16,
)

# train
trainer.fit(net)

In [None]:
print(f"train completed, best_metric: {net.best_val_dice:.4f} " f"at epoch {net.best_val_epoch}")

In [None]:
import torch
from monai.networks.nets import UNet


# Load the model weights from the checkpoint file
checkpoint_path = 'best-checkpoint.ckpt'
model = Net.load_from_checkpoint(checkpoint_path)

# Set the model to evaluation mode
model.eval()


## Model Ensembling

In [15]:
# Instantiate the Net class

model = Net.load_from_checkpoint('checkpoints/best-checkpoint_new_method_1Q_spatial_size_48_0.6382 at epoch 338.ckpt')
model.eval()


file_path = '/home/gasyna/RiSA_S3/3D_segmentation/overlapping_quadrants/0/quadrant_1_1_CT_HR.nii.gz'
img = nib.load(file_path)
data = img.get_fdata()

data = torch.DoubleTensor(data)  # Convert data to type Double


with torch.no_grad():
    y_hat = model(data)

RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [512, 512, 536]

### Model inference (works)

In [41]:
# Model inference

import torch
import nibabel as nib
from monai.transforms import (
    Spacingd, ScaleIntensityRanged, CropForegroundd, Orientationd, EnsureTyped, Compose
)
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch

# Load the trained model
model = Net.load_from_checkpoint('checkpoints/best-checkpoint_new_method_1Q_spatial_size_96_0.6293 at epoch 340.ckpt')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')

# Define the transformations for new data
infer_transforms = Compose(
    [
        Orientationd(keys="image", axcodes="RAS"),
        # Spacingd(keys="image", pixdim=(1.5, 1.5, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
        # CropForegroundd(keys="image", source_key="image"),
        EnsureTyped(keys="image"),
    ]
)

# Load and preprocess the new data
file_path = 'overlapping_quadrants/18/quadrant_1_26_CT_HR.nii.gz'
img = nib.load(file_path)
data = img.get_fdata()

# Manually add the channel dimension
data = data[None, ...]
data_raw = data
# Create a dictionary with the image
data_dict = {"image": data}

# Apply transformations
data_dict = infer_transforms(data_dict)

# Convert data to tensor and add batch dimension
data = torch.tensor(data_dict["image"]).unsqueeze(0).to(model.device)


# Function to run inference
def infer_on_single_image(model, data):
    model.eval()
    with torch.no_grad():
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        outputs = sliding_window_inference(data, roi_size, sw_batch_size, model)
        post_processed_outputs = [model.post_pred(i) for i in decollate_batch(outputs)]
    return post_processed_outputs

# Run inference
predictions = infer_on_single_image(model, data)

# Process and save the predictions as needed
for i, prediction in enumerate(predictions):
    # Save or process each prediction here
    # For example, save as NIfTI file
    prediction_np = prediction.cpu().numpy()
    pred_img = nib.Nifti1Image(prediction_np, img.affine)
    # nib.save(pred_img, f"prediction_{i}.nii.gz")
    nib.save(pred_img, "prediction.nii.gz")

print("input raw shape: ", data_raw.shape)
print("input shape: ", data.shape)
print("output shape: ", pred_img.shape)
print("Inference complete.")


  data = torch.tensor(data_dict["image"]).unsqueeze(0).to(model.device)


input raw shape:  (1, 512, 512, 185)
input shape:  torch.Size([1, 1, 512, 512, 185])
output shape:  (2, 512, 512, 185)
Inference complete.


## Dice score

In [31]:
label = nib.load('overlapping_labels/0/quadrant_1_1_CT_HR_label_airways.nii.gz')
# image = nib.load('overlapping_quadrants/0/quadrant_1_1_CT_HR.nii.gz')
image = nib.load('prediction_0.nii.gz')

label = label.get_fdata()
image = image.get_fdata()

image = image[0, :, :]

print(f'img shape: {image.shape}, label shape: {label.shape}')

img shape: (512, 512, 536), label shape: (512, 512, 536)


In [43]:
import torch
import nibabel as nib
import numpy as np
from monai.metrics import DiceMetric
from monai.transforms import AsDiscrete, EnsureType

# Load the prediction and label NIfTI files
prediction_path = "prediction.nii.gz"  # Update with your actual prediction file path
label_path = "overlapping_labels/18/quadrant_1_26_CT_HR_label_airways.nii.gz"  # Update with your actual label file path
# prediction_path = 'overlapping_quadrants/0/quadrant_1_1_CT_HR.nii.gz'
prediction_img = nib.load(prediction_path)
label_img = nib.load(label_path)

# Get the data from the NIfTI images
prediction_data = prediction_img.get_fdata()[0, :, :]
label_data = label_img.get_fdata()

# Ensure the data has the same shape and type
assert prediction_data.shape == label_data.shape, "Prediction and label must have the same shape"

# Convert to PyTorch tensors
prediction_tensor = torch.tensor(prediction_data).unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
label_tensor = torch.tensor(label_data).unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions

# Convert to binary using argmax if needed (assume binary classification for Dice score)
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])

prediction_tensor = post_pred(prediction_tensor)
label_tensor = post_label(label_tensor)

# Compute Dice score
dice_metric = DiceMetric(include_background=False, reduction="mean")
dice_metric(y_pred=prediction_tensor, y=label_tensor)
dice_score = dice_metric.aggregate().item()
dice_metric.reset()

print(f"Dice Score: {dice_score:.4f}")


Dice Score: 0.4997


: 

NameError: name 'EnsureTyped' is not defined

In [None]:
# Instantiate the Net class

model = Net.load_from_checkpoint('checkpoints/best-checkpoint_new_method_1Q_spatial_size_48_0.6382 at epoch 338.ckpt')
model.eval()


file_path = '/home/gasyna/RiSA_S3/3D_segmentation/overlapping_quadrants/0/quadrant_1_1_CT_HR.nii.gz'
img = nib.load(file_path)
data = img.get_fdata()

data = torch.DoubleTensor(data)  # Convert data to type Double


with torch.no_grad():
    y_hat = model(data)

RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [512, 512, 536]

In [None]:
# Instantiate the Net class
model = Net()

# Load the checkpoint for the single model
checkpoint_path = 'checkpoints/best-checkpoint_new_method_1Q_spatial_size_48_0.6382 at epoch 338.ckpt'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])  # Assuming 'state_dict' is the key for model state_dict in the checkpoint

# Set the model to evaluation mode
model.eval()

# Define the paths to the NIfTI files
file_paths = [
    '/home/gasyna/RiSA_S3/3D_segmentation/overlapping_quadrants/0/quadrant_1_1_CT_HR.nii.gz',
    # Add paths for other files if applicable
]

# Perform inference using the single model
model_output = model.evaluate_single_model(model, file_paths)


In [40]:
def load_and_preprocess_nifti(file_path):
    img = nib.load(file_path)
    data = img.get_fdata()
    # Apply any preprocessing steps here, such as normalization
    # Remember to ensure the data shape matches the input shape expected by your models
    return data


model = Net().double()  # Ensure the model is set to accept Double type input

# Set the model to evaluation mode
model.eval()



# Define the paths to the NIfTI files and load them
input_data = [load_and_preprocess_nifti(file_path) for file_path in file_paths]

# Convert input data to Double type
input_data = [torch.DoubleTensor(data) for data in input_data]



with torch.no_grad():
    # Ensure each data has a batch dimension
    input_data = [data.unsqueeze(0) for data in input_data]
    # Perform inference
    model_output = [model(data) for data in input_data]

In [None]:
import torch

# Define the paths to the checkpoints for each model
checkpoint_paths = [
    'checkpoints/best-checkpoint_new_method_1Q_spatial_size_48_0.6382 at epoch 338.ckpt',
    'checkpoints/best-checkpoint_new_method_2Q_spatial_size_48_0.7714 at epoch 584.ckpt'
    # Add paths for other models if applicable
]

# Create instances of the Net class for each model
models = []
for checkpoint_path in checkpoint_paths:
    # Initialize the model
    model = Net()
    
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])  # Assuming 'state_dict' is the key for model state_dict in the checkpoint
    
    # Set the model to evaluation mode
    model.eval()
    
    model.prepare_data()
    model.evaluate()
    # Append the model to the list of models
    models.append(model)

    break
# Now the models list contains instances of the Net class loaded with the trained weights from checkpoints


In [None]:
import nibabel as nib
import numpy as np
import torch


# Define a function to load and preprocess NIfTI files
def load_and_preprocess_nifti(file_path):
    img = nib.load(file_path)
    data = img.get_fdata()
    # Apply any preprocessing steps here, such as normalization
    # Remember to ensure the data shape matches the input shape expected by your models
    return torch.tensor(data).unsqueeze(0).float()  # Assuming input shape is (batch_size, channels, height, width, depth)

# Define a function to perform inference on NIfTI data using the ensemble of models
def perform_inference(models, input_data):
    model_outputs = []
    for model in models:
        model.eval()
        with torch.no_grad():
            output = model(input_data)
            model_outputs.append(output)
    # Aggregate the outputs from different models
    aggregated_output = torch.mean(torch.stack(model_outputs), dim=0)
    return aggregated_output

# Define the paths to the NIfTI files and load them
file_paths = ["path_to_file1.nii.gz", "path_to_file2.nii.gz", ...]
input_data = [load_and_preprocess_nifti(file_path) for file_path in file_paths]

# Perform inference using the ensemble of models
ensemble_output = [perform_inference(models, data) for data in input_data]

# Post-process the output predictions if necessary
# For example, you can convert the logits to probabilities using softmax

# Save or visualize the results
# You can save the output as NIfTI files or visualize them using plotting libraries like matplotlib or mayavi

## View training in tensorboard

Please uncomment the following cell to load tensorboard results.

In [None]:
%load_ext tensorboard
%tensorboard --logdir=$log_dir

## Check best model output with the input image and label

In [None]:
net.eval()
device = torch.device("cuda:0")
net.to(device)
with torch.no_grad():
    for i, val_data in enumerate(net.val_dataloader()):
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_data["image"].to(device), roi_size, sw_batch_size, net)
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, 80])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 80])
        plt.show()

## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)