# MONAI Sliding Inference and Post-processing
---

To run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant widely-used approaches. Currently, several popular evaluation metrics and inference patterns are included.

MONAI also includes post-processing tranform functions to help handle the model outputs like removing segmentation noise or extracting contour of segmentation results.

## MONAI Sliding Inference and Post-processing

To help you understand more about MONAI Datasets and Caching options, this guide will help you answer five key questions:

1. **What is sliding inference?**
2. **What are post-processing transforms?**
3. **How do I use sliding inference on real data?**
4. **How do I use post-processing on my sliding inference results?**
5. **How can I use tensorboard to visualize the results?**


Let's get started by importing our dependencies.  We'll also use the Jupyter load extension command to load tensorboard so we can vizualize our results later.

In [None]:
%load_ext tensorboard

import monai
monai.config.print_config()

## **1. What is sliding inference?**

#### A toy model for inference
For model inferences on large volumes, the sliding window approach is a popular choice to achieve high performance while having flexible memory requirements. It also supports overlap and blending_mode configurations to handle the overlapped windows for better performances.

A typical process is:

- Select continuous windows on the original image.
- Iteratively run batched window inferences until all windows are analyzed.
- Aggregate the inference outputs to a single segmentation map.
- Save the results to file or compute some evaluation metrics.

<img src="sliding_window.png" style="width: 700px;"/>

The [sliding_window_inference](https://docs.monai.io/en/latest/inferers.html?highlight=sliding#sliding-window-inference) requires a callable function which takes a batch of image windows as the input.

Here we construct a toy model. It has a single model parameter, `self.pred`. The inference outcome is just `input + self.pred`.

Every time the model is called, it also increases `self.pred` by one. This is to demonstrate that the model can be "stateful", and also so that we can conveniently visualize the inference outputs.

In [None]:
class ToyModel:
    # A simple model generates the output by adding an integer `pred` to input.
    # each call of this instance increases the integer by 1.
    pred = 0
    def __call__(self, input):
        self.pred = self.pred + 1
        return input + self.pred

#### Run the inference using sliding window

We're going to create a 200x200-pixel image and pass it to `sliding_window_inference` with a 40x40 window size then display the image.

In [None]:
import torch
import matplotlib.pyplot as plt
from monai.inferers import sliding_window_inference

input_tensor = torch.zeros(1, 1, 200, 200)
output_tensor = sliding_window_inference(
    inputs=input_tensor, 
    predictor=ToyModel(), 
    roi_size=(40, 40), 
    sw_batch_size=1, 
    overlap=0.5, 
    mode="constant")
plt.imshow(output_tensor[0, 0])

#### Gaussian weighted windows
For a given input image window, the convolutional neural networks often predict the central regions more accurately than the border regions, usually due to the stacked convolutions' receptive field.

Therefore, it is worth considering a "Gaussian weighted" prediction to emphasize the central region predictions when we stitch the windows into a complete inference output.

By simply changing the inference mode to "gaussian", the sliding window module will use this "weighted stitching".

In [None]:
input_tensor = torch.zeros(1, 1, 200, 200)
output_tensor_1 = sliding_window_inference(
    inputs=input_tensor, 
    predictor=ToyModel(), 
    roi_size=(40, 40), 
    sw_batch_size=1, 
    overlap=0.5, 
    mode="gaussian")
plt.imshow(output_tensor_1[0, 0])

Compared with the previous inferences, the overlapping windows are stitched together with fewer border artifacts.

In [None]:
plt.subplots(1, 2)
plt.subplot(1, 2, 1); plt.imshow(output_tensor[0, 0])
plt.subplot(1, 2, 2); plt.imshow(output_tensor_1[0, 0])

## **2. What are post-processing transforms?**

MONAI also provides post-processing transforms for handling the model outputs. Currently, the transforms include:

- Adding activation layer (Sigmoid, Softmax, etc.).
- Converting to discrete values (Argmax, One-Hot, Threshold value, etc), as below figure (b).
- Splitting multi-channel data into multiple single channels.
- Removing segmentation noise based on Connected Component Analysis, as below figure (c).
- Extracting contour of segmentation result, which can be used to map to original image and evaluate the model, as below figure (d) and (e).

After applying the post-processing transforms, it’s easier to compute metrics, save model output into files or visualize data in the TensorBoard.

<img src="post_transforms.png" style="width: 700px;"/>

## **3. How do I use sliding inference on real data?**


This section will set up and load a [SegResNet](https://docs.monai.io/en/latest/networks.html?highlight=segresnet#segresnet) model, run sliding window inference, and post-process the model output volumes:
- Argmax to get a discrete prediction map
- Remove small isolated predicted regions
- Convert the segmentation regions into contours

We'll start by importing all of our dependencies.

In [None]:
import os
import glob

from monai.apps import download_and_extract
from monai.utils import set_determinism
from monai.data import CacheDataset, DataLoader
from monai.networks.nets import SegResNet
from monai.transforms import (
    AddChanneld,
    AsDiscrete,
    Compose,
    CropForegroundd,
    KeepLargestConnectedComponent,
    LabelToContour,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    ToTensord,
)

#### Download image and labels

We'll download the .tar file for Task09_Spleen from the link below.  We'll then extract the data, so it's ready to use in the next section.

#### Set up the validation data, preprocessing transforms, and data loader

We'll put the data and labels into a data dictionary, create a sequence of transforms with `Compose`, using a `CacheDataset`, and then load the data using `DataLoader`.

In [None]:
images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(images, labels)
]
val_files = data_dicts[-9:]

val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)
val_ds = monai.apps.DecathlonDataset(root_dir="./", task="Task09_Spleen", section="validation", download=True, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

#### Set up the model

We want to utilize the GPU, so we'll check to see if it's available and set it as the primary device; otherwise we'll use the CPU.

We'll then instantiate the model with the selected device.

In [None]:
import torch 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
model = SegResNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
).to(device)

In [None]:
model_path = os.path.join(root_dir, "segresnet_model_epoch30.pth")
model.load_state_dict(torch.load(model_path))
model.eval()
print(f"model from {model_path}.")

#### Run the sliding window inference

In [None]:
val_data = next(iter(val_loader))
val_data = val_data["image"].to(device)

roi_size = (88, 88, 88)
sw_batch_size = 1
with torch.no_grad():
  val_output = sliding_window_inference(
      val_data, roi_size, sw_batch_size=sw_batch_size, predictor=model, mode="gaussian", overlap=0.2)
print(val_output.shape, val_output.device)

slice_idx = 80
plt.title(f"image -- slice {slice_idx}")
plt.imshow(val_output.detach().cpu()[0, 1, :, :, 80], cmap="gray")

## **4. How do I use post-processing on my sliding inference results?**

#### Post-processing: argmax over the output probabilities into a discrete map

In [None]:
argmax = AsDiscrete(argmax=True)(val_output)
print(argmax.shape)

slice_idx = 80
plt.subplots(1, 2)
plt.subplot(1, 2, 1)
plt.title(f"image -- slice {slice_idx}")
plt.imshow(val_data.detach().cpu()[0, 0, :, :, 80], cmap="gray")

plt.subplot(1, 2, 2)
plt.title(f"argmax -- slice {slice_idx}")
plt.imshow(argmax.detach().cpu()[0, 0, :, :, 80])

#### Post-processing: connected component analysis to select the largest segmentation region

In [None]:
largest = KeepLargestConnectedComponent(applied_labels=[1])(argmax)
print(largest.shape)

slice_idx = 80
plt.subplots(1, 2)
plt.subplot(1, 2, 1)
plt.title(f"image -- slice {slice_idx}")
plt.imshow(val_data.detach().cpu()[0, 0, :, :, 80], cmap="gray")

plt.subplot(1, 2, 2)
plt.title(f"largest component -- slice {slice_idx}")
plt.imshow(largest.detach().cpu()[0, 0, :, :, 80])

#### Post-processing: convert the region into a contour map

In [None]:
contour = LabelToContour()(largest)
print(contour.shape)

slice_idx = 80
plt.subplots(1, 2)
plt.subplot(1, 2, 1)
plt.title(f"image -- slice {slice_idx}")
plt.imshow(val_data.detach().cpu()[0, 0, :, :, 80], cmap="gray")

plt.subplot(1, 2, 2)
plt.title(f"contour -- slice {slice_idx}")
plt.imshow(contour.detach().cpu()[0, 0, :, :, 80], cmap="Greens")

#### Visualise the contour over the original input

In [None]:
map_image = contour + val_data

slice_idx = 80
plt.subplots(1, 2)
plt.subplot(1, 2, 1)
plt.title(f"image -- slice {slice_idx}")
plt.imshow(val_data.detach().cpu()[0, 0, :, :, 80], cmap="gray")

plt.subplot(1, 2, 2)
plt.title(f"contour -- slice {slice_idx}")
plt.imshow(map_image.detach().cpu()[0, 0, :, :, 80], cmap="gray")

For more details about the post-postprocessing transformations, please visit:
https://docs.monai.io/en/latest/transforms.html#post-processing-dict

## 5. **How can I use tensorboard to visualize the results?**

Visualising the results in animation with tensorboard (please start tensorboard manually in a separate terminal):

In [None]:
from monai.visualize import plot_2d_or_3d_image
from torch.utils.tensorboard import SummaryWriter

with SummaryWriter(log_dir=root_dir) as writer:
    plot_2d_or_3d_image(map_image, step=0, writer=writer, tag="segmentation")
    plot_2d_or_3d_image(val_output, step=0, max_channels=2, writer=writer, tag="Probability")

In [None]:
%tensorboard --logdir=runs

## **Summary**

We've covered MONAI Sliding Inference and Post-Processing. Here are some key highlights:

- Sliding inference runs inference on a moving window of images.  It can be helpful with large volume images and can help improve performance.
- Post Processing transforms can help make handling the model output easier

## **Next Steps**

Start exploring MONAI on your own!  There are lots of great tutorials that can help guide you along the way. You can find it on our [GitHub Organization Page](https://github.com/Project-MONAI/tutorials).  We also have all of our videos from our first ever MONAI Bootcamp available on our [Youtube Channel](https://www.youtube.com/c/ProjectMONAI)