In [1]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
plt.figure(dpi=700)
%config InlineBackend.figure_format = 'retina'
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

class MidpointNormalize(matplotlib .colors.Normalize):
    def __init__(self, vmin, vmax, midpoint=0, clip=False):
        self.midpoint = midpoint
        matplotlib .colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        normalized_min = max(0, 1 / 2 * (1 - abs((self.midpoint - self.vmin) / (self.midpoint - self.vmax))))
        normalized_max = min(1, 1 / 2 * (1 + abs((self.vmax - self.midpoint) / (self.midpoint - self.vmin))))
        normalized_mid = 0.5
        x, y = [self.vmin, self.midpoint, self.vmax], [normalized_min, normalized_mid, normalized_max]
        return np.ma.masked_array(np.interp(value, x, y))

def tutorial_Case_select(case):

    if case==1:
        heatmap_loc = './data/MRI_explainable_patient_116_S_0392_2008-06-06_13_32_15_0_S51194__gr_0_pr0.nii'
        img_loc = './data/116_S_0392_FreeSurfer_Cross-Sectional_Processing_brainmask_2008-06-06_13_32_15.0_S51194_mri_brainmask.nii'
    if case==2:    
        heatmap_loc = './data/MRI_explainable_patient_136_S_0194_2006-04-10_17_41_48_0_S13178__gr_0_pr0.nii'
        img_loc = './data/136_S_0194_FreeSurfer_Cross-Sectional_Processing_brainmask_2006-04-10_17_41_48.0_S13178_mri_brainmask.nii'
    if case==3:    
        heatmap_loc = './data/MRI_explainable_patient_114_S_0173_2009-03-02_09_37_22_0_S64286__gr_1_pr1.nii'
        img_loc = './data/ADNI_114_S_0173_FreeSurfer_Cross-Sectional_Processing_brainmask_2009-03-02_09_37_22.0_S64286_mri_brainmask.nii'
    if case==4:    
        heatmap_loc = './data/MRI_explainable_patient_141_S_1137_2006-12-19_16_13_22_0_S24301__gr_0_pr0.nii'
        img_loc = './data/141_S_1137_FreeSurfer_Cross-Sectional_Processing_brainmask_2006-12-19_16_13_22.0_S24301_mri_brainmask.nii'
    if case==5:    
        heatmap_loc = './data_gui/TP_maps_avg.nii'
        img_loc = './MNI_atlas_templates/MNI_T1.nii'
    if case==6:    
        heatmap_loc = './data_gui/TN_maps_avg.nii'
        img_loc = './MNI_atlas_templates/MNI_T1.nii'
        
        
    return(heatmap_loc,img_loc)

#@interact(slice1=(10,50),slice2=(10,50))
def UAI_plots(slice1,slice2,vmin=-0.4,vmax=+0.4,alpha=0.3,heat_on=False):

    img = nib.load(img_loc)
    mri = img.get_fdata()
    img2 = nib.load(heatmap_loc)
    heatmap = img2.get_fdata()
    del img,img2

    norm = MidpointNormalize(vmin=vmin, vmax=vmax, midpoint=0)
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["b","black","r"])

    slice1=slice1
    slice2=slice2
    
    if heat_on==False:
        alpha=1

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    if heat_on==True:
        im=axes[0].imshow(heatmap[:,slice1,:], cmap=cmap,norm=norm)
    im=axes[0].imshow(mri[:,slice1,:], cmap='gray',alpha=alpha)

    if heat_on==True:
        im=axes[1].imshow(heatmap[slice2,:,:], cmap=cmap,norm=norm)
    axes[1].axhline(slice1,ls='--',c='yellow')
    im=axes[1].imshow(mri[slice2,:,:], cmap='gray',alpha=alpha)

    #fig.colorbar(im, ax=axes[1])

    axes[0].set_title('L',size=15)
    axes[0].set_xlabel('R',size=15)
    axes[1].set_title('S',size=15)
    axes[1].set_xlabel('I',size=15)

    for i in range(2):    
        axes[i].set_xticks([])
        axes[i].set_yticks([])

<Figure size 4200x2800 with 0 Axes>

In [2]:
heatmap_loc,img_loc = tutorial_Case_select(2)
interact(UAI_plots,slice1=(0,110),slice2=(0,110),vmin=(-1,0,0.1),vmax=(0,1,0.1),alpha=(0,1,0.1),heat_on=True)

interactive(children=(IntSlider(value=55, description='slice1', max=110), IntSlider(value=55, description='sliâ€¦

<function __main__.UAI_plots(slice1, slice2, vmin=-0.4, vmax=0.4, alpha=0.3, heat_on=False)>