## Generate superpixel-based pseudolabels


### Overview

This is the third step for data preparation

Input: normalized images

Output: pseulabel label candidates for all the images

In [12]:
%reset
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import copy
import skimage

from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.measure import label 
import scipy.ndimage.morphology as snm
from skimage import io
import argparse
import numpy as np
import glob

import SimpleITK as sitk
import os
from tqdm.notebook import tqdm

to01 = lambda x: (x - x.min()) / (x.max() - x.min())



Once deleted, variables cannot be recovered. Proceed (y/[n])? n
Nothing done.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


**Summary**

a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background

b. Generate superpixels as pseudolabels

**Configurations of pseudlabels**

```python
# default setting of minimum superpixel sizes
segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
# you can also try other configs
segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)
```


In [2]:
DATASET_CONFIG = {'SABS':{
                    'img_bname': f'/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_*.nii.gz',
                    'out_dir': '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized',
                    'fg_thresh': 1e-4
                    },
                  'CHAOST2':{
                      'img_bname': f'../CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',
                      'out_dir': './CHAOST2/chaos_MR_T2_normalized',
                      'fg_thresh': 1e-4 + 50
                    }
                 }
            

DOMAIN = 'SABS'
img_bname = DATASET_CONFIG[DOMAIN]['img_bname']
imgs = glob.glob(img_bname)
out_dir = DATASET_CONFIG[DOMAIN]['out_dir']


In [3]:
imgs

['/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_113.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_PA029.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_149.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_148.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_124.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_135.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_PA042.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_PA027.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_086.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_PA039.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_144.nii.gz',
 '/home/htang6/workspace/data/abdom

In [4]:
imgs = sorted(imgs, key = lambda x: x.split('_')[-1].split('.nii.gz')[0])

In [5]:
imgs

['/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_084.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_085.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_086.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_087.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_088.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_089.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_090.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_091.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_092.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_093.nii.gz',
 '/home/htang6/workspace/data/abdomen/superpixel/sabs_CT_normalized/image_094.nii.gz',
 '/home/htang6/workspace/data/abdomen/super

In [6]:
MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting

# wrapper for process 3d image in 2d
def superpix_vol(img, method = 'fezlen', **kwargs):
    """
    loop through the entire volume
    assuming image with axis z, x, y
    """
    if method =='fezlen':
        seg_func = skimage.segmentation.felzenszwalb
    else:
        raise NotImplementedError
        
    out_vol = np.zeros(img.shape)
    for ii in range(img.shape[0]):
        if MODE == 'MIDDLE':
            segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
        else:
            raise NotImplementedError
        out_vol[ii, ...] = segs
        
    return out_vol

# thresholding the intensity values to get a binary mask of the patient
def fg_mask2d(img_2d, thresh): # change this by your need
    mask_map = np.float32(img_2d > thresh)
    
    def getLargestCC(segmentation): # largest connected components
        labels = label(segmentation)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        return largestCC
    if mask_map.max() < 0.999:
        return mask_map
    else:
        post_mask = getLargestCC(mask_map)
        fill_mask = snm.binary_fill_holes(post_mask)
    return fill_mask

# remove superpixels within the empty regions
def superpix_masking(raw_seg2d, mask2d):
    raw_seg2d = np.int32(raw_seg2d)
    lbvs = np.unique(raw_seg2d)
    max_lb = lbvs.max()
    raw_seg2d[raw_seg2d == 0] = max_lb + 1
    lbvs = list(lbvs)
    lbvs.append( max_lb )
    raw_seg2d = raw_seg2d * mask2d
    lb_new = 1
    out_seg2d = np.zeros(raw_seg2d.shape)
    for lbv in lbvs:
        if lbv == 0:
            continue
        else:
            out_seg2d[raw_seg2d == lbv] = lb_new
            lb_new += 1
    
    return out_seg2d
            
def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4):
    raw_seg = superpix_vol(img)
    fg_mask_vol = np.zeros(raw_seg.shape)
    processed_seg_vol = np.zeros(raw_seg.shape)
    for ii in range(raw_seg.shape[0]):
        if verbose:
            print("doing {} slice".format(ii))
        _fgm = fg_mask2d(img[ii, ...], fg_thresh )
        _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)
        fg_mask_vol[ii] = _fgm
        processed_seg_vol[ii] = _out_seg
    return fg_mask_vol, processed_seg_vol
        
# copy spacing and orientation info between sitk objects
def copy_info(src, dst):
    dst.SetSpacing(src.GetSpacing())
    dst.SetOrigin(src.GetOrigin())
    dst.SetDirection(src.GetDirection())
    # dst.CopyInfomation(src)
    return dst


def strip_(img, lb):
    img = np.int32(img)
    if isinstance(lb, float):
        lb = int(lb)
        return np.float32(img == lb) * float(lb)
    elif isinstance(lb, list):
        out = np.zeros(img.shape)
        for _lb in lb:
            out += np.float32(img == int(_lb)) * float(_lb)
            
        return out
    else:
        raise Exception

In [18]:
show3Dimg(sitk.GetArrayFromImage(out_seg_o))

interactive(children=(IntSlider(value=0, description='k', max=238), Output()), _dom_classes=('widget-interact'…

In [17]:
show3Dimg(sitk.GetArrayFromImage(im_obj))

interactive(children=(IntSlider(value=0, description='k', max=238), Output()), _dom_classes=('widget-interact'…

In [16]:
# Generate pseudolabels for every image and save them
for img_fid in tqdm(imgs, total=len(imgs)):
# img_fid = imgs[0]

    idx = os.path.basename(img_fid).split("_")[-1].split(".nii.gz")[0]
    im_obj = sitk.ReadImage(img_fid)

    out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj), fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'] )
    out_fg_o = sitk.GetImageFromArray(out_fg ) 
    out_seg_o = sitk.GetImageFromArray(out_seg )

    out_fg_o = copy_info(im_obj, out_fg_o)
    out_seg_o = copy_info(im_obj, out_seg_o)
    seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')
    msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')
    sitk.WriteImage(out_fg_o, msk_fid)
    sitk.WriteImage(out_seg_o, seg_fid)
    print(f'image with id {idx} has finished')


HBox(children=(FloatProgress(value=0.0, max=123.0), HTML(value='')))

image with id 084 has finished
image with id 085 has finished
image with id 086 has finished
image with id 087 has finished
image with id 088 has finished
image with id 089 has finished
image with id 090 has finished
image with id 091 has finished
image with id 092 has finished
image with id 093 has finished
image with id 094 has finished
image with id 095 has finished
image with id 096 has finished
image with id 097 has finished
image with id 098 has finished
image with id 099 has finished
image with id 100 has finished
image with id 101 has finished
image with id 102 has finished
image with id 103 has finished
image with id 104 has finished
image with id 105 has finished
image with id 106 has finished
image with id 107 has finished
image with id 108 has finished
image with id 109 has finished
image with id 110 has finished
image with id 111 has finished
image with id 112 has finished
image with id 113 has finished
image with id 114 has finished
image with id 115 has finished
image wi

In [8]:
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
import random
import os
import IPython.html.widgets as w
import cv2
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import clear_output
from matplotlib import colors
# from config import config
import matplotlib.gridspec as gridspec
from tqdm import tqdm
# from utils.util import merge_masks, merge_contours, get_contours_from_masks


# color for each OAR in the config['roi_names]
# color_dict = {
#     1: {'color': [128, 0, 128], 'name': 'Brachial Plexus'},
#     2: {'color': [187, 255, 187], 'name': 'Brain Stem'},
#     3: {'color': [0, 255, 0], 'name': 'ConstrictorNaris'},
#     4: {'color': [50, 55, 255], 'name': 'Ear-L'},
#     5: {'color': [0, 0, 255], 'name': 'Ear-R'},
#     6: {'color': [155, 128, 0], 'name': 'Eye-L'},
#     7: {'color': [0, 50, 0], 'name': 'Eye-R'},
#     8: {'color': [200, 155, 0], 'name': 'Hypophysis'},
#     9: {'color': [176, 224, 230], 'name': 'Larynx'},
#     10: {'color': [55, 220, 55], 'name': 'Lens L'},
#     11: {'color': [0, 150, 0], 'name': 'Lens R'},
#     12: {'color': [128, 64, 64], 'name': 'Mandible'},
#     13: {'color': [50, 150, 155], 'name': 'Optical Chiasm'},
#     14: {'color': [75, 200, 15], 'name': 'Optical Nerve L'},
#     15: {'color': [0, 255, 255], 'name': 'Optical Nerve R'},
#     16: {'color': [55, 175, 50], 'name': 'Oral Cavity'},
#     17: {'color': [255, 0, 255], 'name': 'Parotid L'},
#     18: {'color': [125, 64, 250], 'name': 'Parotid R'},
#     19: {'color': [255, 0, 128], 'name': 'SmgL'},
#     20: {'color': [0, 128, 128], 'name': 'SmgR'},
#     21: {'color': [255, 255, 0], 'name': 'Spinal Cord'},
#     22: {'color': [255, 128, 0], 'name': 'Sublingual Gland'},
#     23: {'color': [128, 64, 64], 'name': 'Temporal Lobe L'},
#     24: {'color': [90, 128, 0], 'name': 'Temporal Lobe R'},
#     25: {'color': [50, 100, 255], 'name': 'Thyroid'},
#     26: {'color': [255, 0, 100], 'name': 'TMJL'},
#     27: {'color': [200, 55, 50], 'name': 'TMJR'},
#     28: {'color': [255, 100, 57], 'name': 'Trachea'}

# }


color_dict = {
    1: {'color': [128, 0, 128], 'name': 'Large Bowel'},
    2: {'color': [187, 255, 187], 'name': 'Duodenum'},
    3: {'color': [0, 255, 0], 'name': 'Spinal Cord'},
    4: {'color': [50, 55, 255], 'name': 'Liver'},
    5: {'color': [0, 0, 255], 'name': 'Spleen'},
    6: {'color': [155, 128, 0], 'name': 'Small Bowel'},
    7: {'color': [0, 50, 0], 'name': 'Pancreas'},
    8: {'color': [200, 155, 0], 'name': 'Kidney L'},
    9: {'color': [176, 224, 230], 'name': 'Kidney R'},
    10: {'color': [55, 220, 55], 'name': 'Stomach'},
    11: {'color': [0, 150, 0], 'name': 'Gallbladder'},
}


# color for the contour outline in the comparison subplot
color_dict2 = {
    1: {'color': [0, 0, 255], 'name': 'Ground Truth'},
    2: {'color': [255, 255, 0], 'name': 'Model'},
}


def colorbar(mappable):
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    return fig.colorbar(mappable, cax=cax)


def show3Dimg(image, *imgs):
    n_img = 1 + sum([not img is None for img in imgs])
    def fz(k):
        plt.subplot(1, n_img, 1)
        colorbar(plt.imshow(image[k], cmap='gray'))
        
        for i in range(len(imgs)):
            plt.subplot(1, n_img, 2+i)
            colorbar(plt.imshow(imgs[i][k], vmin=0, vmax=30))
        
        plt.show()
    w.interact(fz, k=w.IntSlider(min=0, max=image.shape[0] - 1, step=1, value=0))


def show3Dimg2(image, *masks):
    '''
    Plot contour and mask on original CT image using matplotlib
    image: CT image of dimension 3.
    *masks: usually consists of [mask, contour], then contour 
            would be plot using alpha=1 and mask using alpha 0.5.
    '''
    continuous_update = False
    n_img = 1 + sum([not img is None for img in masks])
    params = {'z': 0, 'level': 0, 'width': 1000, 'show_mask': True}
    z_slider = w.IntSlider(min=0, max=image.shape[0] - 1, step=1, value=params['z'], 
                           continuous_update=continuous_update, description="z")
    level_slider = w.IntSlider(min=-1024, max=1000, step=1, value=params['level'], 
                              continuous_update=continuous_update, description="level")
    width_slider = w.IntSlider(min=-1024, max=2000, step=1, value=params['width'], 
                              continuous_update=continuous_update, description="width")
    mask_checkbox = w.Checkbox(value=True, description='show mask', disabled=False)

    def plot_figure():
        z = params['z']
        level = params['level']
        width = params['width']
        show_mask = params['show_mask']
        
        
        plt.imshow(image[z], cmap='gray', vmin=level - width / 2, vmax=level + width / 2)
            

        if show_mask:
            for i in range(len(masks)):
                mask = masks[i].astype(np.float32)
                mask[mask == 0] = np.nan
                plt.imshow(mask[z], cmap=custom_cmap, alpha=0.5 * (i + 1), vmin=1, vmax=28)
                
        plt.axis('off')
        plt.legend(handles=patches1, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0. )
        
        plt.show()
        
        
    def on_z_value_change(change):
        params['z'] = change.new
        plot_figure()
        
    def on_level_value_change(change):
        params['level'] = change.new
        plot_figure()
        
    def on_width_value_change(change):
        params['width'] = change.new
        plot_figure()
        
    def on_mask_value_change(change):
        params['show_mask'] = change.new
        plot_figure()
    
    display(z_slider, level_slider, width_slider, mask_checkbox)
    
    z_slider.observe(on_z_value_change, names='value')
    level_slider.observe(on_level_value_change, names='value')
    width_slider.observe(on_width_value_change, names='value')
    mask_checkbox.observe(on_mask_value_change, names='value')
    
    plot_figure()


def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)


def show_image_and_mask(img):
    """
    Given CT img, produce interactive jupyter notebook slider across axial slice
    img: [D,H,W] or [D,H,W,3]
    """

    def fz(k):
        plt.imshow(img[k], vmin=img.min(), vmax=img.max() + 1)
        plt.show()
        
    w.interact(fz, k=w.IntSlider(min=0, max=img.shape[0] - 1, step=1, value=0))
    

def draw_one_rect(img, box, color=(0, 0, 255), scale=3, text=''):
    """
    Given one slice of CT scan, draw one box on that slice with rectangle of size scale times.
    img: [H,W,3]
    box: [y,x,r]
    color: RGB, default (0,128,255)
    scale: how big square box relative to the nodule, default 3
    """
    y0, x0, h, w = box
    H, W, _ = img.shape

    h = h * scale
    w = w * scale
    x0 = max(0, x0 - w / 2)
    y0 = max(0, y0 - h / 2)
    x1, y1 = min(W - 1, x0 + w), min(H - 1, y0 + h)
    h = int(h)
    w = int(w)
    x0, x1, y0, y1 = int(x0), int(x1), int(y0), int(y1)

    cv2.rectangle(img, (x0, y0), (x1, y1), color, 0, lineType=4)

    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.2
    thickness = 0
    size = cv2.getTextSize(text, font, font_scale, thickness)[0]
    text_bottom_right = (x1, y1 + size[1])
    cv2.putText(img, text, text_bottom_right, font, font_scale, color, thickness, cv2.LINE_AA)
    return img


def draw_one_bbox(img, box, color, scale, text):
    """
    Given CT scan in numpy with 3 channels, draw one bounded box on each slice within 2x nodule size.
    img: [D,H,W,3]
    box: [z,y,x,r]
    color: RGB
    scale: how big square box relative to the nodule
    """
    D, H, W, _ = img.shape
    z, _, _, d, _, _ = box
    start = max(0, int(z - d / 2))
    end = min(D - 1, int(z + d / 2))
    for z in range(start, end + 1):
        img[z] = draw_one_rect(img[z], box[[1, 2, 4, 5]], color=color, text=text, scale=scale)
    return img


def draw_bboxes(img, bboxes, color=(0, 128, 255), scale=2):
    """
    Given CT scan in numpy, draw bounded boxes on each slice up within 2x nodule size.
    img: [D,H,W] or [D,H,W,3]
    bboxes: [num, 4] or [num, 5] with dimension 0 probability
    color: RGB, default (0,128,255)
    scale: how big square box relative to the nodule, default 2
    """
    assert img.ndim == 3 or img.ndim == 4
    if img.ndim == 3:
        img = np.repeat(img[:, :, :, np.newaxis], 3, axis=3)

    num = int(len(bboxes))
    colors = get_cmap(num)
    for i, box in enumerate(bboxes):
        if len(box) == 6:
            img = draw_one_bbox(img, box, list(colors(i))[:-1], scale, '')
        elif len(box) == 7:
            p = box[0]
            text = '%.2f' % (p)
            img = draw_one_bbox(img, box[1:], list(colors(i))[:-1], scale, text)
        else:
            raise NotImplementedError

    return img


def draw_ground_truth(img, bboxes, color=(255,0,0), scale=3):
    return draw_bboxes(img, bboxes, color=color, scale=scale)


def draw_rpn(img, bboxes, color=(0, 255, 0), scale=3):
    return draw_bboxes(img, bboxes, color=color, scale=scale)


def draw_rcnn(img, bboxes, color=(255, 255, 255)):
    return draw_bboxes(img, bboxes, color=color)


def draw_points(img, points, alpha=0.5):
    """
    Given CT scan in numpy, draw points on the original img
    img: [D,H,W] or [D,H,W,3]
    points: [D, H, W] indicating the class each pixel belongs to
    """
    assert img.ndim == 3 or img.ndim == 4
    if img.ndim == 3:
        img = np.repeat(img[:, :, :, np.newaxis], 3, axis=3)
        
    num = int(points.max())
    colors = get_cmap(num)
    for i in range(1, num + 1):
        img[points == i] = img[points == i] * (1 - alpha) + np.array(list(colors(i))[:-1]) * alpha
        
    return img


def draw_text(img, text, color=(1., 1., 1.)):
    """
    Given img, draw text at the top right of each slice of img
    img: [D, H, W, 4]
    text: str type, the text you want to put
    """
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
    thickness = 0
    size = cv2.getTextSize(text, font, font_scale, thickness)[0]
    text_bottom_right = (0, 5 + size[1])

    for i in range(len(img)):
        cv2.putText(img[i], text, text_bottom_right, font, font_scale, color, thickness, cv2.LINE_AA)

    return img


def draw_gt(img, mask):
    img = img.copy()
    img = draw_points(img, mask)
    img = draw_text(img, 'Ground Truth')

    return img


def draw_pred(img, mask):
    img = img.copy()
    img = draw_points(img, mask)
    img = draw_text(img, 'Prediction')
    
    return img


def generate_image_anim(img, interval=200, save_path=None):
    """
    Given CT img, return an animation across axial slice
    img: [D,H,W] or [D,H,W,3]
    interval: interval between each slice, default 200
    save_path: path to save the animation if not None, default None

    return: matplotlib.animation.Animation
    """
    fig = plt.figure()
    ims = []
    for i in range(len(img)):
        im = plt.imshow(img[i], animated=True)
        ims.append([im])
    anim = animation.ArtistAnimation(fig, ims, interval=interval, blit=True,
                                     repeat_delay=1000)
    if save_path:
        Writer = animation.writers['ffmpeg']
        writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
        anim.save(save_path)

    return anim


def plot_compare_figure(image, gt, pred, params, save_dir, show_all_legend, fmt=['png']):
    interpolation = 'spline36'
    z = params['z']
    level = params['level']
    width = params['width']
    show_mask = params['show_mask']
    
    # Ignore the start and end of z slice
    start = params['start'][1:]
    end = params['end'][1:]

    #         fig, ax = plt.subplots(n_nod, 3)
    nrow = 1
    ncol = 4
    title_font_size = 10
    gs = gridspec.GridSpec(nrow, ncol,
             wspace=0.01, hspace=0.01, 
             top=0.7, bottom=0.3, 
             left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

    #         plt.subplot(gs[0, 0]).set_title('CT Image', size=title_font_size)
    #         plt.subplot(gs[0, 1]).set_title('Ground Truth', size=title_font_size)
    #         plt.subplot(gs[0, 2]).set_title('Model Prediction', size=title_font_size)

    # CT Image
    ax= plt.subplot(gs[0, 0])

    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image[z], cmap='gray', vmin=level - width / 2, vmax=level + width / 2, interpolation=interpolation)
    ax.text(0.95, 0.95, 'W: {}, L: {}'.format(width, level),
            verticalalignment='bottom', horizontalalignment='right',
            transform=ax.transAxes,
            color='white', fontsize=15)

    c1 = start[0]
    c2 = start[1]
    d1 = end[0] - start[0]
    d2 = end[1] - start[1]
    yx_rect = patches.Rectangle((c2, c1),d2,d1,
                         linewidth=1, edgecolor='white', facecolor='none')
    ax.add_patch(yx_rect)
    
    image = image.copy()[:, start[0]:end[0], start[1]:end[1]]
    gt = [g.copy()[:, start[0]:end[0], start[1]:end[1]] for g in gt]
    pred = [p.copy()[:, start[0]:end[0], start[1]:end[1]] for p in pred]

    # Show ground truth on image
    masks = gt
    ax= plt.subplot(gs[0, 1])

    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image[z], cmap='gray', vmin=level - width / 2, vmax=level + width / 2, interpolation=interpolation)
    if show_mask:
        for i in range(len(masks)):
            mask = masks[i].astype(np.float32)
            mask[mask == 0] = np.nan
            ax.imshow(mask[z], cmap=custom_cmap, alpha=0.5 * (i + 1), vmin=1, vmax=11)

    # Show prediction on image
    masks = pred
    ax= plt.subplot(gs[0, 2])

    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image[z], cmap='gray', vmin=level - width / 2, vmax=level + width / 2, interpolation=interpolation)
    if show_mask:
        for i in range(len(masks)):
            mask = masks[i].astype(np.float32)
            mask[mask == 0] = np.nan
            ax.imshow(mask[z], cmap=custom_cmap, alpha=0.5 * (i + 1), vmin=1, vmax=11)


    # Show gt and prediction comparison
    ax= plt.subplot(gs[0, 3])

    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image[z], cmap='gray', vmin=level - width / 2, vmax=level + width / 2, interpolation=interpolation)

    gt_ctr = gt[1].copy()
    gt_ctr[gt_ctr > 0] = 1
    # Make value for pred_ctr to 2, correponding to the color_map2
    pred_ctr = pred[1].copy()
    pred_ctr[pred_ctr > 0] = 2

    if show_mask:
        gt_ctr = gt_ctr.astype(np.float32)
        gt_ctr[gt_ctr == 0] = np.nan
        pred_ctr = pred_ctr.astype(np.float32)
        pred_ctr[pred_ctr == 0] = np.nan
        ax.imshow(gt_ctr[z], cmap=custom_cmap2, alpha=1, vmin=1, vmax=2)
        ax.imshow(pred_ctr[z], cmap=custom_cmap2, alpha=1, vmin=1, vmax=2)



    #         plt.axis('off')
    legend_properties = {} # {'weight': 'bold'}
    if show_all_legend:
        first_legend = plt.legend(handles=patches1, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0., prop=legend_properties)
    else:
        g = np.unique(gt[0][z])
        p = np.unique(pred[0][z])
        a = set(g).union(p)
        a.remove(0)
        a = list(a)
        p = [patches1[i - 1] for i in a]
        first_legend = plt.legend(handles=p, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0., prop=legend_properties)

    ax = plt.gca().add_artist(first_legend)
    plt.legend(handles=patches2,  bbox_to_anchor=(1.01, 0.2), loc=2, borderaxespad=0., prop=legend_properties)
    
    if 'png' in fmt:
        plt.savefig(os.path.join(save_dir, '{}.png'.format(z)), bbox_inches='tight')
    if 'pdf' in fmt:
        plt.savefig(os.path.join(save_dir, '{}.pdf'.format(z)), bbox_inches='tight')
    #         plt.show()


def show3D_comparison(image, gt, pred, bbox, save_dir='paper_figs/', show_all_legend=True):
    '''
    show 3d comparison plot of ground truth and prediction. 
    Four views: original CT image with zoomed in bbox, ground truth overlay on the image, 
                         prediction overlay on the image, gt and pred contour outline comparison

    image: CT image of dimension 3
    gt: a list of 2 elements, the first is ground truth mask and the second is ground truth contour
    pred: a list of 2 elements, the first is predicted mask and the second is predicted contour
    bbox: [start, end], plot zoomed in region (defined by this param) for view 2 - 4
    '''
    continuous_update = False
#     n_img = 1 + sum([not img is None for img in masks])
    start, end = bbox
    params = {'z': 0, 'level': 0, 'width': 1000, 'show_mask': True, 'start': start, 'end': end}
    z_slider = w.IntSlider(min=0, max=image.shape[0] - 1, step=1, value=params['z'], 
                           continuous_update=continuous_update, description="z")
    level_slider = w.IntSlider(min=-1024, max=1000, step=1, value=params['level'], 
                              continuous_update=continuous_update, description="level")
    width_slider = w.IntSlider(min=-1024, max=2000, step=1, value=params['width'], 
                              continuous_update=continuous_update, description="width")
    mask_checkbox = w.Checkbox(value=True, description='show mask', disabled=False)
    
    N = 3
    plt.rcParams['legend.markerscale'] = 0.2
    fig, axes = plt.subplots(1, N)
    plt.subplots_adjust(hspace=0)
    for i in range(N):
        axes[i].set_axis_off()
        
    def on_z_value_change(change):
        params['z'] = change.new
        plot_compare_figure(image, gt, pred, params, save_dir, show_all_legend)
        
    def on_level_value_change(change):
        params['level'] = change.new
        plot_compare_figure(image, gt, pred, params, save_dir, show_all_legend)
        
    def on_width_value_change(change):
        params['width'] = change.new
        plot_compare_figure(image, gt, pred, params, save_dir, show_all_legend)
        
    def on_mask_value_change(change):
        params['show_mask'] = change.new
        plot_compare_figure(image, gt, pred, params, save_dir, show_all_legend)
    
    display(z_slider, level_slider, width_slider, mask_checkbox)
    
    z_slider.observe(on_z_value_change, names='value')
    level_slider.observe(on_level_value_change, names='value')
    width_slider.observe(on_width_value_change, names='value')
    mask_checkbox.observe(on_mask_value_change, names='value')
    
    plot_compare_figure(image, gt, pred, params, save_dir, show_all_legend)


def save_one_slice(image, masks, params, save_dir, show_all_legend):
    plt.figure()
    interpolation = 'spline36'
    z = params['z']
    level = params['level']
    width = params['width']
    
    plt.imshow(image[z], cmap='gray', vmin=level - width / 2, vmax=level + width / 2, interpolation=interpolation)

    for i in range(len(masks)):
        mask = masks[i].astype(np.float32)
        mask[mask == 0] = np.nan
        plt.imshow(mask[z], cmap=custom_cmap, alpha=0.5 * (i + 1), vmin=1, vmax=28)
            
    plt.axis('off')

    legend_properties = {} # {'weight': 'bold'}
    if show_all_legend:
        plt.legend(handles=patches1, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0., prop=legend_properties)
    else:
        p = np.unique(masks[0][z])
        a = set(p)
        a.remove(0)
        a = list(a)
        p = [patches1[i - 1] for i in a]
        plt.legend(handles=p, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0., prop=legend_properties)
    
    plt.savefig(os.path.join(save_dir, '{}.png'.format(z)), bbox_inches='tight')
    plt.close()


def generate_image_pngs(raw_img, raw_mask, save_dir, show_all_legend=False):
    '''
    Generate image pngs with applied mask for each OAR, slice by slice,
    and save the pngs into folder save_dir
    '''
    plt.rcParams['figure.figsize'] = (16, 12)
    params = {'z': 0, 'level': 35, 'width': 400}
    merged_mask = merge_masks(raw_mask)
    merged_ctr = merge_contours(get_contours_from_masks(raw_mask))

    for i in tqdm(range(len(raw_img)), desc='Total'):
        params['z'] = i
        save_one_slice(raw_img, [merged_mask, merged_ctr], 
                       params, save_dir, show_all_legend=show_all_legend)

