# Visualization of Segmentation Results

In this notebook I illustrate various ways one can display the results of segmentation algorithms and the overlays so that they can be easily incorporated into a manuscript. 

The important point to remember when working with bio-medical images is that evaluating segmentation algorithms is most often done using reference data to which you compare your results.
 
In the medical domain reference data is commonly obtained via manual segmentation by an expert. When you are resource limited, the reference data may be defined by a single expert. 
   
In this notebook I show how to display the overlays of reference segmentation and manual segmentation results. Once we have a reference, we compare the algorithm's performance using multiple criteria, as usually there is no single evaluation measure that conveys all of the relevant information. 
The data we use in the notebook is a set of manually segmented knee articular cartilage and bone from a single clinical MRI scan. A larger dataset (four scans) 
is freely available from this [Osteoarthritis Initiative (OAI)](https://nda.nih.gov/oai/) repository. The relevant publication is: Ebrahimkhani, S., et al. "Automated segmentation of knee articular cartilage: Joint deep and hand-crafted learning-based framework using diffeomorphic mapping." Neurocomputing 467 (2022): 36-55.

### Importing the python packages

In [None]:
__author__='Somayeh Eb.'

import pandas as pd
import pdb
import numpy as np
import sys
import os
from os.path import isfile,join
from os import listdir

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import scipy.io as sio

from medpy.io import load
import medpy.metric as mdp
import medpy.metric.binary as metrics
import SimpleITK as sitk
import nibabel as nib

import eval_volumes as evol
import eval_segm as es

from sklearn.preprocessing import LabelEncoder
from enum import Enum
# import gui
from ipywidgets import interact, fixed


title_bar=['SubjectID','MeanDSC','MenaIoU','MeanPixelAcc.','ASD(mm)','RMSD(mm)', 'VOE(%)', 'VD(%)','Sensitivity(%)','Specificity(%)','MSD(mm)']

# root directory of subjects' folder
root_dir = '/home/esomayeh/workplace/OAI-Z_samples/'

# Always write output to a separate directory, we don't want to pollute the source directory.
OUTPUT_DIR = './overlay_outputs_oai-z/'
 

def display_with_overlay(slice_number, image, segs, window_min, window_max, title=None, output_name='results.png'):
    #def display_with_overlay(segmentation_number, slice_number, image, segs, window_min, window_max):
    """
    Display a MRI slice with segmented contours overlaid onto it. The contours are the edges of
    the labeled regions.
    """
    #img = image[:,:,slice_number]
    #msk = segs[:, :, slice_number]
    # msk = segs[segmentation_number][:,:,slice_number]
    img = image[slice_number,:,:]
    msk = segs[slice_number,:,:]

    #mask,sitk.sitkLabelUInt8
    overlay_img = sitk.LabelMapContourOverlay(sitk.Cast(msk, sitk.sitkLabelUInt8),
                                              sitk.Cast(sitk.IntensityWindowing(img,
                                                                                windowMinimum=window_min,
                                                                                windowMaximum=window_max,outputMinimum=0.0, outputMaximum=300.0),
                                                        sitk.sitkUInt8),
                                             opacity = 1,
                                             contourThickness=[2,2])
    #We assume the original slice is isotropic, otherwise the display would be distorted
    #values = [1,2,3,4]
    #labels = ['femur bone', 'femoral cartilage', 'tibia bone', 'tibial cartilage']
    #labels = ['FB', 'FC', 'TB', 'TC']
    # plt.figure(figsize=(8,4))
    #im =plt.imshow(sitk.GetArrayViewFromImage(overlay_img))


    # get the colors of the values, according to the
    # colormap used by imshow
    #colors = [im.cmap(im.norm(value)) for value in values]     #[(0.26851, 0.009605, 0.335427, 1.0), (0.269944, 0.014625, 0.341379, 1.0), (0.271305, 0.019942, 0.347269, 1.0), (0.272594, 0.025563, 0.353093, 1.0)]
    #pdb.set_trace()

    #colors=[(0, 205, 0,255), (0, 0, 255,255), (255,0, 255,255), (0, 255, 255,255)]
    #le = LabelEncoder()
    #y = [le.fit_transform(c) for c in colors]

    # create a patch (proxy artist) for every color
    #patches = [mpatches.Patch(color=y[i], label="{l}".format(l=labels[i])) for i in range(len(values))]
    # put those patched as legend-handles into the legend
    #plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    plt.imshow(sitk.GetArrayViewFromImage(overlay_img))
    if (title):
        plt.title(title)
    plt.axis('off')
    plt.savefig(output_name)
    plt.show()

In [None]:
"""Important Note: change the obj_label for different objects manually based on the object label in the ground truth, 
AND, change the file name accordingly
This code includes the calculation for sensitivity and specificity
"""
obj_label= 1    #*** fb =1, fc= 2, tb = 3, tc=4

subjects=listdir(root_dir)
N = len(subjects)
subjects.sort()

subject_id = []
# pdb.set_trace()
for s in range(3,5):

    
    sub_dir = root_dir + subjects[s] + '/orig_transformed/'
    img_file_dir = sub_dir + subjects[s] + '_image_tb.nii.gz'
    gt_file_dir = sub_dir + subjects[s] + '_label_tb.nii.gz'
    res_file_dir = sub_dir + subjects[s] + '_pred_tb.nii.gz'
    st2res_file_dir = sub_dir + subjects[s] + '_pred2_tb.nii.gz'
    #pdb.set_trace()

    img, img_header=load(img_file_dir)
    i_spacing = img_header.get_voxel_spacing()  #img_header.get_voxel_spacing() python3  #img_header.affine.diagonal()[0:3]  # voxel spacing python2

    lbl, lbl_header = load(gt_file_dir)  # shape: width,height,depth
    res_vol, res_header = load(res_file_dir)  # shape: width,height,depth

    gray_itkimage = sitk.ReadImage(img_file_dir, sitk.sitkFloat32)
    grayimage = sitk.GetArrayFromImage(gray_itkimage)   # shape : (384, 384, 118)

    segmentation_itkimage = sitk.ReadImage(res_file_dir, sitk.sitkUInt8)
    seg = sitk.GetArrayFromImage(segmentation_itkimage)   # shape : (384, 384, 118)
    seg[seg == 1] = 0  # remove FB bone
    seg[seg == 3] = 0  # remove TB

    numpyOrigin_seg = np.array(list(segmentation_itkimage.GetOrigin()))
    numpySpacing_seg = np.array(list(segmentation_itkimage.GetSpacing()))
    segmentation = sitk.Image(segmentation_itkimage.GetSize(), sitk.sitkUInt8)
    segmentation = sitk.Paste(segmentation, sitk.GetImageFromArray(seg), segmentation_itkimage.GetSize())
    segmentation.SetOrigin(numpyOrigin_seg)
    segmentation.SetSpacing(numpySpacing_seg)

    seg2_itkimage = sitk.ReadImage(st2res_file_dir, sitk.sitkUInt8)
    seg2 = sitk.GetArrayFromImage(seg2_itkimage)  # shape : (384, 384, 118)
    seg2[seg2== 1] = 0  # remove FB bone
    seg2[seg2 == 3] = 0  # remove TB

    numpyOrigin_seg2 = np.array(list(seg2_itkimage.GetOrigin()))
    numpySpacing_seg2 = np.array(list(seg2_itkimage.GetSpacing()))
    segmentation2 = sitk.Image(seg2_itkimage.GetSize(), sitk.sitkUInt8)
    segmentation2 = sitk.Paste(segmentation2, sitk.GetImageFromArray(seg2), seg2_itkimage.GetSize())
    segmentation2.SetOrigin(numpyOrigin_seg2)
    segmentation2.SetSpacing(numpySpacing_seg2)

    reference_itkimage = sitk.ReadImage(gt_file_dir, sitk.sitkUInt8)
    ref = sitk.GetArrayFromImage(reference_itkimage)      # shape : (384, 384, 118)
    ref[ref == 1] = 0  # remove FB bone
    ref[ref == 3] = 0 # remove TB

    numpyOrigin_ref = np.array(list(reference_itkimage.GetOrigin()))
    numpySpacing_ref = np.array(list(reference_itkimage.GetSpacing()))
    reference = sitk.Image(reference_itkimage.GetSize(), sitk.sitkUInt8)
    reference = sitk.Paste(reference, sitk.GetImageFromArray(ref), reference_itkimage.GetSize())
    reference.SetOrigin(numpyOrigin_ref)
    reference.SetSpacing(numpySpacing_ref)

    grayimage_redirected = sitk.Image(segmentation_itkimage.GetSize(), sitk.sitkFloat32)  # +
    grayimage_redirected = sitk.Paste(grayimage_redirected, sitk.GetImageFromArray(grayimage),
                                      segmentation_itkimage.GetSize())  # +
    grayimage_redirected.SetOrigin(numpyOrigin_seg)  # +
    grayimage_redirected.SetSpacing(numpySpacing_seg)  # +

    #pdb.set_trace()
    n_slices = gray_itkimage.GetSize()[2]
    for slice_number in range(n_slices):
        subject_folder = OUTPUT_DIR + subjects[s]
        if not os.path.isdir(subject_folder):
            os.makedirs(subject_folder)
        output_segname = subject_folder + '/sl-'+str(slice_number)+'-3dunet.png'
        display_with_overlay(slice_number, grayimage_redirected, segmentation, 0, 1, output_name=output_segname)
        output_segname2 = subject_folder + '/sl-'+str(slice_number)+'-jdhcl.png'
        display_with_overlay(slice_number, grayimage_redirected, segmentation2, 0, 1, output_name=output_segname2)

        output_gtname = subject_folder + '/sl-'+str(slice_number)+'-gt.png'
        #display_with_overlay(slice_number, gray_itkimage, reference_itkimage, 0, 1, output_name=output_gtname)
        display_with_overlay(slice_number, grayimage_redirected, reference, 0, 1, output_name=output_gtname)

    slice_number = 23
    #pdb.set_trace()
    #labels=[1,2,3,4]
    labels = [2, 4]
    mean_dsc = np.zeros(len(labels), dtype=float)
    for l in range(len(labels)):
        segm=seg.copy()
        gt=ref.copy()
        segm_st2 = seg2.copy()
        segm[segm != labels[l]] = 0  # 3dunet
        segm_st2[segm_st2 != labels[l]] = 0
        gt[gt != labels[l]] = 0
        mean_dsc[l]=es.mean_DSC(segm[:, :, slice_number], gt[:, :, slice_number])
        #pdb.set_trace()
    print(subjects[s], ', slice ', slice_number, ', DSC: ',mean_dsc, ' Average DSC: ', np.mean(mean_dsc))

    #pdb.set_trace()