In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib

import SimpleITK as sitk  

from platipy.imaging.label.utils import get_com

from platipy.imaging.utils.vessel import vessel_spline_generation

from platipy.imaging import ImageVisualiser

from platipy.imaging.registration.utils import smooth_and_resample

# import colorcet as cc

%matplotlib notebook

In [None]:
from birt_utils import (
    interpolate_image,
    interpolate_histology_lesion_probability,
    generate_sampling_label
)

In [None]:
# Set parameters

contour_fill_hole_mm = 5

input_dir = pathlib.Path("../../1_data/atlas_data/")
case_id_list = sorted([i.name[6:] for i in input_dir.glob("*MRHIST*")])
print(len(case_id_list), case_id_list)

In [None]:
"""
Simplify the images/labels that we propagate
"""

labels_linear = [
    "TUMOUR_PROBABILITY_GRADE_2+2",
    "TUMOUR_PROBABILITY_GRADE_3+2",
    "TUMOUR_PROBABILITY_GRADE_3+3",
    "TUMOUR_PROBABILITY_GRADE_3+4",
    "TUMOUR_PROBABILITY_GRADE_4+3",
    "TUMOUR_PROBABILITY_GRADE_4+4",
    "TUMOUR_PROBABILITY_GRADE_4+5",
    "TUMOUR_PROBABILITY_GRADE_5+4",
    "TUMOUR_PROBABILITY_GRADE_5+5",
]

labels_nn = [
    "CONTOUR_PROSTATE",
    "CONTOUR_PZ",
    "CONTOUR_URETHRA",
    "LABEL_HISTOLOGY",
    "LABEL_SAMPLING"
]

images_bspline = [
    "MRI_T2W_2D",
]

images_linear = [
    "CELL_DENSITY_MAP",
]

images_nn = [
    "HISTOLOGY"
]

data_names = labels_linear + labels_nn + images_linear + images_nn

In [None]:
vals = []

for atlas_id in case_id_list:
    
    im = sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "LABELS" / f"MRHIST{atlas_id}_LABEL_HISTOLOGY.nii.gz").as_posix() )
    new_vals = np.unique(sitk.GetArrayViewFromImage(im))
    
    print(atlas_id, new_vals)
    
    vals += list(new_vals)

In [None]:
np.unique(vals)

In [None]:
"""
Read in data
"""
hist_value_2p2 = 64
hist_value_3p2 = 96
hist_value_3p3 = 128
hist_value_3p4 = 160
hist_value_4p3 = 192
hist_value_4p4 = 224
hist_value_4p5 = 234
hist_value_5p4 = 244
hist_value_5p5 = 255

atlas_set = {}

for atlas_id in case_id_list:
    print(atlas_id, end=" | ")
    atlas_set[atlas_id] = {}
    atlas_set[atlas_id]["ORIGINAL"] = {}       
        
    # Read MRI
    atlas_set[atlas_id]["ORIGINAL"]['MRI_T2W_2D'] = sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "IMAGES" / f"MRHIST{atlas_id}_MRI_T2W_2D.nii.gz").as_posix() )
    
    # Resampling functions
    g_nn = lambda x: sitk.Resample(x, atlas_set[atlas_id]["ORIGINAL"]['MRI_T2W_2D'], sitk.Transform(), sitk.sitkNearestNeighbor)
    g_linear = lambda x: sitk.Resample(x, atlas_set[atlas_id]["ORIGINAL"]['MRI_T2W_2D'], sitk.Transform(), sitk.sitkLinear)
    
    # Read cell density and histology
    atlas_set[atlas_id]["ORIGINAL"]['CELL_DENSITY_MAP']   = g_linear( sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "IMAGES" / f"MRHIST{atlas_id}_CELL_DENSITY_MAP.nii.gz").as_posix() ) )
    atlas_set[atlas_id]["ORIGINAL"]['HISTOLOGY'] = g_nn( sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "IMAGES" / f"MRHIST{atlas_id}_HISTOLOGY.nii.gz").as_posix() ) )
    
    # Read whole prostate contour
    atlas_set[atlas_id]["ORIGINAL"]['CONTOUR_PROSTATE'] = g_nn( sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "LABELS" / f"MRHIST{atlas_id}_CONTOUR_PROSTATE.nii.gz").as_posix() ) )>0
    
    # Fill holes
    contour_fillhole_img = [int(contour_fill_hole_mm/i) for i in atlas_set[atlas_id]["ORIGINAL"]['MRI_T2W_2D'].GetSpacing()]
    atlas_set[atlas_id]["ORIGINAL"]['CONTOUR_PROSTATE'] = sitk.BinaryMorphologicalClosing(atlas_set[atlas_id]["ORIGINAL"]['CONTOUR_PROSTATE'], contour_fillhole_img)
    
    # Masking function
    mask_to_prostate = lambda x: sitk.Mask(x, atlas_set[atlas_id]["ORIGINAL"]['CONTOUR_PROSTATE'])
    
    # Read in PZ and urethtra contours (and mask)
    atlas_set[atlas_id]["ORIGINAL"]['CONTOUR_PZ'] = mask_to_prostate (g_nn( sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "LABELS" / f"MRHIST{atlas_id}_CONTOUR_PZ_INTERP.nii.gz").as_posix() ) ) )
    atlas_set[atlas_id]["ORIGINAL"]['CONTOUR_URETHRA']  = mask_to_prostate (g_nn( sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "LABELS" / f"MRHIST{atlas_id}_CONTOUR_URETHRA.nii.gz").as_posix() ) ) )
    
    # Read in histology labels (tumour annotation)
    atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] = g_nn( sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / "LABELS" / f"MRHIST{atlas_id}_LABEL_HISTOLOGY.nii.gz").as_posix() ) ) 

    # Extract out individual labels
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_2+2"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_2p2))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_3+2"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_3p2))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_3+3"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_3p3))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_3+4"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_3p4))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_4+3"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_4p3))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_4+4"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_4p4))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_4+5"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_4p5))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_5+4"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_5p4))
    atlas_set[atlas_id]["ORIGINAL"]["TUMOUR_PROBABILITY_GRADE_5+5"] = mask_to_prostate(interpolate_histology_lesion_probability(atlas_set[atlas_id]["ORIGINAL"]['LABEL_HISTOLOGY'] == hist_value_5p5))
    
    # Generate sampling label
    atlas_set[atlas_id]["ORIGINAL"]["LABEL_SAMPLING"] = mask_to_prostate( generate_sampling_label(atlas_set[atlas_id]["ORIGINAL"]['HISTOLOGY']) )

In [None]:
"""
Interpolate missing cell density/histology images
"""

for atlas_id in atlas_set:
    atlas_set[atlas_id]["ORIGINAL"]['CELL_DENSITY_MAP'] = interpolate_image(
        sitk.GrayscaleFillhole(
            sitk.Cast(
                atlas_set[atlas_id]["ORIGINAL"]['CELL_DENSITY_MAP'],
                sitk.sitkFloat32
            )
        )
    )
    
    atlas_set[atlas_id]["ORIGINAL"]['HISTOLOGY'] = interpolate_image(
        sitk.Cast(
            atlas_set[atlas_id]["ORIGINAL"]['HISTOLOGY'],
            sitk.sitkVectorFloat32
        )
    )

In [None]:
"""
Resample to 0.8mm (isotropic) voxel size
"""

f_nn = lambda x: smooth_and_resample(x, isotropic_voxel_size_mm = 0.8, interpolator=sitk.sitkNearestNeighbor)
f_linear = lambda x: smooth_and_resample(x, isotropic_voxel_size_mm = 0.8, interpolator=sitk.sitkLinear)
f_bspline = lambda x: smooth_and_resample(x, isotropic_voxel_size_mm = 0.8, interpolator=sitk.sitkBSpline)

for atlas_id in atlas_set:
    
    atlas_set[atlas_id]['RESAMPLED'] = {}
        
    for label_name in labels_linear + images_linear:
        atlas_set[atlas_id]['RESAMPLED'][label_name]   = f_linear( atlas_set[atlas_id]['ORIGINAL'][label_name])
    
    for label_name in images_bspline:
        atlas_set[atlas_id]['RESAMPLED'][label_name]   = f_bspline( atlas_set[atlas_id]['ORIGINAL'][label_name])
        
    for label_name in labels_nn + images_nn:
        atlas_set[atlas_id]['RESAMPLED'][label_name]   = f_nn( atlas_set[atlas_id]['ORIGINAL'][label_name])
        
    # Memory saver
    atlas_set[atlas_id]["ORIGINAL"] = None

In [None]:
"""
Write atlas data
"""

for atlas_id in list(atlas_set.keys()):
    
    print(atlas_id, end=" | ")
    
    output_dir = pathlib.Path(f"../1_processing/ATLAS_DATA_PROCESSED/MRHIST{atlas_id}")
    (output_dir / "IMAGES").mkdir(exist_ok=True, parents=True)
    (output_dir / "LABELS").mkdir(exist_ok=True, parents=True)
    
    for label_name in labels_linear + labels_nn:
        sitk.WriteImage(atlas_set[atlas_id]["RESAMPLED"][label_name], str(output_dir / "LABELS" / f"MRHIST{atlas_id}_{label_name}.nii.gz"))
        
    for img_name in images_bspline + images_linear + images_nn:
        if "CELL_DENSITY" in img_name:
            sitk.WriteImage((8000/255 * atlas_set[atlas_id]["RESAMPLED"]['CELL_DENSITY_MAP'])**1.5, str(output_dir / "IMAGES" / f"MRHIST{atlas_id}_{img_name}.nii.gz"))
        else:
            sitk.WriteImage(atlas_set[atlas_id]["RESAMPLED"][img_name], str(output_dir / "IMAGES" / f"MRHIST{atlas_id}_{img_name}.nii.gz"))
        
    """
    Generate some figures to check data integrity
    """
    
    figure_dir = pathlib.Path(f"../1_processing/FIGURES_PROCESSING")
    figure_dir.mkdir(exist_ok=True, parents=True)
    
    # 1. Contour check
    vis = ImageVisualiser(atlas_set[atlas_id]["RESAMPLED"]['MRI_T2W_2D'], cut=get_com(atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_PZ']), figure_size_in=6, window=[0,1200])
    vis.add_contour({
        'WG':atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_PROSTATE'],
        'PZ':atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_PZ'],
        'U':atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_URETHRA'],
    }, colormap=plt.cm.cool)
    fig = vis.show()
    fig.savefig(figure_dir / f"MRHIST{atlas_id}_0_CONTOURS.jpeg", dpi = 300)
    
    # 2. CD check
    vis = ImageVisualiser(atlas_set[atlas_id]["RESAMPLED"]['MRI_T2W_2D'], cut=get_com(atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_PZ']), figure_size_in=6, window=[0,1200])
    vis.add_scalar_overlay((8000/255 * atlas_set[atlas_id]["RESAMPLED"]['CELL_DENSITY_MAP'])**1.5, min_value=0, max_value=200000, name='Cell density [mm'+r'$^{-3}$'+']', colormap=plt.cm.gnuplot2, alpha=1)
    fig = vis.show()
    fig.savefig(figure_dir / f"MRHIST{atlas_id}_1_CELLDENSITY.jpeg", dpi = 300)
    
    # 3. Histology
    vis = ImageVisualiser(atlas_set[atlas_id]["RESAMPLED"]['HISTOLOGY'], cut=get_com(atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_PZ']), figure_size_in=6)
    vis.add_contour({
        'SAMPLE (HALF)':atlas_set[atlas_id]["RESAMPLED"]['LABEL_SAMPLING']>=0.5,
        'SAMPLE (FULL)':atlas_set[atlas_id]["RESAMPLED"]['LABEL_SAMPLING']<=0.5,
    }, colormap=plt.cm.cool)
    fig = vis.show()
    fig.savefig(figure_dir / f"MRHIST{atlas_id}_2_HISTOLOGY.jpeg", dpi = 300)
    
    # 4. Histology annotations
    vis = ImageVisualiser(atlas_set[atlas_id]["RESAMPLED"]['MRI_T2W_2D'], figure_size_in=6, window=[0,1], projection="median")
    ctr_dict = {
        label[-3:]:atlas_set[atlas_id]["RESAMPLED"][label]
        for label in labels_linear
    }
    vis.add_contour(ctr_dict, colormap=plt.cm.jet)
    fig = vis.show()
    fig.savefig(figure_dir / f"MRHIST{atlas_id}_3_ANNOTATIONS.jpeg", dpi = 300)

    # Close
    plt.close("all")