In [None]:

import pandas as pd
import numpy as np
import pickle
import os
import json

from class_modalities.datasets import DataManager

from class_modalities.transforms import LoadNifti, Compose, Roi2Mask_probs, ResampleReshapeAlign, Sitk2Numpy, ScaleIntensityRanged
import SimpleITK as sitk

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [None]:
def display_info(img):
    print('img information :')
    print('\t Origin    :', img.GetOrigin())
    print('\t Size      :', img.GetSize())
    print('\t Spacing   :', img.GetSpacing())
    print('\t Direction :', img.GetDirection())

In [None]:
from mrcnn import visualize
from mrcnn.visualize import display_images

import colorsys
import random

seed_color = 0
def random_colors(N, bright=True):
    """
    Generate random colors.
    To get visually distinct colors, generate them in HSV space then
    convert to RGB.
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.Random(seed_color).shuffle(colors)
    return colors

def generate_bbox(img):
    # generate bounding box from the segmentation
    bbox = []
    for i in range(img.shape[2]):
        indexes = np.where(img[:, :, i])
        y1, y2 = min(indexes[0]), max(indexes[0])
        x1, x2 = min(indexes[1]), max(indexes[1])
        bbox.append([y1, x1, y2, x2])

    return np.array(bbox)

def plot_mip_whole_scan(mip_pet, mip_mask_true, ax=None):
#     # MIP
#     mip_pet = np.max(pet_array, axis=axis)
#     mip_mask_true = np.max(mask_array axis=axis)


    #
    # prepapre img for plotting 
    mip_pet[mip_pet>5.0] = 5.0
    mip_pet = (255*mip_pet/5.0).astype(int)
    # convert to RBG + MIP
    image = mip_pet[:, :, None] * np.ones(3, dtype=int)[None, None, :]

    # add dims
    mip_mask_true = np.expand_dims(mip_mask_true, axis=-1)

    # generate bbox from mask
    bbox = generate_bbox(mip_mask_true)

    colors = random_colors(bbox.shape[0])

    # plot the result
    class_ids, class_names = np.ones(bbox.shape[0], dtype=int), ["", ""] # ['background', 'lymphoma']
    visualize.display_instances(image, bbox, mip_mask_true, class_ids, class_names, show_bbox=False,
                               ax=ax)
#         axes[ii].set_title(titles[threshold])

    # plt.show()
    

In [None]:
from scipy.stats import truncnorm, uniform
from skimage import filters

def relative_seg(roi):
    
    lower, upper = 0.33, 0.60
    mu, std = 0.42, 0.06

    a, b = (lower - mu) / std, (upper - mu) / std

    return truncnorm.cdf(roi/np.max(roi), a, b, loc=mu, scale=std)

def absolute_seg(roi):
    
    lower, upper = 2.0, 4.0
    mu, std = 2.5, 0.5

    a, b = (lower - mu) / std, (upper - mu) / std

    return truncnorm.cdf(roi, a, b, loc=mu, scale=std)


def otsu_seg(roi):
    t = filters.threshold_otsu(roi)
    return np.where(roi > t, 1, 0)


In [None]:
import os
import glob

def get_dataset():
    data_path = '/media/sf_Deep_Oncopole/data/raw_data'
    
    pet_path = sorted(glob.glob(os.path.join(data_path, '*_nifti_PT.nii')))
    ct_path = sorted(glob.glob(os.path.join(data_path,'*_nifti_CT.nii')))
    mask_path = sorted(glob.glob(os.path.join(data_path,'*_nifti_mask.nii')))

    dataset = []
    for i in range(len(pet_path)):
        dataset.append({'pet_img': pet_path[i],
                       'ct_img': ct_path[i], 
                       'mask_img': mask_path[i]})

    return dataset

dataset = get_dataset()

In [None]:
print(len(dataset))
print(dataset[0])

In [None]:
image_shape = (256, 128, 128)  # (z, y, x)
voxel_spacing = (4.0, 4.8, 4.8) # (z, y, x)

In [None]:
target_shape = image_shape[::-1]  # (z, y, x) to (x, y, z)
target_voxel_spacing = voxel_spacing[::-1]

transformers2 = Compose([  # read img + meta info
    LoadNifti(keys=("pet_img", "ct_img", "mask_img")),
    Roi2Mask_probs(keys=('pet_img', 'mask_img'),
                   method='absolute', new_key_name='mask_img_absolute'),
    Roi2Mask_probs(keys=('pet_img', 'mask_img'),
                   method='relative', new_key_name='mask_img_relative'),
    Roi2Mask_probs(keys=('pet_img', 'mask_img'),
                   method='otsu', new_key_name='mask_img_otsu'),
#     ResampleReshapeAlign(target_shape, target_voxel_spacing,
#                          keys=['pet_img', "ct_img",
#                                'mask_img_absolute', 'mask_img_relative', 'mask_img_otsu'],
#                          origin='head', origin_key='pet_img',
#                          interpolator={'pet_img': sitk.sitkLinear,
#                                        'ct_img': sitk.sitkLinear,
#                                        'mask_img': sitk.sitkLinear,
#                                        'mask_img_absolute': sitk.sitkLinear,
#                                        'mask_img_relative': sitk.sitkLinear,
#                                        'mask_img_otsu': sitk.sitkLinear},
#                          default_value={'pet_img': 0.0,
#                                         'ct_img': -1000.0,
#                                         'mask_img': 0,
#                                         'mask_img_absolute': 0.0,
#                                         'mask_img_relative': 0.0,
#                                         'mask_img_otsu': 0.0}),
    Sitk2Numpy(keys=['pet_img', 'ct_img',
                     'mask_img_absolute', 'mask_img_relative', 'mask_img_otsu'])
])

In [None]:
img_path = dataset[0]

result = transformers2(img_path)
print(result.keys())

In [None]:
display_info(result['mask_img'])

In [None]:
roi_img = sitk.GetArrayFromImage(result['mask_img'])

In [None]:
np.unique(roi_img)

In [None]:
roi_img.shape

In [None]:
result['pet_img'].shape

In [None]:
np.where(roi_img[0])

In [None]:
idx_roi = 0

x_min, x_max = min(np.where(roi_img[idx_roi])[0]), max(np.where(roi_img[idx_roi])[0])
y_min, y_max = min(np.where(roi_img[idx_roi])[1]), max(np.where(roi_img[idx_roi])[1])
z_min, z_max = min(np.where(roi_img[idx_roi])[2]), max(np.where(roi_img[idx_roi])[2])

In [None]:
pet_roi = result['pet_img'][x_min:x_max, y_min:y_max, z_min:z_max]
mask_rois = dict()
for key in ['mask_img_absolute', 'mask_img_relative', 'mask_img_otsu']:
    mask_rois[key] = result[key][x_min:x_max, y_min:y_max, z_min:z_max]



In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
keys = ['mask_img_absolute', 'mask_img_relative', 'mask_img_otsu']
axis = 2  # x_axis
n_rows = len(keys) + 2
# print(result['pet_img'].shape)

# fig = plt.figure(figsize=(12, 8))
fig, axes = plt.subplots(n_rows, 2, figsize=(16, 16))

# PLOT HISTOGRAM and probs

hist_roi = result['pet_img'][np.where(roi_img[idx_roi]>0)]

sns.distplot(hist_roi, ax=axes[0, 1])
axes[0, 1].set_title('PET hist')
# axes[0, 1].set_xlabel('SUV')


hist_roi = np.array(sorted(hist_roi))
fc_dict = {'mask_img_absolute': absolute_seg, 
           'mask_img_relative': relative_seg,
           'mask_img_otsu': otsu_seg}

mean_probs = []
for ii, key in enumerate(keys):
    probs = fc_dict[key](hist_roi)
    mean_probs.append(probs)
    axes[ii+1, 1].plot(hist_roi, probs)
    axes[ii+1, 1].set_title(key)
    
ii = len(keys)
mean_probs = np.mean(mean_probs, axis=0)
axes[ii+1, 1].plot(hist_roi, mean_probs)
axes[ii+1, 1].set_title('mean')






# PLOT image & mask


# ax1 = fig.add_subplot(n_rows, 1, 1)
im1 = axes[0, 0].imshow(np.max(pet_roi, axis=axis), cmap='hot')

divider = make_axes_locatable(axes[0, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im1, cax=cax, orientation='vertical')
axes[0, 0].set_title('MIP PET')

for ii, key in enumerate(keys):
#     ax2 = fig.add_subplot(2, 1, 2)
    im2 = axes[ii+1, 0].imshow(np.max(mask_rois[key], axis=axis))

    divider = make_axes_locatable(axes[ii+1, 0])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im2, cax=cax, orientation='vertical')
    axes[ii+1, 0].set_title(key)
    
ii = len(keys)
mean_mask = np.mean([mask_rois[key] for key in keys], axis=0)
im2 = axes[ii+1, 0].imshow(np.max(mean_mask, axis=axis))

divider = make_axes_locatable(axes[ii+1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im2, cax=cax, orientation='vertical')
axes[ii+1, 0].set_title('mean_mask')



plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95,
                   hspace=0.25, wspace=0.35)
plt.show()

In [None]:
len(np.where(np.zeros(100))[0])

In [None]:
def plot_seg_hist(result, idx_roi, axis=1):
    """
    axis = (z, y, x)
    """
    axis_list = axis if isinstance(axis, (list, tuple)) else [axis]
    
    if len(np.where(roi_img[idx_roi])[0]) == 0 :
        print('empty R.O.I')
        return None
    
    x_min, x_max = min(np.where(roi_img[idx_roi])[0]), max(np.where(roi_img[idx_roi])[0]) + 1
    y_min, y_max = min(np.where(roi_img[idx_roi])[1]), max(np.where(roi_img[idx_roi])[1]) + 1
    z_min, z_max = min(np.where(roi_img[idx_roi])[2]), max(np.where(roi_img[idx_roi])[2]) + 1
    
    pet_roi = result['pet_img'].copy()
    pet_roi[np.where(roi_img[idx_roi]==0)] = 0.0
    pet_roi = pet_roi[x_min:x_max, y_min:y_max, z_min:z_max]
#     pet_roi = result['pet_img'][x_min:x_max, y_min:y_max, z_min:z_max]
    
    mask_rois = dict()
    for key in ['mask_img_absolute', 'mask_img_relative', 'mask_img_otsu']:
        mask_rois[key] = result[key][x_min:x_max, y_min:y_max, z_min:z_max]
    
    
    
    keys = ['mask_img_absolute', 'mask_img_relative', 'mask_img_otsu']
    n_rows = len(keys) + 2 # PET + method + mean method
    n_cols = len(axis_list) + 1 # MIP + hist
    # print(result['pet_img'].shape)

    # fig = plt.figure(figsize=(12, 8))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 16))

    # PLOT HISTOGRAM and probs

    hist_roi = result['pet_img'][np.where(roi_img[idx_roi]>0)]

    sns.distplot(hist_roi, ax=axes[0, n_cols-1])
    axes[0, n_cols-1].set_title('PET hist')
    # axes[0, 1].set_xlabel('SUV')


    hist_roi = np.array(sorted(hist_roi))
    fc_dict = {'mask_img_absolute': absolute_seg, 
               'mask_img_relative': relative_seg,
               'mask_img_otsu': otsu_seg}
    titles_dict = {'mask_img_absolute': 'threshold ~ 2.5 SUV', 
                   'mask_img_relative': 'threshold ~ 42 % SUV max',
                   'mask_img_otsu': 'threshold otsu'}

    mean_probs = []
    for ii, key in enumerate(keys):
        probs = fc_dict[key](hist_roi)
        mean_probs.append(probs)
        axes[ii+1, n_cols-1].plot(hist_roi, probs)
        axes[ii+1, n_cols-1].set_title(key)

    ii = len(keys)
    mean_probs = np.mean(mean_probs, axis=0)
    axes[ii+1, n_cols-1].plot(hist_roi, mean_probs)
    axes[ii+1, n_cols-1].set_title('average')


    # PLOT image & mask
    
    for ncol, axis in enumerate(axis_list):

        # ax1 = fig.add_subplot(n_rows, 1, 1)
        im1 = axes[0, ncol].imshow(np.max(pet_roi, axis=axis), cmap='hot')

        divider = make_axes_locatable(axes[0, ncol])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im1, cax=cax, orientation='vertical')
        axes[0, ncol].set_title('axis {} MIP \nPET'.format(axis))
        
        mean_mask = []
        for ii, key in enumerate(keys):
        #     ax2 = fig.add_subplot(2, 1, 2)
        
            hist_roi = result['pet_img'][np.where(roi_img[idx_roi]>0)]
            probs = fc_dict[key](hist_roi)
            
            mask = np.zeros(roi_img[idx_roi].shape)
#             print(mask.shape)
            mask[np.where(roi_img[idx_roi]>0)] = probs
            mask = mask[x_min:x_max, y_min:y_max, z_min:z_max]
            mean_mask.append(mask)
            
            im2 = axes[ii+1, ncol].imshow(np.max(mask, axis=axis))

            divider = make_axes_locatable(axes[ii+1, ncol])
            cax = divider.append_axes('right', size='5%', pad=0.05)
            fig.colorbar(im2, cax=cax, orientation='vertical')
            axes[ii+1, ncol].set_title(titles_dict[key])

        ii = len(keys)
        # mean_mask = np.mean([mask_rois[key] for key in keys], axis=0)
        mean_mask = np.mean(mean_mask, axis=0)
        im2 = axes[ii+1, ncol].imshow(np.max(mean_mask, axis=axis))

        divider = make_axes_locatable(axes[ii+1, ncol])
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im2, cax=cax, orientation='vertical')
        axes[ii+1, ncol].set_title('average_mask')



    plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95,
                       hspace=0.25, wspace=0.35)
    plt.show()

In [None]:
idx_data = -1

In [None]:
idx_data += 1 # = 8
print(idx_data)
img_path = dataset[idx_data]

result = transformers2(img_path)
print(result.keys())

In [None]:
roi_img = sitk.GetArrayFromImage(result['mask_img'])
if len(roi_img.shape) == 3:
    roi_img = np.expand_dims(roi_img, axis=0)
print(roi_img.shape)

In [None]:
print('n_roi : ', roi_img.shape[0])
for idx_roi in range(roi_img.shape[0]):
    print(idx_roi)
    plot_seg_hist(result, idx_roi, axis=(2, 1, 0))
    
    

In [None]:
for axis in [1, 2]: #  axis = (x, y, z)
    mip_pet = result['pet_img'].copy()
    mip_pet = np.max(np.flip(mip_pet, axis=0), axis=axis)

    keys = ['mask_img_absolute', 'mask_img_relative', 'mask_img_otsu']
    title_dict = {'mask_img_absolute': '~2.5 SUV', 
             'mask_img_relative': '~42% SUV max', 
             'mask_img_otsu': 'otsu'}


    figsize=(16, 16)
    fig, axes = plt.subplots(1, len(keys)+1, figsize=figsize)
    # print(axes)
    masks = []
    for ii, key in enumerate(keys):
    #     print(ii, key)
        mask = result[key].copy()
        masks.append(mask)
        mask = np.round(mask)
        mask = np.max(np.flip(mask, axis=0), axis=axis)

        plot_mip_whole_scan(mip_pet, mask, ax=axes[ii])
        axes[ii].set_title(title_dict[key])
        
    average_mask = np.round(np.mean(masks, axis=0))
    
    mask = np.max(np.flip(average_mask, axis=0), axis=axis)

    plot_mip_whole_scan(mip_pet, mask, ax=axes[ii+1])
    axes[ii+1].set_title('average')

    
    plt.show()

In [None]:
# figsize=(16, 16)
# fig, axes = plt.subplots(1, 3, figsize=figsize)

# plot_mip_whole_scan(mip_pet, mip_mask_true)
# plot_seg(ax=axes[i])

# plt.show()