# Atlas generation and statistical analysis

### Initialisation

Set up of environment and loading of data:
- read in modules
- define some useful code snippets
- read in the atlas set

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from birt_utils import smooth_edges, read_atlas_data, calculate_sig_map, kl_divergence

import pathlib
import functools

import pandas as pd
import SimpleITK as sitk

from platipy.imaging.label.utils import get_com

from platipy.imaging.registration.linear import alignment_registration, linear_registration

from platipy.imaging.registration.utils import apply_transform, smooth_and_resample, convert_mask_to_reg_structure

from platipy.imaging.registration.deformable import fast_symmetric_forces_demons_registration

from platipy.imaging import ImageVisualiser

from platipy.imaging.label.iar import median_absolute_deviation

from matplotlib import colors
from scipy.stats import gaussian_kde

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import norm, binom, lognorm
from scipy.stats import gaussian_kde

import matplotlib.ticker as mticker
import matplotlib.patches as mpatches

from palettable.cmocean.sequential import Ice_3


%matplotlib notebook

In [None]:
"""
Snippets
"""

f_conv = lambda x: sitk.GetArrayViewFromImage(x)
f_conv_r = lambda x: sitk.GetImageFromArray(x)

In [None]:
"""
Get the list of patient ids
"""
input_dir = pathlib.Path("../1_processing/ATLAS_DATA_REGISTERED/")
case_id_list = sorted([i.name[6:] for i in input_dir.glob("*MRHIST*")])
print(len(case_id_list), case_id_list)

In [None]:
"""
Simplify the images/labels that we propagate
"""

labels_linear = [
    "TUMOUR_PROBABILITY_GRADE_2+2",
    "TUMOUR_PROBABILITY_GRADE_3+2",
    "TUMOUR_PROBABILITY_GRADE_3+3",
    "TUMOUR_PROBABILITY_GRADE_3+4",
    "TUMOUR_PROBABILITY_GRADE_4+3",
    "TUMOUR_PROBABILITY_GRADE_4+4",
    "TUMOUR_PROBABILITY_GRADE_4+5",
    "TUMOUR_PROBABILITY_GRADE_5+4",
    "TUMOUR_PROBABILITY_GRADE_5+5",
]

labels_nn = [
    "CONTOUR_PROSTATE",
    "CONTOUR_PZ",
    "CONTOUR_URETHRA",
    "LABEL_HISTOLOGY",
    "LABEL_SAMPLING"
]

images_linear = [
    "MRI_T2W_2D",
    "CELL_DENSITY_MAP",
]

images_nn = [
    "HISTOLOGY"
]

data_names = labels_linear + labels_nn + images_linear + images_nn

In [None]:
"""
Read in the atlas contours
"""

reference_WG = sitk.ReadImage("../2_output/ATLAS_PRODUCTS/REFERENCE_PROSTATE.nii.gz")
reference_PZ = sitk.ReadImage("../2_output/ATLAS_PRODUCTS/REFERENCE_PZ.nii.gz")
reference_U = sitk.ReadImage("../2_output/ATLAS_PRODUCTS/REFERENCE_URETHRA_SPLINE_1.5MM.nii.gz")

In [None]:
"""
Read in data

All the data (only final stage: DIR_PZ)
"""


atlas_set = read_atlas_data(
    case_id_list=case_id_list,
    reg_types=["DIR_PZ"],
    label_names = labels_linear + labels_nn,
    image_names = images_linear + images_nn,
    input_dir = input_dir
)

### Processing

In this section we compute the basic components of the atlas and visualise using platipy.

The most important object is `atlas_products`.
This is a dictionary containing the relevant components of the atlas:
- Mean MRI (useful for visualisation): `"MEAN_MRI"`
- Sampling frequency: `"SAMPLING_FREQUENCY"`



In [None]:
"""
Initialise
"""

atlas_products = {}

In [None]:
"""
Define the cut
It's worth playing around with this parameter to explore different parts of the atlas
"""

cut = (26, 62, 61)

In [None]:
"""
MEAN MRI
"""
mri_mean = sum([sitk.Cast(atlas_set[i]["DIR_PZ"]["MRI_T2W_2D"], sitk.sitkFloat64) for i in case_id_list])/len(case_id_list)
mri_mean = sitk.Mask(mri_mean, reference_WG)

atlas_products["MEAN_MRI"] = mri_mean

In [None]:
"""
MEAN MRI
Visualise
"""

vis = ImageVisualiser(atlas_products["MEAN_MRI"], axis='ortho', cut=cut, window=[190,200], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_contour({
    "Prostate":reference_WG,
    "PZ":reference_PZ,
    "Urethra":reference_U}, colormap=plt.cm.autumn)

fig = vis.show()

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/mri/mri_mean_ortho.png", dpi=300)

In [None]:
"""
SAMPLING FREQUENCY

Calculate sampling frequency (per voxel)
This is used to normalise the atlas results
"""

sampling_frequency = sitk.Cast(
    sum([(atlas_set[case_id]['DIR_PZ']['LABEL_SAMPLING']>0.25) for case_id in case_id_list]),
    sitk.sitkFloat64
)

atlas_products["SAMPLING_FREQUENCY"] = sampling_frequency

# For cell density we take a stricter cut-off
valid_cd_samples = [(atlas_set[case_id]['DIR_PZ']['LABEL_SAMPLING']>0.8) & (atlas_set[case_id]['DIR_PZ']['CELL_DENSITY_MAP']>2500) for case_id in case_id_list]
atlas_products["VALID_CD_SAMPLES"] = valid_cd_samples

sampling_frequency_cd = sitk.Cast(sum(valid_cd_samples), sitk.sitkFloat64)
atlas_products["SAMPLING_FREQUENCY_CD"] = sampling_frequency_cd


sitk.WriteImage(sampling_frequency, "../2_output/ATLAS_PRODUCTS/ATLAS_SAMPLING_FREQUENCY.nii.gz")
sitk.WriteImage(sampling_frequency_cd, "../2_output/ATLAS_PRODUCTS/ATLAS_SAMPLING_FREQUENCY_CD.nii.gz")

In [None]:
"""
SAMPLING FREQUENCY

Visualisation
"""

vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.add_scalar_overlay(
    sitk.Mask(atlas_products["SAMPLING_FREQUENCY"], atlas_products["SAMPLING_FREQUENCY"]>1),
    min_value=0,
    max_value=65,
    discrete_levels=13,
    name='Sampling Frequency',
    colormap=plt.cm.RdBu,
    alpha=1
)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/sampling_frequency/atlas_sampling_frequency.png", dpi=500)

### CELL DENSITY

In this section we generate the atlas models:
 - Cell density: `"CELL_DENSITY"`
  - mean: `"MEAN"`
  - mean (log): `"LOG_MEAN"`
  - variance: `"VARIANCE"`
  - variance (log): `"LOG_VARIANCE"`

In [None]:
"""
Initialisation
"""

atlas_products["CELL_DENSITY"] = {}

In [None]:
"""
Cell density calculation

"""
cell_density_list = [atlas_set[case_id]['DIR_PZ']['CELL_DENSITY_MAP'] for case_id in case_id_list]

# Calculate the mean
cell_density_mean = sum(cell_density_list) / atlas_products["SAMPLING_FREQUENCY"]
cell_density_mean_smedge = smooth_edges(cell_density_mean, reference_WG, dilate_voxels=5, smooth_mm=2)

# Calculate the variance
cell_density_var = sum(
    [
        (atlas_set[case_id]['DIR_PZ']['CELL_DENSITY_MAP'] - cell_density_mean)**2
        for case_id in case_id_list
    ]
) / (atlas_products["SAMPLING_FREQUENCY"]-1) #  sample variance


cell_density_var = f_conv_r(np.var([f_conv(atlas_set[case_id]['DIR_PZ']['CELL_DENSITY_MAP']) for case_id in case_id_list], axis=0))
cell_density_var.CopyInformation(cell_density_mean)

cell_density_var_smedge = smooth_edges(cell_density_var, reference_WG, dilate_voxels=5, smooth_mm=2)

atlas_products["CELL_DENSITY"]["MEAN"] = cell_density_mean_smedge
atlas_products["CELL_DENSITY"]["VARIANCE"] = cell_density_var_smedge

"""
Log CD computation
"""

log_cell_density_list = [sitk.Cast(sitk.Mask(sitk.Log(i), j), sitk.sitkFloat64) for i,j in zip(cell_density_list, atlas_products["VALID_CD_SAMPLES"])]

# Calculate the mean
log_cell_density_mean = sum(log_cell_density_list) / sum([sitk.Cast(i,sitk.sitkFloat64) for i in atlas_products["VALID_CD_SAMPLES"]])#atlas_products["SAMPLING_FREQUENCY"]
log_cell_density_mean_smedge = smooth_edges(log_cell_density_mean, reference_WG, dilate_voxels=5, smooth_mm=2)

# Calculate the variance
log_cell_density_var = sum(
    [
        sitk.Mask((i - log_cell_density_mean)**2,j)
        for i,j in zip(log_cell_density_list, atlas_products["VALID_CD_SAMPLES"])
    ]
) / (atlas_products["SAMPLING_FREQUENCY"]-1) # sample variance

# log_cell_density_var = f_conv_r(np.var([f_conv(i) for i in log_cell_density_list], axis=0))
# log_cell_density_var.CopyInformation(log_cell_density_mean)

# remove the inevitable NaNs
arr_cd_log_var = sitk.GetArrayFromImage(log_cell_density_var)
arr_cd_log_var[np.where(~np.isfinite(arr_cd_log_var))] = 0
log_cell_density_var = f_conv_r(arr_cd_log_var)
log_cell_density_var.CopyInformation(log_cell_density_mean)

log_cell_density_var_smedge = smooth_edges(log_cell_density_var, reference_WG, dilate_voxels=5, smooth_mm=2)

atlas_products["CELL_DENSITY"]["LOG_MEAN"] = log_cell_density_mean_smedge
atlas_products["CELL_DENSITY"]["LOG_VARIANCE"] = log_cell_density_var_smedge


"""
save the products
"""

sitk.WriteImage(atlas_products["CELL_DENSITY"]["MEAN"], "../2_output/ATLAS_PRODUCTS/ATLAS_CELL_DENSITY_MEAN.nii.gz")
sitk.WriteImage(sitk.Sqrt(atlas_products["CELL_DENSITY"]["VARIANCE"]), "../2_output/ATLAS_PRODUCTS/ATLAS_CELL_DENSITY_STDDEV.nii.gz")
sitk.WriteImage(atlas_products["CELL_DENSITY"]["LOG_MEAN"], "../2_output/ATLAS_PRODUCTS/ATLAS_CELL_DENSITY_LOG_MEAN.nii.gz")
sitk.WriteImage(sitk.Sqrt(atlas_products["CELL_DENSITY"]["LOG_VARIANCE"]), "../2_output/ATLAS_PRODUCTS/ATLAS_CELL_DENSITY_LOG_STDDEV.nii.gz")

In [None]:
"""
CELL DENSITY MEAN
Visualisation
"""

vis = ImageVisualiser(atlas_products["MEAN_MRI"]*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.add_scalar_overlay(
    atlas_products["CELL_DENSITY"]["MEAN"],
    min_value=0,
    max_value=210000,
    name='Cell Density [mm'+r'$^{-3}$'+']',
    colormap=plt.cm.jet
)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_0_mean.png", dpi=500)

In [None]:
"""
CELL DENSITY STD DEV
Visualisation
"""
   
vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(
    sitk.Sqrt(atlas_products["CELL_DENSITY"]["VARIANCE"]),
    min_value=0,
    max_value=1e5,
    discrete_levels=10,
    name='Standard Deviation',
    colormap=sns.color_palette("mako", as_cmap=True)
)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)
fig = vis.show()

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_0_stddev.png", dpi=500)

In [None]:
"""
CELL DENSITY CV
COEFFICIENT OF VARIATION
Visualisation
"""
   
vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(
    100*sitk.Sqrt(atlas_products["CELL_DENSITY"]["VARIANCE"])/atlas_products["CELL_DENSITY"]["MEAN"],
    min_value=0,
    max_value=100,
    discrete_levels=10,
    name='Cell Density CV\nCoefficient of Variation [%]',
    colormap=sns.color_palette("mako", as_cmap=True)
)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)
fig = vis.show()

sitk.WriteImage(cell_density_var_smedge, "../2_output/ATLAS_PRODUCTS/ATLAS_CELL_DENSITY_COEFF_VAR.nii.gz")

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_0_CV.png", dpi=500)

In [None]:
"""
LOG CELL DENSITY MEAN
Visualisation
"""

vis = ImageVisualiser(atlas_products["MEAN_MRI"]*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.add_scalar_overlay(
    atlas_products["CELL_DENSITY"]["LOG_MEAN"],
    min_value=9.9,
    max_value=12.25,
    name='(Log) Cell Density [mm'+r'$^{-3}$'+']',
    colormap=plt.cm.jet
)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()


"""
Change colorbar axis
"""
ax = fig.axes[4]
cbar = fig.axes[0].images[1].colorbar

# Convert labels to 10^x display
original_ticks = list(ax.get_yticks())


new_tick_labels = [20000,30000,45000,70000,100000,150000,200000,]
new_ticks = np.log(new_tick_labels)
# new_tick_labels[-1] = "1"

# Add a tick at 0.05
# new_ticks = new_ticks + [np.log10(0.05),]

# Add a tick label "0.05"
# new_tick_labels = new_tick_labels + [0.05]

cbar.set_ticks(new_ticks)
cbar.set_ticklabels(new_tick_labels)

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_1_log_mean.png", dpi=500)

In [None]:
"""
LOG CELL DENSITY STD DEV
Visualisation
"""

# We have to get fancy here: the STD DEV doesn't make sense by itself in LOG space
# Calculate the mean plus 1 std. dev
mean_plus_1SD = atlas_products["CELL_DENSITY"]["LOG_MEAN"] + sitk.Sqrt(atlas_products["CELL_DENSITY"]["LOG_VARIANCE"])
# Now convert to linear space
mean_plus_1SD_linear = sitk.Exp(mean_plus_1SD)
# Now get the difference from the mean
stddev = mean_plus_1SD - sitk.Exp(atlas_products["CELL_DENSITY"]["LOG_MEAN"])
   
vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(
    sitk.Sqrt(atlas_products["CELL_DENSITY"]["LOG_VARIANCE"]),
    min_value=0,
    max_value=2,
    discrete_levels=10,
    name='Standard Deviation',
    colormap=sns.color_palette("mako", as_cmap=True)
)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)
fig = vis.show()

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_1_log_stddev.png", dpi=500)

In [None]:
"""
LOG CELL DENSITY CV
COEFFICIENT OF VARIATION
Visualisation
"""
   
vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(
    100*sitk.Sqrt(atlas_products["CELL_DENSITY"]["LOG_VARIANCE"])/atlas_products["CELL_DENSITY"]["LOG_MEAN"],
    min_value=0,
    max_value=10,
    discrete_levels=10,
    name='Cell Density CV\nCoefficient of Variation [%]',
    colormap=sns.color_palette("mako", as_cmap=True)
)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)
fig = vis.show()

sitk.WriteImage(cell_density_var_smedge, "../2_output/ATLAS_PRODUCTS/ATLAS_CELL_DENSITY_LOG_COEFF_VAR.nii.gz")

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_1_log_CV.png", dpi=500)

In [None]:
"""
CELL DENSITY VARIATION
SAMPLING LINEAR MODEL
"""

for i in [-1,-0.5,0,0.5,1]:

    cd_sample = atlas_products["CELL_DENSITY"]["MEAN"] + i*sitk.Sqrt(atlas_products["CELL_DENSITY"]["VARIANCE"])

    vis = ImageVisualiser(mri_mean*0 + 0.1, axis='z', cut=cut[0], window=[0,1], figure_size_in=6)
    # vis = ImageVisualiser(mri_mean*0 + 0.1, cut=cut, window=[0,1], figure_size_in=6)

    vis.add_scalar_overlay(cd_sample, min_value=0, max_value=310000, name='Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.jet, show_colorbar=False)

    vis.set_limits_from_label(reference_WG, expansion=5)

    #sns.color_palette("crest", as_cmap=True)

    vis.add_contour({
        "Prostate":reference_WG,
        "Urethra":reference_U,
        "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

    fig = vis.show()

    fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_2_variation_mean_{i}sigma.png", dpi=500)
    
"""
save one ortho for the colormap
"""
vis = ImageVisualiser(mri_mean*0 + 0.1, window=[0,1], figure_size_in=6)
# vis = ImageVisualiser(mri_mean*0 + 0.1, cut=cut, window=[0,1], figure_size_in=6)

vis.add_scalar_overlay(cd_sample, min_value=0, max_value=310000, name='Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.jet)

vis.set_limits_from_label(reference_WG, expansion=5)

#sns.color_palette("crest", as_cmap=True)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()
fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_2_variation_ortho_for_colormap.png", dpi=500)



In [None]:
"""
LOG CELL DENSITY VARIATION
SAMPLING LINEAR MODEL
"""

for i in [-1,-0.5,0,0.5,1]:

    cd_sample = sitk.Exp(atlas_products["CELL_DENSITY"]["LOG_MEAN"] + i*sitk.Sqrt(atlas_products["CELL_DENSITY"]["LOG_VARIANCE"]))

    vis = ImageVisualiser(mri_mean*0 + 0.1, axis='z', cut=cut[0], window=[0,1], figure_size_in=6)
    # vis = ImageVisualiser(mri_mean*0 + 0.1, cut=cut, window=[0,1], figure_size_in=6)

    vis.add_scalar_overlay(sitk.Mask(cd_sample, reference_WG), min_value=0, max_value=310000, name='Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.jet, show_colorbar=False)

    vis.set_limits_from_label(reference_WG, expansion=5)

    #sns.color_palette("crest", as_cmap=True)

    vis.add_contour({
        "Prostate":reference_WG,
        "Urethra":reference_U,
        "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

    fig = vis.show()

    fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_3_log_variation_mean_{i}sigma.png", dpi=500)
    
"""
save one ortho for the colormap
"""
vis = ImageVisualiser(mri_mean*0 + 0.1, window=[0,1], figure_size_in=6)
# vis = ImageVisualiser(mri_mean*0 + 0.1, cut=cut, window=[0,1], figure_size_in=6)

vis.add_scalar_overlay(sitk.Mask(cd_sample, reference_WG), min_value=0, max_value=310000, name='Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.jet)

vis.set_limits_from_label(reference_WG, expansion=5)

#sns.color_palette("crest", as_cmap=True)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()
fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/atlas_cell_density_3_log_variation_ortho_for_colormap.png", dpi=500)



In [None]:
"""
Statistical model validation
Plot mu vs sigma^2
!TODO
"""



## TUMOUR PROBABILITY


The next section generates the tumour probability component of the atlas:
 - Combine all the tumour grades 
 - Smooth each individual tumour probability map

In [None]:
atlas_products["TUMOUR_PROBABILITY"] = {}

In [None]:
"""
Here we combined total probability maps
"""

sigma = 0 # 0.2 also an option

for case_id in case_id_list:

    summation_image = atlas_set[case_id]["DIR_PZ"][labels_linear[0]]
    
    for label_name in labels_linear[1:]:
        
        summation_image += atlas_set[case_id]["DIR_PZ"][label_name]
        
    if sigma>0:
        summation_image = sitk.SmoothingRecursiveGaussian(summation_image, sigma=(sigma,sigma,sigma), normalizeAcrossScale=True)
        
    atlas_set[case_id]["DIR_PZ"]['TUMOUR_PROBABILITY_TOTAL'] = summation_image

In [None]:
"""
Now we compute the prevalence atlas
"""

# Mean
probability_map_tumour_mean = sum([atlas_set[case_id]["DIR_PZ"]['TUMOUR_PROBABILITY_TOTAL'] for case_id in case_id_list])/atlas_products["SAMPLING_FREQUENCY"]

# remove the inevitable NaNs
arr_tp_var = sitk.GetArrayFromImage(probability_map_tumour_mean)
arr_tp_var[np.where(~np.isfinite(arr_tp_var))] = 0
probability_map_tumour_mean = f_conv_r(arr_tp_var)
probability_map_tumour_mean.CopyInformation(atlas_products["SAMPLING_FREQUENCY"])

probability_map_tumour_mean_smedge = smooth_edges(probability_map_tumour_mean, reference_WG, dilate_voxels=2, smooth_mm=1)

atlas_products["TUMOUR_PROBABILITY"]["MEAN"] = probability_map_tumour_mean_smedge

# Sigma (binomail approx.)
#r"$TP \sim \mathcal{N}(\mu=p, \sigma=\sqrt{\frac{p(1-p)}{62})}$")
arr_tp = f_conv(atlas_products["TUMOUR_PROBABILITY"]["MEAN"])
arr_sf = f_conv(atlas_products["SAMPLING_FREQUENCY"])

arr = arr_tp*(1-arr_tp)/arr_sf
arr[np.where(~np.isfinite(arr))] = 0
probability_map_tumour_var = f_conv_r(arr)

probability_map_tumour_var.CopyInformation(probability_map_tumour_mean)
probability_map_tumour_var_smedge = smooth_edges(probability_map_tumour_var, reference_WG, dilate_voxels=2, smooth_mm=1)

atlas_products["TUMOUR_PROBABILITY"]["VARIANCE"] = probability_map_tumour_var_smedge
atlas_products["TUMOUR_PROBABILITY"]["STDDEV"] = sitk.Sqrt(probability_map_tumour_var_smedge)


"""
save the products
"""

sitk.WriteImage(atlas_products["TUMOUR_PROBABILITY"]["MEAN"], "../2_output/ATLAS_PRODUCTS/ATLAS_TUMOUR_PROBABILITY_MEAN.nii.gz")
sitk.WriteImage(atlas_products["TUMOUR_PROBABILITY"]["VARIANCE"], "../2_output/ATLAS_PRODUCTS/ATLAS_TUMOUR_PROBABILITY_VARIANCE.nii.gz")
sitk.WriteImage(atlas_products["TUMOUR_PROBABILITY"]["STDDEV"], "../2_output/ATLAS_PRODUCTS/ATLAS_TUMOUR_PROBABILITY_STDDEV.nii.gz")

In [None]:
"""
TUMOUR PROBABILITY
Visualisation
"""

vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(
    atlas_products["TUMOUR_PROBABILITY"]["MEAN"],
    min_value=0,
    max_value=0.4,
    name='Tumour Probability',
    discrete_levels=10,
    colormap=plt.cm.magma
)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()
fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/tumour_probability/atlas_tumour_probability_0_mean.png", dpi=500)

In [None]:
"""
Create some nice colormaps for variation
"""

colors_large = plt.cm.terrain(np.linspace(0, 0.9, 256))
colors_small = plt.cm.terrain(np.linspace(0.9, 1, 256))

all_colors = np.vstack((colors_large, colors_small))

colormap_var = colors.LinearSegmentedColormap.from_list('sig_map',
    all_colors)

divnorm_var = colors.TwoSlopeNorm(vmin=0, vcenter=35, vmax=40)

In [None]:
"""
TUMOUR PROBABILITY CV
Visualisation
"""

tumour_probability_cv = 100*atlas_products["TUMOUR_PROBABILITY"]["STDDEV"]/atlas_products["TUMOUR_PROBABILITY"]["MEAN"]
arr_tp_cv = sitk.GetArrayFromImage(tumour_probability_cv)
arr_tp_cv[np.where(~np.isfinite(arr_tp_cv))] = 0
tumour_probability_cv = f_conv_r(arr_tp_cv)
tumour_probability_cv.CopyInformation(atlas_products["TUMOUR_PROBABILITY"]["MEAN"])


vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho',cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(
    tumour_probability_cv,
    max_value=100,
    discrete_levels=10,
    name='Tumour Probability CV\nCoefficient of Variation [%]',
    colormap=sns.color_palette("mako", as_cmap=True)
)
vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/tumour_probability/atlas_tumour_probability_0_CV.png", dpi=500)

In [None]:
"""
TUMOUR PROBABILITY STD DEV
Visualisation
"""

vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho',cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(
    atlas_products["TUMOUR_PROBABILITY"]["STDDEV"],
    max_value=0.1,
    discrete_levels=10,
    name='Tumour Probability Std. Dev.',
    colormap=sns.color_palette("mako", as_cmap=True)
)
vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()

fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/tumour_probability/atlas_tumour_probability_0_stddev.png", dpi=500)

In [None]:
"""
TUMOUR PROBABILITY VARIATION
SAMPLING NORMAL MODEL
"""

for i in [-1,-0.5,0,0.5,1]:

    tp_sample = atlas_products["TUMOUR_PROBABILITY"]["MEAN"] + i*sitk.Sqrt(atlas_products["TUMOUR_PROBABILITY"]["VARIANCE"])

    vis = ImageVisualiser(mri_mean*0 + 0.1, axis='z', cut=cut[0], window=[0,1], figure_size_in=6)

    vis.add_scalar_overlay(sitk.Mask(tp_sample, reference_WG), min_value=0, max_value=0.4, name='Tumour Probability', discrete_levels=10, colormap=plt.cm.magma, show_colorbar=False)

    vis.set_limits_from_label(reference_WG, expansion=5)

    vis.add_contour({
        "Prostate":reference_WG,
        "Urethra":reference_U,
        "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

    fig = vis.show()

    fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/tumour_probability/atlas_tumour_probability_1_variation_mean_{i}sigma.png", dpi=500)
    
"""
save one ortho for the colormap
"""
vis = ImageVisualiser(mri_mean*0 + 0.1, window=[0,1], figure_size_in=6)

vis.add_scalar_overlay(sitk.Mask(tp_sample, reference_WG), min_value=0, max_value=0.4, name='Tumour Probability', discrete_levels=10, colormap=plt.cm.magma, show_colorbar=False)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False)

fig = vis.show()
fig.savefig(f"../../3_deliverables/Figures/working/atlas_ortho_slices/tumour_probability/atlas_tumour_probability_1_variation_ortho_for_colormap.png", dpi=500)

In [None]:
"""
Sample visualisation
"""


tumour_cell_density = probability_map_tumour_mean_smedge * cell_density_mean_smedge
   
vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho', cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(tumour_cell_density, min_value=0, max_value=70000, mid_ticks=True, discrete_levels=15,  name='Tumour Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.afmhot)
vis.add_contour({
    "Prostate":reference_WG,
    "Urethra":reference_U,
    "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False
)

fig = vis.show()

sitk.WriteImage(tumour_cell_density, "../2_output/atlas_products/atlas_tumour_cell_density.nii.gz")

fig.savefig(f"../../3_deliverables/Figures/atlas_mean_tumour_cell_density.png", dpi=500)

### Analysis of statistical models

In [None]:
"""
Create voxel-level plots of the distribution of parameters

For tumour probability - we approximate a normal distribution from a binomial distribution
sigma ~ sqrt( p(1-p) / n )

For the cell density - we directly compute a normal distribution from statistics

"""

d = 4
sample_location = (26,67,47)
# sample_location = (26,67,77)
ix,iy,iz = sample_location

test_arr = sitk.GetArrayFromImage(mri_mean*0)
test_arr[ix:ix+d,iy:iy+d,iz:iz+d] = 1
test_img = sitk.GetImageFromArray(test_arr)
test_img.CopyInformation(atlas_products["MEAN_MRI"])

In [None]:
"""
TUMOUR PROBABILITY WITH SAMPLE
"""

vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho',cut=sample_location, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(atlas_products["TUMOUR_PROBABILITY"]["MEAN"], min_value=0, max_value=0.4, name='Tumour Probability', discrete_levels=10, colormap=plt.cm.magma)

vis.add_scalar_overlay(test_img, show_colorbar=False, min_value=0, max_value=1, colormap = plt.cm.gray, alpha=1)

# vis.add_contour({
#     "Prostate":reference_WG,
#     "Urethra":reference_U,
#     "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False
# )


fig = vis.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_ortho_slices/tumour_probability/single_voxel_sample_tp_single.jpeg", dpi=400)

In [None]:
"""
TUMOUR PROBABILITY
HISTOGRAM OF VALUES
"""
x = np.linspace(0,1,101)
mu = probability_map_tumour_mean[sample_location[::-1]]
# n_obs = sampling_frequency[loc_br_img] # IN PRACTICE WE WOULD NEED THIS
sigma = np.sqrt(mu*(1-mu)/63)
y = norm.pdf(x, mu, sigma)

binom_tp = binom(n=63, p=mu)
x_int = np.arange(0,63)/63.
dx = x_int[1] - x_int[0]


fig, ax = plt.subplots(1,1,figsize=(6*0.8,3*0.8))

bars = ax.bar(x_int, 63*binom_tp.pmf(np.arange(0,63)), width=dx, ec="white", zorder=1, fc="#107030")#r"$TP \sim \mathcal{B}(N=63, p)$")

ax.plot(x,y, c="#000033", lw=2, zorder=2, label="Normal approximation")#r"$TP \sim \mathcal{N}(\mu=p, \sigma=\sqrt{\frac{p(1-p)}{63})}$")
ax.set_xlim(0,1)
ax.set_ylim(0,14)

ax.set_yticks((0,2,4,6,8,10,12,14))

ax.set_xlabel("Tumour probability (rate of occurance)")
ax.set_ylabel("Relative Likelihood")
    
handles, _ = ax.get_legend_handles_labels()
handles.append(mpatches.Patch(color="#107030", label="Empirical (binomial model)"))

#ax.grid()
ax.legend(handles=handles)
ax.set_axisbelow(True)

fig.tight_layout()

fig.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_statistics/histogram_tumour_probability_single_2_voxel.jpeg", dpi=400)

In [None]:
"""
CELL DENSITY WITH SAMPLE
"""
vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho',cut=sample_location, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

vis.add_scalar_overlay(atlas_products["CELL_DENSITY"]["MEAN"], min_value=0, name='Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.jet, max_value=210000)

vis.add_scalar_overlay(test_img, show_colorbar=False, min_value=0, max_value=1, colormap = plt.cm.gray, alpha=1)

# vis.add_contour({
#     "Prostate":reference_WG,
#     "Urethra":reference_U,
#     "PZ":reference_PZ}, colormap=plt.cm.spring_r, show_legend=False
# )


fig = vis.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_ortho_slices/cell_density/single_voxel_sample_cd.jpeg", dpi=400)

In [None]:
"""
CELL DENSITY (LINEAR MODEL)
HISTOGRAM OF VALUES
"""
cell_density_vals = np.array([atlas_set[i]["DIR_PZ"]["CELL_DENSITY_MAP"][sample_location[::-1]] for i in case_id_list])

x = np.linspace(0,2*cell_density_vals.max(),501)
mu = cell_density_vals.mean()
sigma = cell_density_vals.std()
y = norm.pdf(x, mu, sigma)

kde_cell_density = gaussian_kde(cell_density_vals)

kl_div = kl_divergence(norm(mu, sigma), kde_cell_density, x)

kde_cell_density = gaussian_kde(cell_density_vals)

fig, ax = plt.subplots(1,1,figsize=(6*0.8,4*0.7))

_,_,hist_ = ax.hist(cell_density_vals, bins=np.linspace(0,2*cell_density_vals.max(),25), density=True, ec="white", fc="#107030", zorder=1, label=r"Empirical (histogram)")
kde_, = ax.plot(x, kde_cell_density(x), c="#000033", lw=2, ls="--", label="Gaussian KDE")

norm_, = ax.plot(x,y, c="#000033", lw=2, zorder=2, label="Normal Model\n"+r"$D_{KL}=$"+f"{kl_div:.3f} bits")


ax.set_xlabel("Cell Density [cm"+r"$^{-3}$"+"] (Linear Scale)")
ax.set_ylabel("Relative Likelihood")

ax.set_xlim(0,2*cell_density_vals.max())
# ax.set_ylim(0,6e-5)

#ax.grid()
ax.legend(handles=[hist_.patches[0],kde_,norm_])
ax.set_axisbelow(True)

# xticks = [1000,2500, 5000,12000,25000,50000,100000]
# ax.set_xticks(np.log(xticks))
# ax.set_xticklabels(xticks, rotation=45)

ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)

fig.tight_layout()

fig.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_statistics/histogram_cell_density_single_voxel_2_linear.jpeg", dpi=400)

In [None]:
"""
CELL DENSITY (LINEAR MODEL)
HISTOGRAM OF VALUES
"""
cell_density_vals = np.log([atlas_set[i]["DIR_PZ"]["CELL_DENSITY_MAP"][sample_location[::-1]] for i in case_id_list])
x = np.linspace(0.8*cell_density_vals.min(),1.15*cell_density_vals.max(),501)
mu = cell_density_vals.mean()
sigma = cell_density_vals.std()
y = norm.pdf(x, mu, sigma)

kde_cell_density = gaussian_kde(cell_density_vals)

kl_div = kl_divergence(norm(mu, sigma), kde_cell_density, x)

kde_cell_density = gaussian_kde(cell_density_vals)

fig, ax = plt.subplots(1,1,figsize=(6*0.8,4*0.7))

_,_,hist_ = ax.hist(cell_density_vals, bins=np.linspace(0.8*cell_density_vals.min(),1.15*cell_density_vals.max(),25), density=True, ec="white", fc="#107030", zorder=1, label=r"Empirical (histogram)")
kde_, = ax.plot(x, kde_cell_density(x), c="#000033", lw=2, ls="--", label="Gaussian KDE")

norm_, = ax.plot(x,y, c="#000033", lw=2, zorder=2, label="Log-Normal Model\n"+r"$D_{KL}=$"+f"{kl_div:.3f} bits")


ax.set_xlabel("Cell Density [cm"+r"$^{-3}$"+"] (Log Scale)")
ax.set_ylabel("Relative Likelihood")

ax.set_xlim(0.8*cell_density_vals.min(),1.15*cell_density_vals.max())
# ax.set_ylim(0,6e-5)

#ax.grid()
ax.legend(handles=[hist_.patches[0],kde_,norm_])
ax.set_axisbelow(True)

xticks = [1000,2500, 5000,12000,25000,50000,100000,250000,500000,1000000,2500000]
ax.set_xticks(np.log(xticks))
ax.set_xticklabels(xticks, rotation=45)

ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)

fig.tight_layout()

fig.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_statistics/histogram_cell_density_single_voxel_2_log.jpeg", dpi=400)

In [None]:
"""
Create the 3D map of KL divergence
"""

min_samples = 5
outlier_threshold = 5 # 4 sigma equivalent

cell_density_vals = np.array([f_conv(atlas_set[case_id]["DIR_PZ"]["CELL_DENSITY_MAP"])[np.where(f_conv(reference_WG))] for case_id in case_id_list])
log_cell_density_vals = np.log([f_conv(atlas_set[case_id]["DIR_PZ"]["CELL_DENSITY_MAP"])[np.where(f_conv(reference_WG))] for case_id in case_id_list])
condition_vals = np.array([f_conv(i)[np.where(f_conv(reference_WG))] for i in atlas_products["VALID_CD_SAMPLES"]])

kl_values_normal = []
kl_values_lognormal = []

for enum, cd_vals in enumerate(np.where(condition_vals, cell_density_vals, -1).T):
        
    vals = cd_vals[cd_vals>0]

    if len(vals)<=min_samples:
        kl_values_normal.append(0)
        continue
        
    # Remove outliers
    test_statistics = np.abs((vals-vals.mean())/vals.std())
    vals = vals[test_statistics <= outlier_threshold]

    # Normal model
    normal_model = norm(vals.mean(),vals.std())
    kde_model = gaussian_kde(vals)
    
    x=np.linspace(0.5*vals.min(),1.5*vals.max(),100)
    
    kl_values_normal.append(kl_divergence(normal_model, kde_model, x))
        
    if enum%10000==0:
        print(enum,end=", ") 
        

kl_values_lognormal = []

for enum, cd_vals in enumerate(np.where(condition_vals, log_cell_density_vals, -1).T):
        
    vals = cd_vals[cd_vals>0]

    if len(vals)<=min_samples:
        kl_values_lognormal.append(0)
        continue
        
    # Remove outliers
    test_statistics = np.abs((vals-np.median(vals))/np.median(np.abs(vals-np.median(vals)))) # MAD
    vals = vals[test_statistics <= 1.48*outlier_threshold]

    # Log-Normal model
    log_normal_model = norm(vals.mean(),vals.std())
    kde_model = gaussian_kde(vals)
    
    x=np.linspace(0.8*vals.min(),1.2*vals.max(),100)
    
    kl_values_lognormal.append(kl_divergence(log_normal_model, kde_model, x))
        
    if enum%10000==0:
        print(enum,end=", ") 

In [None]:
"""
MODEL FITTING ANALYSIS
HISTOGRAM FOR NORMAL MODEL
"""

g_vals = np.array(kl_values_normal)[np.isfinite(kl_values_normal)]
g_vals = g_vals[g_vals>0]

fig, ax = plt.subplots(1,1, figsize=(5,3))

mean = np.mean(g_vals)
sigma = np.std(g_vals)

ax.hist(g_vals, bins=np.linspace(0,0.5,60), ec="white", fc="#410C78", label="Normal Model\n"+r"Mean $\pm$ Std. Dev."+f"\n{mean:.3f} " + r"$\pm$" f" {sigma:.3f} bits")

ax.set_xlabel("KL Divergence [Relative Entropy, bits]")
ax.set_ylabel("Frequency (Number of Voxels) ")

ax.set_xlim(0,0.5)

#ax.grid()
ax.set_axisbelow(True)

ax.legend(title="")

fig.tight_layout()

fig.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_statistics/histogram_kl_divergence_normal.jpeg", dpi=500)

In [None]:
"""
MODEL FITTING ANALYSIS
HISTOGRAM FOR LOG-NORMAL MODEL
"""

g_vals = np.array(kl_values_lognormal)[np.isfinite(kl_values_lognormal)]
g_vals = g_vals[g_vals>0]

fig, ax = plt.subplots(1,1, figsize=(5,3))

mean = np.mean(g_vals)
sigma = np.std(g_vals)

ax.hist(g_vals, bins=np.linspace(0,0.5,60), ec="white", fc="#410C78", label="Log-Normal Model\n"+r"Mean $\pm$ Std. Dev."+f"\n{mean:.3f} " + r"$\pm$" f" {sigma:.3f} bits")

ax.set_xlabel("KL Divergence [Relative Entropy, bits]")
ax.set_ylabel("Frequency (Number of Voxels) ")

ax.set_xlim(0,0.5)

#ax.grid()
ax.set_axisbelow(True)

ax.legend(title="")

fig.tight_layout()

fig.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_statistics/histogram_kl_divergence_log_normal.jpeg", dpi=500)

In [None]:
"""
NC
"""

template_arr = 0.0*f_conv(reference_WG)
for i,j,k, value in zip(*np.where(f_conv(reference_WG)), kl_values_lognormal):
    if value>0:
        #print(i,j,k,value)
        template_arr[i,j,k] = value

kl_img_lognormal = f_conv_r(template_arr)
kl_img_lognormal.CopyInformation(cell_density_mean)

"""
NC
"""

vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho',cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

# vis.add_scalar_overlay(cell_density_mean_smedge, min_value=0.01, name='Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.jet, max_value=35200)

vis.add_scalar_overlay(kl_img_lognormal, show_colorbar=True, discrete_levels=10, min_value=0, max_value=0.25, colormap = plt.cm.viridis, alpha=1, name="KL Divergence\n [Relative Entropy, bits]")

#vis.add_contour({"Prostate":reference_WG, " ":reference_PZ,}, show_legend=False, colormap=plt.cm.spring_r)

fig = vis.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_ortho_slices/statistics/kl_divergence_lognormal.jpeg", dpi=400)

In [None]:
"""
NC
"""

template_arr = 0.0*f_conv(reference_WG)
for i,j,k, value in zip(*np.where(f_conv(reference_WG)), kl_values_normal):
    if value>0:
        #print(i,j,k,value)
        template_arr[i,j,k] = value

kl_img_normal = f_conv_r(template_arr)
kl_img_normal.CopyInformation(cell_density_mean)

"""
NC
"""

vis = ImageVisualiser(mri_mean*0 + 0.1, axis='ortho',cut=cut, window=[0,1], figure_size_in=6)

vis.set_limits_from_label(reference_WG, expansion=5)

# vis.add_scalar_overlay(cell_density_mean_smedge, min_value=0.01, name='Cell Density [mm'+r'$^{-3}$'+']', colormap=plt.cm.jet, max_value=35200)

vis.add_scalar_overlay(kl_img_normal, show_colorbar=True, discrete_levels=10, min_value=0, max_value=0.25, colormap = plt.cm.viridis, alpha=1, name="KL Divergence\n [Relative Entropy, bits]")

#vis.add_contour({"Prostate":reference_WG, " ":reference_PZ,}, show_legend=False, colormap=plt.cm.spring_r)

fig = vis.show()

fig.savefig("../../3_deliverables/Figures/working/atlas_ortho_slices/statistics/kl_divergence_normal.jpeg", dpi=400)