# DeepGrow 2D Inference Tutorial

Deepgrow is an AI Assisted Annotation tool designed for speeding up the annotation process by a user based interaction in the form of clicks. The deepgrow uses a guided predictive segmentation model, where the guidance is generated by the user in the form of positive or negative clicks. The positive clicks are guidance indicators towards the organ/region of interest, while the negative clicks are guidance signals for suggesting regions that should not be a part of the segmentation/annotation. An overview of this process is shown in the below figure:

<img src="../../figures/image_deepgrow_scheme.png" alt='deepgrow scheme'>

based on: Sakinis et al., Interactive segmentation of medical images through
fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import jit

from monai.apps.deepgrow.transforms import (
    AddGuidanceFromPointsd,
    AddGuidanceSignald,
    Fetch2DSliced,
    ResizeGuidanced,
    RestoreLabeld,
    SpatialCropGuidanced,
)
from monai.transforms import (
    AsChannelFirstd,
    Spacingd,
    LoadImaged,
    AddChanneld,
    NormalizeIntensityd,
    ToTensord,
    ToNumpyd,
    Activationsd,
    AsDiscreted,
    Resized
)

max_epochs = 1


def draw_points(guidance):
    if guidance is None:
        return
    colors = ['r+', 'b+']
    for color, points in zip(colors, guidance):
        for p in points:
            p1 = p[-1]
            p2 = p[-2]
            plt.plot(p1, p2, color, 'MarkerSize', 30)


def show_image(image, label, guidance=None):
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")

    if label is not None:
        masked = np.ma.masked_where(label == 0, label)
        plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)

    draw_points(guidance)
    plt.colorbar()

    if label is not None:
        plt.subplot(1, 2, 2)
        plt.title("label")
        plt.imshow(label)
        plt.colorbar()
        # draw_points(guidance)
    plt.show()


def print_data(data):
    for k in data:
        v = data[k]

        d = type(v)
        if type(v) in (int, float, bool, str, dict, tuple):
            d = v
        elif hasattr(v, 'shape'):
            d = v.shape

        if k in ('image_meta_dict', 'label_meta_dict'):
            for m in data[k]:
                print('{} Meta:: {} => {}'.format(k, m, data[k][m]))
        else:
            print('Data key: {} = {}'.format(k, d))

Pre-processing hyper-parameters and Transform compositions. Image is resampled to a 1.0x1.0 mm^2 resolution. The below snippet shows where the guidance signal is placed on the foreground (organ of interest)

In [None]:
# Pre Processing
roi_size = [256, 256]
model_size = [128, 192, 192]
pixdim = (1.0, 1.0)
dimensions = 2

data = {
    'image': '_image.nii.gz',
    'foreground': [[66, 180, 105]],
    'background': []
}
slice_idx = original_slice_idx = data['foreground'][0][2]

pre_transforms = [
    LoadImaged(keys='image'),
    AsChannelFirstd(keys='image'),
    Spacingd(keys='image', pixdim=pixdim, mode='bilinear'),

    AddGuidanceFromPointsd(ref_image='image', guidance='guidance', foreground='foreground', background='background',
                           dimensions=dimensions),
    Fetch2DSliced(keys='image', guidance='guidance'),
    AddChanneld(keys='image'),

    SpatialCropGuidanced(keys='image', guidance='guidance', spatial_size=roi_size),
    Resized(keys='image', spatial_size=roi_size, mode='area'),
    ResizeGuidanced(guidance='guidance', ref_image='image'),
    NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0),
    AddGuidanceSignald(image='image', guidance='guidance'),
    ToTensord(keys='image')
]

original_image = None
original_image_slice = None
for t in pre_transforms:
    tname = type(t).__name__

    data = t(data)
    image = data['image']
    label = data.get('label')
    guidance = data.get('guidance')

    print("{} => image shape: {}, label shape: {}".format(
        tname, image.shape, label.shape if label is not None else None))

    image = image if tname == 'Fetch2DSliced' else image[:, :, slice_idx] if tname in (
        'LoadImaged') else image[slice_idx, :, :]
    label = None

    guidance = guidance if guidance else [np.roll(data['foreground'], 1).tolist(), []]
    print('Guidance: {}'.format(guidance))

    show_image(image, label, guidance)
    if tname == 'Fetch2DSliced':
        slice_idx = 0
    if tname == 'LoadImaged':
        original_image = data['image']
    if tname == 'AddChanneld':
        original_image_slice = data['image']

For a single click, the prediction is made from the deepgrow model. Corresponding input image with the known ground truth is shown along with the predicted segmentation. They has been shown for multiple slices.

In [None]:
# Evaluation
model_path = '/workspace/Data/models/deepgrow_2d.ts'
model = jit.load(model_path)
model.cuda()
model.eval()

inputs = data['image'][None].cuda()
with torch.no_grad():
    outputs = model(inputs)
outputs = outputs[0]
data['pred'] = outputs

post_transforms = [
    Activationsd(keys='pred', sigmoid=True),
    AsDiscreted(keys='pred', threshold_values=True, logit_thresh=0.5),
    ToNumpyd(keys='pred'),
    RestoreLabeld(keys='pred', ref_image='image', mode='nearest'),
]

for t in post_transforms:
    tname = type(t).__name__

    data = t(data)
    image = original_image if tname == 'RestoreLabeld' else data['image']
    label = data['pred']
    print("{} => image shape: {}, pred shape: {}".format(tname, image.shape, label.shape))

    if tname in 'RestoreLabeld':
        image = image[:, :, original_slice_idx]
        label = label[0, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[original_slice_idx]
        print("PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(
            tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))
        show_image(image, label)
    else:
        image = image[0, :, :].detach().cpu().numpy() if torch.is_tensor(image) else image[0]
        label = label[0, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[0]
        print("PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(
            tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))
        show_image(image, label)