In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pathlib

import SimpleITK as sitk

from skimage import measure

from platipy.imaging.registration.linear import alignment_registration,linear_registration
from platipy.imaging.registration.deformable import fast_symmetric_forces_demons_registration
from platipy.imaging.registration.utils import (
    apply_transform,
    smooth_and_resample,
    convert_mask_to_reg_structure
)    

from platipy.imaging.label.utils import get_com

from platipy.imaging.utils.vessel import vessel_spline_generation

from platipy.imaging import ImageVisualiser

# import colorcet as cc

%matplotlib notebook

In [None]:
# Set parameters
input_dir = pathlib.Path("../1_processing/ATLAS_DATA_PROCESSED/")
case_id_list = sorted([i.name[6:] for i in input_dir.glob("*MRHIST*")])
print(len(case_id_list), case_id_list)

In [None]:
"""
Simplify the images/labels that we propagate
"""

labels_linear = [
    "TUMOUR_PROBABILITY_GRADE_2+2",
    "TUMOUR_PROBABILITY_GRADE_3+2",
    "TUMOUR_PROBABILITY_GRADE_3+3",
    "TUMOUR_PROBABILITY_GRADE_3+4",
    "TUMOUR_PROBABILITY_GRADE_4+3",
    "TUMOUR_PROBABILITY_GRADE_4+4",
    "TUMOUR_PROBABILITY_GRADE_4+5",
    "TUMOUR_PROBABILITY_GRADE_5+4",
    "TUMOUR_PROBABILITY_GRADE_5+5",
]

labels_nn = [
    "CONTOUR_PROSTATE",
    "CONTOUR_PZ",
    "CONTOUR_URETHRA",
    "LABEL_HISTOLOGY",
    "LABEL_SAMPLING"
]

images_bspline = [
    "MRI_T2W_2D",
]

images_linear = [
    "CELL_DENSITY_MAP",
]

images_nn = [
    "HISTOLOGY"
]

data_names = labels_linear + labels_nn + images_linear + images_nn

In [None]:
"""
Read in data
"""
atlas_set = {}

atlas_id_list = case_id_list[:]

for atlas_id in atlas_id_list:
    
    print(atlas_id, end=" | ")
    
    atlas_set[atlas_id] = {}

    atlas_set[atlas_id]["RESAMPLED"] = {}

    for image_name in images_linear + images_nn + images_bspline:
        atlas_set[atlas_id]["RESAMPLED"][image_name]   = sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / f"IMAGES" / f"MRHIST{atlas_id}_{image_name}.nii.gz").as_posix() )

    for label_name in labels_linear + labels_nn:
        atlas_set[atlas_id]["RESAMPLED"][label_name]   = sitk.ReadImage( (input_dir / f"MRHIST{atlas_id}" / f"LABELS" / f"MRHIST{atlas_id}_{label_name}.nii.gz").as_posix() )

In [None]:
"""
Find the mean volume prostate to serve as a reference
"""

f_vol = lambda x: np.product(x.GetSpacing()) * np.sum(sitk.GetArrayViewFromImage(x)>=1) / 1000

volume_dict = {
    atlas_id: f_vol(atlas_set[atlas_id]['RESAMPLED']['CONTOUR_PROSTATE'])
    for atlas_id in atlas_id_list
}

In [None]:
fig, ax = plt.subplots(1,1)

counts, bins, _ = ax.hist(volume_dict.values(), bins=18, histtype='stepfilled', color='#99bbff', )

bins_c = (bins[1:] + bins[:-1])/2
mean_vol = np.mean(list(volume_dict.values()))
median_vol = np.median(list(volume_dict.values()))

closest_match = counts[np.argmin(np.abs(median_vol - bins_c))]
ax.plot([median_vol, median_vol], [0, closest_match], label=f"Median = {median_vol:.2f} cm" + r"$^3$", lw=3)

closest_match = counts[np.argmin(np.abs(mean_vol - bins_c))]
ax.plot([mean_vol, mean_vol], [0, closest_match], label=f"Mean = {mean_vol:.2f} cm" + r"$^3$", lw=3)

ax.set_xlabel('Prostate Volume [cm'+r'$^3$' +']')
ax.grid()
ax.set_axisbelow(True)
ax.legend()

fig.show()

sorted_id, sorted_vol = zip(*sorted(volume_dict.items(), key=lambda x:x[1]))

closest_id = sorted_id[np.argmin(np.abs( np.array(sorted_vol) - median_vol ))]
print("Closest match: ", closest_id)

In [None]:
for i,j in zip(sorted_vol, sorted_id):
    print(f"MRHIST {j} prostate volume = {i:.2f} cc")

In [None]:
"""
Set up exemplar
We need to copy the information into the "ALIGNED" data for registration to reference volume (once computed)
"""

exemplar = atlas_set[closest_id]["RESAMPLED"]

atlas_set[closest_id]["ALIGNED"] = {}

for label_name in data_names:
    atlas_set[closest_id]["ALIGNED"][label_name] = exemplar[label_name]

print("Closest match to median prostate volume: MRHIST",closest_id)

In [None]:
"""
Visualise the 'median' prostate
This is used as an exemplar to which every other case is registered
"""

vis = ImageVisualiser(exemplar["MRI_T2W_2D"], cut=get_com(exemplar['CONTOUR_PROSTATE']), window=[0,800], figure_size_in=6)

vis.set_limits_from_label(exemplar['CONTOUR_PROSTATE'], expansion=10)

vis.add_contour({'Prostate (Exemplar)':exemplar['CONTOUR_PROSTATE']}, colormap=plt.cm.summer_r)

fig = vis.show()

In [None]:
# Register stage: 1
# RIGID REGISTRATION

exemplar_image = exemplar["MRI_T2W_2D"]
exemplar_mask = exemplar["CONTOUR_PROSTATE"]

atlas_id_list = case_id_list[:]
atlas_id_list.remove(closest_id)

for atlas_id in atlas_id_list:
    
    # Register the external contours using simple alignment
    atlas_mask = atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_PROSTATE']
    
    _, initial_tfm = linear_registration(
        exemplar_mask,
        atlas_mask,
        reg_method = "Translation",
        shrink_factors =  [6,3],
        smooth_sigmas =  [2,1],
        sampling_rate =  0.75,#0.9,
        final_interp =  sitk.sitkLinear,
        metric =  "mean_squares",
        optimiser =  "gradient_descent_line_search",
        number_of_iterations =  25
    )
    
    # Apply transform
    atlas_set[atlas_id]["ALIGNED"] = {}
    
    for label_name in images_bspline:
        atlas_set[atlas_id]["ALIGNED"][label_name] = apply_transform(
            atlas_set[atlas_id]['RESAMPLED'][label_name],
            reference_image = exemplar_image,
            transform = initial_tfm,
            default_value=0,
            interpolator=sitk.sitkBSpline
        )
        
    for label_name in labels_linear + images_linear:
        atlas_set[atlas_id]["ALIGNED"][label_name] = apply_transform(
            atlas_set[atlas_id]['RESAMPLED'][label_name],
            reference_image = exemplar_image,
            transform = initial_tfm,
            default_value=0,
            interpolator=sitk.sitkLinear
        )
        
    for label_name in labels_nn + images_nn:
        atlas_set[atlas_id]["ALIGNED"][label_name] = apply_transform(
            atlas_set[atlas_id]['RESAMPLED'][label_name],
            reference_image = exemplar_image,
            transform = initial_tfm,
            default_value=0,
            interpolator=sitk.sitkNearestNeighbor
        )
    
    p = pathlib.Path(f"../1_processing/ATLAS_DATA_REGISTERED/MRHIST{atlas_id}")
    (p / "IMAGES_ALIGNED").mkdir(exist_ok=True, parents=True)
    (p / "LABELS_ALIGNED").mkdir(exist_ok=True, parents=True)
    (p / "TRANSFORMS").mkdir(exist_ok=True, parents=True)
    
    for label_name in labels_linear + labels_nn:
        sitk.WriteImage(atlas_set[atlas_id]["ALIGNED"][label_name], str( p / "LABELS_ALIGNED" / f"MRHIST{atlas_id}_{label_name}.nii.gz") )   
    
    for image_name in images_bspline + images_linear + images_nn:
        sitk.WriteImage(atlas_set[atlas_id]["ALIGNED"][image_name], str( p / "IMAGES_ALIGNED" / f"MRHIST{atlas_id}_{image_name}.nii.gz") )
        
    sitk.WriteTransform(initial_tfm, str( p / "TRANSFORMS" / f"MRHIST{atlas_id}_INITIAL_TRANSLATION_TO_MRHIST{closest_id}.tfm") )

    # Visualise and save figure
    vis = ImageVisualiser(atlas_set[atlas_id]["ALIGNED"]["MRI_T2W_2D"], cut=get_com(exemplar['CONTOUR_PROSTATE']), window=[0,800])

    vis.add_contour({
        'Prostate (Exemplar)':exemplar_mask,
        'PZ (Exemplar)':exemplar["CONTOUR_PZ"],
        f'Prostate (MRHIST {atlas_id})':atlas_set[atlas_id]["ALIGNED"]["CONTOUR_PROSTATE"],
        f'PZ (MRHIST {atlas_id})':atlas_set[atlas_id]["ALIGNED"]["CONTOUR_PZ"],
    }, 
    colormap=plt.cm.jet)
    
    fig = vis.show()
    fig.savefig(f"../1_processing/FIGURES_REGISTRATION/MRHIST{atlas_id}_0_INITIAL.png", dpi=300)
    
plt.close('all')

In [None]:
"""
Save data for the exemplar patient
"""
p = pathlib.Path(f"../1_processing/ATLAS_DATA_REGISTERED/MRHIST{closest_id}")
(p / "IMAGES_ALIGNED").mkdir(exist_ok=True, parents=True)
(p / "LABELS_ALIGNED").mkdir(exist_ok=True, parents=True)
(p / "TRANSFORMS").mkdir(exist_ok=True, parents=True)

for label_name in labels_linear + labels_nn:
    sitk.WriteImage(exemplar[label_name], str( p / "LABELS_ALIGNED" / f"MRHIST{closest_id}_{label_name}.nii.gz") )

for image_name in images_bspline + images_linear + images_nn:
    sitk.WriteImage(exemplar[image_name], str( p / "IMAGES_ALIGNED" / f"MRHIST{closest_id}_{image_name}.nii.gz") )

In [None]:
"""
Visualise a sample of the ALIGNED images
"""

example_WG = {i:atlas_set[i]['ALIGNED']['CONTOUR_PROSTATE'] for i in np.random.choice(case_id_list, 30, replace=False)}

In [None]:
vis = ImageVisualiser(exemplar_image, cut=get_com(exemplar['CONTOUR_PROSTATE'])[0], axis='z', window=[0,800], figure_size_in=5)
vis.add_contour(example_WG, colormap=plt.cm.summer_r)
vis.set_limits_from_label(sum(example_WG.values()), expansion=2)
fig = vis.show()
fig.savefig("../../3_deliverables/Figures/reg_step_1_align_to_median.png", dpi=400)

In [None]:
probability_map_prostate = sum([sitk.Cast(atlas_set[i]["ALIGNED"]["CONTOUR_PROSTATE"], sitk.sitkFloat64) for i in case_id_list]) / len(case_id_list)
probability_map_pz = sum([sitk.Cast(atlas_set[i]["ALIGNED"]["CONTOUR_PZ"], sitk.sitkFloat64) for i in case_id_list]) / len(case_id_list)

# Save a figure
vis = ImageVisualiser(probability_map_prostate, cut=get_com(exemplar['CONTOUR_PROSTATE']), window=[0,1], figure_size_in=8)
vis.add_scalar_overlay(
    sum([sitk.Cast(atlas_set[i]["ALIGNED"]["CONTOUR_PROSTATE"], sitk.sitkFloat64) for i in case_id_list]),
    colormap=plt.cm.gist_earth,
    discrete_levels=12,
    max_value=60,
    min_value=0,
    name="Number of prostates (aligned)",
    alpha=0.4
)

vis.add_contour({
    'Prostate (Reference)':probability_map_prostate>0.5,
    #'PZ (Reference)':probability_map_pz>0.5,
}, 
colormap=plt.cm.spring_r)
vis.set_limits_from_label(sum(example_WG.values()), expansion=2)
fig = vis.show()
# fig.savefig("../../3_deliverables/Figures/reg_step_2_reference_prostate.png", dpi=400)

In [None]:
# Register stage: 1
# RIGID REGISTRATION

reference_mask = probability_map_prostate>0.5

atlas_id_list = case_id_list[:]

for atlas_id in atlas_id_list:
    
    # Register the external contours using simple alignment
    atlas_mask = atlas_set[atlas_id]["RESAMPLED"]['CONTOUR_PROSTATE']
    
    _, similarity_tfm = linear_registration(
        reference_mask,
        atlas_mask,
        reg_method = "Similarity",
        shrink_factors = [6,3],
        smooth_sigmas = [1,1],
        sampling_rate = 0.85,
        final_interp = sitk.sitkLinear,
        metric = "mean_squares",
        optimiser = "gradient_descent_line_search",
        number_of_iterations = 25
    )
    
    # Apply transform
    atlas_set[atlas_id]["SCALED"] = {}
    
    for label_name in images_bspline:
        atlas_set[atlas_id]["SCALED"][label_name] = apply_transform(
            atlas_set[atlas_id]['RESAMPLED'][label_name],
            reference_image = exemplar_image,
            transform = similarity_tfm,
            default_value=0,
            interpolator=sitk.sitkBSpline
        )
        
    for label_name in labels_linear + images_linear:
        atlas_set[atlas_id]["SCALED"][label_name] = apply_transform(
            atlas_set[atlas_id]['RESAMPLED'][label_name],
            reference_image = exemplar_image,
            transform = similarity_tfm,
            default_value=0,
            interpolator=sitk.sitkLinear
        )
        
    for label_name in labels_nn + images_nn:
        atlas_set[atlas_id]["SCALED"][label_name] = apply_transform(
            atlas_set[atlas_id]['RESAMPLED'][label_name],
            reference_image = exemplar_image,
            transform = similarity_tfm,
            default_value=0,
            interpolator=sitk.sitkNearestNeighbor
        )
    

    p = pathlib.Path(f"../1_processing/ATLAS_DATA_REGISTERED/MRHIST{atlas_id}")
    (p / "IMAGES_SCALED").mkdir(exist_ok=True, parents=True)
    (p / "LABELS_SCALED").mkdir(exist_ok=True, parents=True)
    
    for label_name in labels_linear + labels_nn:
        sitk.WriteImage(atlas_set[atlas_id]["SCALED"][label_name], str( p / "LABELS_SCALED" / f"MRHIST{atlas_id}_{label_name}.nii.gz") )  
    
    for image_name in images_bspline + images_linear + images_nn:
        sitk.WriteImage(atlas_set[atlas_id]["SCALED"][image_name], str( p / "IMAGES_SCALED" / f"MRHIST{atlas_id}_{image_name}.nii.gz") )
                        
    sitk.WriteTransform(initial_tfm, str( p / "TRANSFORMS" / f"MRHIST{atlas_id}_SIMILARITY_TO_REFERENCE.tfm") )

    # Visualise and save figure
    vis = ImageVisualiser(atlas_set[atlas_id]["SCALED"]["MRI_T2W_2D"], cut=get_com(exemplar['CONTOUR_PROSTATE']), window=[0,800])

    vis.add_contour({
        'Prostate (Reference)':reference_mask,
        'PZ (Reference)':probability_map_pz>0.5,
        f'Prostate (MRHIST {atlas_id})':atlas_set[atlas_id]["SCALED"]["CONTOUR_PROSTATE"],
        f'PZ (MRHIST {atlas_id})':atlas_set[atlas_id]["SCALED"]["CONTOUR_PZ"],
    }, 
    colormap=plt.cm.jet)
    
    fig = vis.show()
    fig.savefig(f"../1_processing/FIGURES_REGISTRATION/MRHIST{atlas_id}_1_SCALING.png", dpi=300)
    
    plt.close('all')
    
    # Memory saver
    #atlas_set[atlas_id]["RESAMPLED"] = None
    
plt.close('all')

In [None]:
"""
Memory saver
"""

for atlas_id in case_id_list:
    atlas_set[atlas_id]["RESAMPLED"] = None
    atlas_set[atlas_id]["ALIGNED"] = None

In [None]:
"""
Read in here
"""

process_dir = pathlib.Path("../../2_workspace/1_processing/ATLAS_DATA_REGISTERED/")

atlas_set = {}

atlas_id_list = case_id_list[:]

for atlas_id in atlas_id_list:
    
    print(atlas_id, end=" | ")
    
    atlas_set[atlas_id] = {}

    atlas_set[atlas_id]["SCALED"] = {}

    for image_name in images_linear + images_nn + images_bspline:
        atlas_set[atlas_id]["SCALED"][image_name]   = sitk.ReadImage( (process_dir / f"MRHIST{atlas_id}" / f"IMAGES_SCALED" / f"MRHIST{atlas_id}_{image_name}.nii.gz").as_posix() )

    for label_name in labels_linear + labels_nn:
        atlas_set[atlas_id]["SCALED"][label_name]   = sitk.ReadImage( (process_dir / f"MRHIST{atlas_id}" / f"LABELS_SCALED" / f"MRHIST{atlas_id}_{label_name}.nii.gz").as_posix() )
        
        
probability_map_prostate = sitk.ReadImage("../../2_workspace/2_output/ATLAS_PRODUCTS/REFERENCE_PROSTATE_PROBABILITY.nii.gz")
probability_map_pz = sitk.ReadImage("../../2_workspace/2_output/ATLAS_PRODUCTS/REFERENCE_PZ_PROBABILITY.nii.gz")
exemplar = atlas_set["054"]["SCALED"]

In [None]:
# Register stage: 2
# DEMONS REGISTRATION

reference_mask = probability_map_prostate>0.5
reference_reg_structure = convert_mask_to_reg_structure(reference_mask)

atlas_id_list = case_id_list[:]

for atlas_id in atlas_id_list:
    
    # Register the external contours using simple alignment
    atlas_mask = atlas_set[atlas_id]["SCALED"]['CONTOUR_PROSTATE']
    atlas_reg_structure = convert_mask_to_reg_structure(atlas_mask)
    
    _, deform_tfm, deform_field = fast_symmetric_forces_demons_registration(
        reference_reg_structure,
        atlas_reg_structure,
        resolution_staging     = [3.2, 1.6, 0.8],
        iteration_staging      = [20,20,20],
        smoothing_sigma_factor = 0,     
        isotropic_resample     = True,
        default_value          = 0,
        ncores                 = 8,
        verbose                = False
    )
    
    # Apply transform
    atlas_set[atlas_id]["DIR"] = {}
    
    for label_name in images_bspline:
        atlas_set[atlas_id]["DIR"][label_name] = apply_transform(
            atlas_set[atlas_id]['SCALED'][label_name],
            transform=deform_tfm,
            default_value=0,
            interpolator=sitk.sitkBSpline
        )
        
    for label_name in labels_linear + images_linear:
        atlas_set[atlas_id]["DIR"][label_name] = apply_transform(
            atlas_set[atlas_id]['SCALED'][label_name],
            transform=deform_tfm,
            default_value=0,
            interpolator=sitk.sitkLinear
        )
        
    for label_name in labels_nn + images_nn:
        if label_name == "HISTOLOGY":
            image_hist_dir_0 = apply_transform(sitk.VectorIndexSelectionCast(atlas_set[atlas_id]['SCALED']['HISTOLOGY'], 0), transform=deform_tfm, default_value=0, interpolator=2)
            image_hist_dir_1 = apply_transform(sitk.VectorIndexSelectionCast(atlas_set[atlas_id]['SCALED']['HISTOLOGY'], 1), transform=deform_tfm, default_value=0, interpolator=2)
            image_hist_dir_2 = apply_transform(sitk.VectorIndexSelectionCast(atlas_set[atlas_id]['SCALED']['HISTOLOGY'], 2), transform=deform_tfm, default_value=0, interpolator=2)

            atlas_set[atlas_id]["DIR"]["HISTOLOGY"] = sitk.Compose(image_hist_dir_0, image_hist_dir_1, image_hist_dir_2) 
            
        else:
            atlas_set[atlas_id]["DIR"][label_name] = apply_transform(atlas_set[atlas_id]['SCALED'][label_name], transform=deform_tfm, default_value=0, interpolator=1)
    
    p = pathlib.Path(f"../1_processing/ATLAS_DATA_REGISTERED/MRHIST{atlas_id}")
    (p / "IMAGES_DIR").mkdir(exist_ok=True, parents=True)
    (p / "LABELS_DIR").mkdir(exist_ok=True, parents=True)
    
    
    for label_name in labels_linear + labels_nn:
        sitk.WriteImage(atlas_set[atlas_id]["DIR"][label_name], str( p / "LABELS_DIR" / f"MRHIST{atlas_id}_{label_name}.nii.gz") )
    
    for image_name in images_bspline + images_linear + images_nn:
        sitk.WriteImage(atlas_set[atlas_id]["DIR"][image_name], str( p / "IMAGES_DIR" / f"MRHIST{atlas_id}_{image_name}.nii.gz") )
                        
    
    sitk.WriteImage(deform_field, str( p / "TRANSFORMS" / f"MRHIST{atlas_id}_DVF_PROSTATE_ANATOMY_GUIDED.nii.gz") )

    # Visualise and save figure
    vis = ImageVisualiser(atlas_set[atlas_id]["DIR"]["MRI_T2W_2D"], cut=get_com(exemplar['CONTOUR_PROSTATE']), window=[0,800])

    vis.add_contour({
        'Prostate (Reference)':reference_mask,
        'PZ (Reference)':probability_map_pz>0.5,
        f'Prostate (MRHIST {atlas_id})':atlas_set[atlas_id]["DIR"]["CONTOUR_PROSTATE"],
        f'PZ (MRHIST {atlas_id})':atlas_set[atlas_id]["DIR"]["CONTOUR_PZ"],
    }, 
    colormap=plt.cm.jet)
    
    fig = vis.show()
    fig.savefig(f"../1_processing/FIGURES_REGISTRATION/MRHIST{atlas_id}_2_DIR.png", dpi=300)

    
    # Memory saver
    atlas_set[atlas_id]["SCALED"] = None
    
plt.close("all")

In [None]:
"""
Memory saver
"""

for atlas_id in case_id_list:
    atlas_set[atlas_id]["SCALED"] = None

In [None]:
example_PZ = {i:atlas_set[i]['DIR']['CONTOUR_PZ'] for i in np.random.choice(atlas_id_list, 30, replace=False)}

In [None]:
"""
Calculate reference PZ from DIR-registered cases
"""

probability_map_pz_dir = sum([sitk.Cast(atlas_set[i]["DIR"]["CONTOUR_PZ"], sitk.sitkFloat64) for i in atlas_id_list]) / len(atlas_id_list)

In [None]:
vis = ImageVisualiser(atlas_set["054"]["DIR"]["MRI_T2W_2D"], cut=get_com(exemplar['CONTOUR_PROSTATE'])[0], axis='z', window=[0,800], figure_size_in=5)
vis.add_contour(example_PZ, colormap=plt.cm.summer_r)
vis.set_limits_from_label(sum(example_PZ.values()), expansion=10)
fig = vis.show()
# fig.savefig("../../3_deliverables/Figures/reg_step_3a_reference_pz_indiv_contours.png", dpi=400)

In [None]:
# Save a figure
probability_map_pz_dir_masked = probability_map_pz>0.35
probability_map_pz_dir_masked = sitk.Mask(probability_map_pz_dir_masked, probability_map_prostate>0.5)

vis = ImageVisualiser(probability_map_pz_dir, cut=get_com(exemplar['CONTOUR_PROSTATE']), window=[0,1], figure_size_in=8)
vis.add_scalar_overlay(
    sum([sitk.Cast(atlas_set[i]["DIR"]["CONTOUR_PZ"], sitk.sitkFloat64) for i in case_id_list]),
    colormap=plt.cm.gist_earth,
    discrete_levels=12,
    max_value=60,
    min_value=0,
    name="Number of PZ (aligned)",
    alpha=0.4
)

vis.add_contour({
    'Prostate (Reference)':probability_map_prostate>0.5,
    'PZ (Reference)':probability_map_pz_dir_masked,
}, 
colormap=plt.cm.spring_r)
vis.set_limits_from_label(sum(example_WG.values()), expansion=2)
fig = vis.show()
# fig.savefig("../../3_deliverables/Figures/reg_step_3b_reference_pz_prob_map.png", dpi=400)

In [None]:
from platipy.imaging.registration.utils import convert_mask_to_distance_map

In [None]:
def constrain_dvf(dvf, boundary, dx_mm=3):
    
    d_map = convert_mask_to_distance_map(boundary)
    
    d_map_masked = sitk.Threshold(d_map, lower=0, upper=dx_mm)
    
    mask = d_map_masked / dx_mm
    
    mask = mask + sitk.Cast(sitk.BinaryThreshold(d_map, lowerThreshold=dx_mm, upperThreshold=100), sitk.sitkFloat64)
    
    dvf = sitk.Cast(dvf, sitk.sitkVectorFloat64)
    
    output = sitk.Compose(
        sitk.VectorIndexSelectionCast(dvf, 0) * mask,
        sitk.VectorIndexSelectionCast(dvf, 1) * mask,
        sitk.VectorIndexSelectionCast(dvf, 2) * mask
    )
        
    return output

In [None]:
# Register stage: 3
# DEMONS REGISTRATION - PZ-guided

reference_mask = probability_map_pz_dir_masked
reference_reg_structure = convert_mask_to_reg_structure(reference_mask)

atlas_id_list = case_id_list[:]

for atlas_id in atlas_id_list:
    
    atlas_mask = atlas_set[atlas_id]["DIR"]["CONTOUR_PZ"]
    atlas_reg_structure = convert_mask_to_reg_structure(atlas_mask)
    
    combined_displacement_field = sitk.Image(
        atlas_mask.GetWidth(),
        atlas_mask.GetHeight(),
        atlas_mask.GetDepth(),
        sitk.sitkVectorFloat64,
    )
    
    for image_slice in np.arange(0, probability_map_pz_dir.GetSize()[2], 1):
        
        image_slice = int(image_slice)
        
        if sitk.GetArrayViewFromImage(atlas_reg_structure[:,:,image_slice]).sum() > 25:
            _, _, deform_field = fast_symmetric_forces_demons_registration(
                reference_reg_structure[:,:,image_slice],
                atlas_reg_structure[:,:,image_slice],
                resolution_staging     = [3.2, 1.6, 0.8],
                iteration_staging      = [25,25,25],
                smoothing_sigma_factor = 0,     
                isotropic_resample     = True,
                default_value          = 0,
                ncores                 = 8,
            )
            
            
            reference_prostate_mask_smooth = sitk.Cast( sitk.SmoothingRecursiveGaussian(probability_map_prostate>0.5, sigma=(2,2,2))[:,:,image_slice], sitk.sitkFloat64)
            
            combined_displacement_field[:,:,image_slice] = sitk.Compose(
                reference_prostate_mask_smooth * sitk.VectorIndexSelectionCast(deform_field, 0),
                reference_prostate_mask_smooth * sitk.VectorIndexSelectionCast(deform_field, 1),
                sitk.Cast(0*atlas_mask[:,:,image_slice], sitk.sitkFloat64)
            )
            
    combined_displacement_field.CopyInformation(atlas_mask)

    combined_displacement_field = sitk.SmoothingRecursiveGaussian((combined_displacement_field), sigma=(1,1,1))
    
    """
    regularise on the border of the entire prostate
    this step is important to ensure PZ-guided DIR maintains the prostate boundary alignment
    """
    combined_displacement_field = constrain_dvf(combined_displacement_field, probability_map_prostate>0.5)
        
    combined_tfm = sitk.DisplacementFieldTransform(sitk.Cast(combined_displacement_field, sitk.sitkVectorFloat64))

    # Apply transform
    atlas_set[atlas_id]["DIR_PZ"] = {}

    for label_name in images_bspline:
        atlas_set[atlas_id]["DIR_PZ"][label_name] = apply_transform(atlas_set[atlas_id]['DIR'][label_name], transform=combined_tfm, default_value=0, interpolator=3)
        
    for label_name in labels_linear + images_linear:
        atlas_set[atlas_id]["DIR_PZ"][label_name] = apply_transform(atlas_set[atlas_id]['DIR'][label_name], transform=combined_tfm, default_value=0, interpolator=2)
        
    for label_name in labels_nn + images_nn:
        if label_name == "HISTOLOGY":
            image_hist_dir_0 = apply_transform(sitk.VectorIndexSelectionCast(atlas_set[atlas_id]['DIR']['HISTOLOGY'], 0), transform=combined_tfm, default_value=0, interpolator=2)
            image_hist_dir_1 = apply_transform(sitk.VectorIndexSelectionCast(atlas_set[atlas_id]['DIR']['HISTOLOGY'], 1), transform=combined_tfm, default_value=0, interpolator=2)
            image_hist_dir_2 = apply_transform(sitk.VectorIndexSelectionCast(atlas_set[atlas_id]['DIR']['HISTOLOGY'], 2), transform=combined_tfm, default_value=0, interpolator=2)

            atlas_set[atlas_id]["DIR_PZ"]["HISTOLOGY"] = sitk.Compose(image_hist_dir_0, image_hist_dir_1, image_hist_dir_2) 
            
        else:
            atlas_set[atlas_id]["DIR_PZ"][label_name] = apply_transform(atlas_set[atlas_id]['DIR'][label_name], transform=combined_tfm, default_value=0, interpolator=1)
    
    p = pathlib.Path(f"../1_processing/ATLAS_DATA_REGISTERED/MRHIST{atlas_id}")
    (p / "IMAGES_DIR_PZ").mkdir(exist_ok=True, parents=True)
    (p / "LABELS_DIR_PZ").mkdir(exist_ok=True, parents=True)
    
    for label_name in labels_linear + labels_nn:
        sitk.WriteImage(atlas_set[atlas_id]["DIR_PZ"][label_name], str( p / "LABELS_DIR_PZ" / f"MRHIST{atlas_id}_{label_name}.nii.gz") )
    
    for image_name in images_bspline + images_linear + images_nn:
        sitk.WriteImage(atlas_set[atlas_id]["DIR_PZ"][image_name], str( p / "IMAGES_DIR_PZ" / f"MRHIST{atlas_id}_{image_name}.nii.gz") )
                        
    sitk.WriteImage(combined_displacement_field, str( p / "TRANSFORMS" / f"MRHIST{atlas_id}_DVF_PZ_ANATOMY_GUIDED.nii.gz") )

    # Visualise and save figure
    vis = ImageVisualiser(atlas_set[atlas_id]["DIR_PZ"]["MRI_T2W_2D"], cut=get_com(exemplar['CONTOUR_PROSTATE']), window=[0,800])

    vis.add_contour({
        'Prostate (Reference)':probability_map_prostate>0.5,
        'PZ (Reference)':reference_mask,
        f'Prostate (MRHIST {atlas_id})':atlas_set[atlas_id]["DIR_PZ"]["CONTOUR_PROSTATE"],
        f'PZ (MRHIST {atlas_id})':atlas_set[atlas_id]["DIR_PZ"]["CONTOUR_PZ"],
    }, 
    colormap=plt.cm.jet)
    
    fig = vis.show()
    fig.savefig(f"../1_processing/FIGURES_REGISTRATION/MRHIST{atlas_id}_3_DIR_PZ.png", dpi=300)

    
    # Memory saver
    atlas_set[atlas_id]["DIR"] = None
    
plt.close("all")    

In [None]:
"""
Calculate urethra
"""

probability_map_urethra_dir_2 = sum([sitk.Cast(atlas_set[i]["DIR_PZ"]["CONTOUR_URETHRA"], sitk.sitkFloat64) for i in atlas_id_list]) / len(atlas_id_list)

In [None]:
"""
Save atlas products
"""

#sitk.WriteImage(probability_map_prostate, "../2_output/ATLAS_PRODUCTS/REFERENCE_PROSTATE_PROBABILITY.nii.gz")
sitk.WriteImage(probability_map_pz_dir, "../2_output/ATLAS_PRODUCTS/REFERENCE_PZ_PROBABILITY.nii.gz")
sitk.WriteImage(probability_map_urethra_dir_2, "../2_output/ATLAS_PRODUCTS/REFERENCE_URETHRA_PROBABILITY.nii.gz")

#sitk.WriteImage(probability_map_prostate>0.5, "../2_output/ATLAS_PRODUCTS/REFERENCE_PROSTATE.nii.gz")
sitk.WriteImage(probability_map_pz_dir_masked, "../2_output/ATLAS_PRODUCTS/REFERENCE_PZ.nii.gz")

In [None]:
"""
Compute urethra spline using nominal 1.5mm radius
- we also produce smaller and larger prostates to be used as a measure of uncertainty
"""

for vessel_radius in [0.8, 1, 1.5, 2, 2.2]:

    spline_U = vessel_spline_generation(
        reference_image = probability_map_prostate,
        atlas_set = atlas_set,
        vessel_name_list = ["CONTOUR_URETHRA"],
        vessel_radius_mm_dict = {"CONTOUR_URETHRA":vessel_radius},
        stop_condition_type_dict = {"CONTOUR_URETHRA":"area"},
        stop_condition_value_dict = {"CONTOUR_URETHRA":4},
        scan_direction_dict = {"CONTOUR_URETHRA":"z"},
        atlas_label = "DIR_PZ"
    )

    sitk.WriteImage(spline_U["CONTOUR_URETHRA"], f"../2_output/ATLAS_PRODUCTS/REFERENCE_URETHRA_SPLINE_{vessel_radius}MM.nii.gz")