In [1]:
from IPython.display import display
import ipywidgets as widgets
import tbi

from glob import glob
import nilearn.plotting as plotting
import pandas as pd
from os.path import join
from functools import partial
import os.path

class ImageViewer:
    axes = ('x', 'y', 'z')
    
    def __init__(self, filename):
        self.filename = filename
       
    @classmethod
    def min_max(cls, arrays):
        for arr in arrays:
            print("XXX:", arr.min(), ":", arr.max())
            
        amin, amax = arrays[0].min(), arrays[0].max()
        
        for arr in arrays[1:]:
            temp = arr.min()
            
            if temp > amin:
                amin = temp
                
            temp = arr.max()
                
            if temp < amax:
                amax = temp
                
        print("YYY", amin, ":", amax)
        return amin, amax

    @classmethod
    def create_slider(cls, amin, amax, direction):
        return widgets.IntSlider(min=amin, 
                                 max=amax, 
                                 step=1, 
                                 continuous_update=False, 
                                 description=direction)
    
    @classmethod
    def create_layout(cls, height='200px', border='4px solid blue', overflow_y='scroll'):
        return widgets.Layout(height=height, border=border, overflow_y=overflow_y)
    
    def display_image(self):
        from nilearn.plotting.find_cuts import find_cut_slices
        import nibabel  as nib
        
        self.img = nib.load(self.filename)
        self.xxx = 0
        all_cuts = [find_cut_slices(self.img, direction=axis) for axis in self.axes]
        #amin, amax = self.min_max(all_cuts) 
        #self.sliders = [self.create_slider(amin, amax, axis) for axis in self.axes]
        
        self.sliders = []
        for i, axis in enumerate(self.axes):
            amin, amax = all_cuts[i].min(),  all_cuts[i].max()
            self.sliders.append(self.create_slider(amin, amax, axis))
            
        children = [widgets.interactive(self.plot_slice, view=slider) for slider in self.sliders]  
        self.out = widgets.Output(layout=self.create_layout())
        box = widgets.VBox(children=[widgets.HBox(children=children), 
                                     widgets.HBox(children=[create_space_box(), 
                                                            self.out])
                                    ])
        display(box)
    
    def plot_slice(self, view):
        self.xxx += 1
        
        if self.xxx < 3:
            return
        
        with out:
            with self.out:
                cut_coords = tuple(slider.value for slider in self.sliders)
                plotting.plot_img(self.img, display_mode='ortho', cut_coords=cut_coords)
                self.out.clear_output()
                plotting.show()

def display_image(nii_file, use_viewer=True):
    print('Plotting {0}'.format(nii_file))
    if use_viewer:
        viewer = ImageViewer(nii_file)
        viewer.display_image()
    else:
        plotting.plot_img(nii_file)
        plotting.show()
        
                
class WorkflowUtil:
    _outdir = '/data/illustration/py-out4'
    _mni = 'illustration_data/MNI152_T1_1mm_brain.nii.gz'
    _atlas = 'illustration_data/New_atlas_cort_asym_sub.nii.gz'
    _template = 'illustration_data/T_template0.nii.gz'
    
    #dcmfiles = 'illustration_data/dcmfiles/*/'
    dcmfiles = '/data/tbi_registration/test_data/ID_*eba6ca7-7473dee7c1' 
    _cache = {}

    @classmethod
    def outdir(cls):
        return cls._outdir
    
    @classmethod
    def atlas(cls):
        return cls._atlas
    
    @classmethod
    def template(cls):
        return cls._template
    
    @classmethod
    def mni(cls):
        return cls._mni
    
    @classmethod
    def to_cache(cls, key, textfield):
        textfields = cls._cache.get(key)
        
        if textfields is None:
            textfields = [textfield]
            cls._cache[key] = textfields
        else:
            textfields.append(textfield)
    
    @classmethod
    def convert_input_pattern(cls):
        return cls.dcmfiles
    
    @classmethod
    def convert_output_dir(cls):
        return os.path.join(cls.outdir(), 'convert')

    @classmethod
    def preprocessing_input_pattern(cls, pattern='*.nii.gz'):
        return os.path.join(cls.convert_output_dir(), pattern)

    @classmethod
    def preprocessing_output_dir(cls):
        return os.path.join(cls.outdir(), 'preprocessing')
    
    @classmethod
    def skull_strip_input_pattern(cls, pattern='*_normalizedWarped.nii.gz'):
        return os.path.join(cls.preprocessing_output_dir(), pattern)
    
    @classmethod
    def skull_strip_output_dir(cls):
        return os.path.join(cls.outdir(), 'brains')
    
    @classmethod
    def segmentation_input_pattern(cls, pattern='*.nii.gz'):
        return os.path.join(cls.skull_strip_output_dir(), pattern)

    @classmethod
    def segmentation_output_dir(cls):
        return os.path.join(cls.outdir(), 'segmentation')
    
    @classmethod
    def geo_input_pattern(cls, pattern='SEG/*/*.nii.gz'):
        return os.path.join(cls.segmentation_output_dir(), pattern)

    @classmethod
    def geo_output_dir(cls):
        return os.path.join(cls.outdir(), 'label_geometry_measures')
    
    @classmethod
    def stat_input_pattern(cls, pattern='REGIS/Affine2SyN/*affine2Syn1Warp.nii.gz'):
        return os.path.join(cls.segmentation_output_dir(), pattern)

    @classmethod
    def stat_output_dir(cls):
        return os.path.join(cls.outdir(), 'image_intensity_stat')

def create_html(text, layout=widgets.Layout(height='45px', width='90%', size='20')):
    space_box = widgets.Box(layout=widgets.Layout(height ='25px', width='90%')) 
    return widgets.Box([widgets.HTML(text, layout=layout), space_box])
    
def create_textfield(value, layout=widgets.Layout(width = "60%")):
    textfield = widgets.Text(layout=layout)
    
    if callable(value):
        textfield.value = value()
        WorkflowUtil.to_cache(value, textfield)
    else:
        textfield.value = value
        
    return textfield

def create_checkbox(description):
    return widgets.Checkbox(value=False, description=description, disabled=False, indent=False)
    
def create_label(value, layout=widgets.Layout(width = '18%')):
    return widgets.Label(value=value, layout=layout)

def create_checkbox_box(description):
    checkbox = create_checkbox(description)
    box = widgets.Box([create_label(''), checkbox])
    return box, checkbox
    
def create_textfield_box(label, value):
    textfield = create_textfield(value)
    box = widgets.Box([create_label(label), textfield])
    return box, textfield

def create_space_box():
    return widgets.Box([create_label('')])

def create_button_box(*buttons):
    temp = [create_label('')]
    
    for button in buttons:
        temp.append(button)
        
    return widgets.Box(children=temp)

def create_panel(*boxes):
    buttons = boxes[-1]
    temp = list(boxes[:-1])
    
    for i in range(5 - len(temp)):
        temp.append(create_space_box())
    
    temp.append(widgets.Box(layout=widgets.Layout(height ='4px')))
    temp.append(buttons)
    temp.append(widgets.Box(layout=widgets.Layout(height ='3px')))
    return widgets.VBox(children=temp)

def apply_inputs(output_dir_textfield, template_textfield, atlas_textfield, mni_textfield, b):
    out.clear_output()
    
    with out:
        WorkflowUtil._outdir = output_dir_textfield.value
        WorkflowUtil._template = template_textfield.value
        WorkflowUtil._atlas = atlas_textfield.value
        WorkflowUtil._mni = mni_textfield.value
       
        for func, lst in WorkflowUtil._cache.items():
            for textfield in lst:
                textfield.value = func()
        
def run_convert(pattern_textfield, use_dcm2niix_checkbox, output_dir_textfield, b):
    out.clear_output()
   
    with out:
        args = []
        
        if use_dcm2niix_checkbox.value:
            args.append('--use-dcm2niix')
            
        args.extend([pattern_textfield.value, output_dir_textfield.value])
        tbi.convert(args)
    
def show_convert(output_dir_textfield, b):
    out.clear_output()
    with out:
        nii_files = glob(join(output_dir_textfield.value, "*.nii.gz"))
        
        for nii_file in nii_files:
            display_image(nii_file)

def run_preprocessing(mni_textfield, 
                           pattern_textfield, 
                           output_dir_textfield, 
                           run_button):
    out.clear_output()
    with out:
        args = ['-m', mni_textfield.value, pattern_textfield.value, output_dir_textfield.value]
        tbi.preprocessing(args)

            
def show_preprocessing(output_dir_textfield, b):
    out.clear_output()
    with out:
        nii_files = glob(join(output_dir_textfield.value, "*.nii.gz"))
        for nii_file in nii_files:
            display_image(nii_file) 
        
            
def run_skull_strip(pattern_textfield, output_dir_textfield, b):
    out.clear_output()
    
    with out:
        tbi.skull_strip([pattern_textfield.value, output_dir_textfield.value])

def show_skull_strip(output_dir_textfield, b):
    out.clear_output()
    with out:
        nii_files = glob(join(output_dir_textfield.value, "*.nii.gz"))
        
        for nii_file in nii_files:
            display_image(nii_file)
            
def run_segmentation(template_textfield, 
                     atlas_textfield, 
                     pattern_textfield, 
                     output_dir_textfield, 
                     run_button):
    out.clear_output()
    with out:
        args = ['-t',
                template_textfield.value,
                '-a',
                atlas_textfield.value,
                pattern_textfield.value,
                output_dir_textfield.value]
        tbi.segmentation(args)

            
def show_segmentation(output_dir_textfield, b):
    out.clear_output()
    with out:
        nii_files = glob(join(output_dir_textfield.value, 'SEG/*/*.nii.gz'))
        for nii_file in nii_files:
            display_image(nii_file)
            
def run_label_geometry_measures(pattern_textfield, output_dir_textfield, b):
    out.clear_output()
    with out:
        args = [pattern_textfield.value, output_dir_textfield.value]
        tbi.label_geometry_measures(args)

def show_label_geometry_measures(output_dir_textfield, b):
    out.clear_output()
    with out:
        label_geometry_measures_dir = output_dir_textfield.value
        csv_files = glob(join(label_geometry_measures_dir, "*.csv"))

        for csv_file in csv_files:
            df = pd.read_csv(csv_file)
            display(df)
    
def run_intensity_stat(atlas_textfield, 
                           pattern_textfield, 
                           output_dir_textfield, 
                           run_button):
    out.clear_output()
    with out:
        args = ['-a', atlas_textfield.value, pattern_textfield.value, output_dir_textfield.value]
        tbi.image_intensity_stat(args)

            
def show_intensity_stat(output_dir_textfield, b):
    out.clear_output()
    with out:
        csv_files = glob(join(output_dir_textfield.value, "*.csv"))

        for csv_file in csv_files:
            df = pd.read_csv(csv_file)
            display(df)

def create_input_box():
    html = create_html("""Workflow Inputs.""")
    
    box1, outdir_textfield = create_textfield_box('Ouput Directory:', WorkflowUtil.outdir())
    box2, template_textfield = create_textfield_box('Template File:', WorkflowUtil.template())
    box3, atlas_textfield = create_textfield_box('Atlas File:', WorkflowUtil.atlas())
    box4, mni_textfield = create_textfield_box('MNI File:', WorkflowUtil.mni())

    apply_button = widgets.Button(description="Apply", button_style='success')
    state = (outdir_textfield, template_textfield, atlas_textfield, mni_textfield)
    apply_button.on_click(partial(apply_inputs, *state))
    button_box = create_button_box(apply_button)
    
    return create_panel(html, box1, box2, box3, box4, button_box)

def create_convert_box():
    html = create_html("""Convert DCM files to nii.""")
   
    box1, pattern_textfield = create_textfield_box('Input Pattern:', WorkflowUtil.convert_input_pattern)
    box2, checkbox = create_checkbox_box('Use dcm2niix')
    box3, outdir_textfield = create_textfield_box('Output Directory:', WorkflowUtil.convert_output_dir)

    run_button = widgets.Button(description="Run", button_style='success')
    show_button = widgets.Button(description="Show", button_style='info')
    run_button.on_click(partial(run_convert, pattern_textfield, checkbox, outdir_textfield))
    show_button.on_click(partial(show_convert, outdir_textfield))
    button_box = create_button_box(run_button, show_button)
    
    return create_panel(html, box1, box2, box3, button_box)

def create_preprocessing_box():
    html = create_html("""Preprocessing""")
    
    box1, pattern_textfield = create_textfield_box('Input Pattern:', WorkflowUtil.preprocessing_input_pattern)
    box2, mni_textfield = create_textfield_box('MNI File:', WorkflowUtil.mni)
    box3, outdir_textfield = create_textfield_box('Output Directory:', WorkflowUtil.preprocessing_output_dir)

    run_button = widgets.Button(description='Run', button_style='success')
    show_button = widgets.Button(description='Show', button_style='info')
    run_button.on_click(partial(run_preprocessing, 
                             mni_textfield, 
                             pattern_textfield, 
                             outdir_textfield))
    show_button.on_click(partial(show_preprocessing, outdir_textfield))
    button_box = create_button_box(run_button, show_button)
    
    return create_panel(html, box1, box2, box3, button_box)

def create_skull_strip_box():
    html = create_html("""Skull strip""")
    
    box1, pattern_textfield = create_textfield_box('Input Pattern:', WorkflowUtil.skull_strip_input_pattern)
    box2, outdir_textfield = create_textfield_box('Output Directory:', WorkflowUtil.skull_strip_output_dir)

    run_button = widgets.Button(description="Run", button_style='success')
    show_button = widgets.Button(description="Show", button_style='info')
    run_button.on_click(partial(run_skull_strip, pattern_textfield, outdir_textfield))
    show_button.on_click(partial(show_skull_strip, outdir_textfield))
    button_box = create_button_box(run_button, show_button)

    return create_panel(html, box1, box2, button_box)

def create_segmentation_box():
    html = create_html("""Image Intensity Stat Jac uses results from REGIS/Affine2SyN""")
    
    box1, pattern_textfield = create_textfield_box('Input Pattern:', WorkflowUtil.segmentation_input_pattern)
    box2, template_textfield = create_textfield_box('Template File:', WorkflowUtil.template)
    box3, atlas_textfield = create_textfield_box('Atlas File:', WorkflowUtil.atlas)
    box4, outdir_textfield = create_textfield_box('Output Directory:', WorkflowUtil.segmentation_output_dir)

    run_button = widgets.Button(description='Run', button_style='success')
    show_button = widgets.Button(description='Show', button_style='info')
    run_button.on_click(partial(run_segmentation, 
                                template_textfield,  
                                atlas_textfield, 
                                pattern_textfield, 
                                outdir_textfield))
    show_button.on_click(partial(show_segmentation, outdir_textfield))
    button_box = create_button_box(run_button, show_button)
    
    return create_panel(html, box1, box2, box3, box4, button_box)

def create_geo_box():
    html = create_html("""Label Geometry Measures uses output from segmentation SEG files.""")
    box1, pattern_textfield = create_textfield_box('Input Pattern:', WorkflowUtil.geo_input_pattern)
    box2, outdir_textfield = create_textfield_box('Output Directory:', WorkflowUtil.geo_output_dir)

    run_button = widgets.Button(description="Run", button_style='success')
    show_button = widgets.Button(description="Show", button_style='info')
    run_button.on_click(partial(run_label_geometry_measures, pattern_textfield,
                                outdir_textfield))
    show_button.on_click(partial(show_label_geometry_measures, outdir_textfield))
    button_box = create_button_box(run_button, show_button)

    return create_panel(html, box1, box2, button_box)

def create_stat_box():
    html = create_html("""Image Intensity Stat Jac uses results from REGIS/Affine2SyN""")
    
    box1, pattern_textfield = create_textfield_box('Input Pattern:', WorkflowUtil.stat_input_pattern)
    box2, atlas_textfield = create_textfield_box('Atlas File:', WorkflowUtil.atlas)
    box3, outdir_textfield = create_textfield_box('Output Directory:', WorkflowUtil.stat_output_dir)

    run_button = widgets.Button(description='Run', button_style='success')
    show_button = widgets.Button(description='Show', button_style='info')
    run_button.on_click(partial(run_intensity_stat, 
                             atlas_textfield, 
                             pattern_textfield, 
                             outdir_textfield))
    show_button.on_click(partial(show_intensity_stat, outdir_textfield))
    button_box = create_button_box(run_button, show_button)
    
    return create_panel(html, box1, box2, box3, button_box)

#outputs = {i: widgets.Output(layout=widgets.Layout(height='300px', overflow_y='auto')) for i in range(0,7)}
outputs = {i: widgets.Output(layout=widgets.Layout(height='300px', border='2px solid black', 
                                                   overflow_y='scroll')) for i in range(0,7)}

out = None

def print_on_select(widget):
    tab_idx = widget['new']
    global out
    
    out = outputs[tab_idx]

def boxit(box, i):
    output = outputs[i]
    items = [box, output]
    #return widgets.VBox(children=items, layout={'border': '2px solid black'})
    return widgets.VBox(children=items)

box1 = boxit(create_input_box(), 0)
box2 = boxit(create_convert_box(), 1)
box3 = boxit(create_preprocessing_box(), 2)
box4 = boxit(create_skull_strip_box(), 3)
box5 = boxit(create_segmentation_box(), 4)
box6 = boxit(create_geo_box(), 5)
box7 = boxit(create_stat_box(), 6)

tab = widgets.Tab(children=(box1, box2, box3, box4, box5, box6, box7))
tab.set_title(0, 'Inputs')
tab.set_title(1, 'Convert')
tab.set_title(2, 'Preprocessing')
tab.set_title(3, 'SkullStrip')
tab.set_title(4, 'Segmentation')
tab.set_title(5, 'LabelGeoMeasures')
tab.set_title(6, 'ImageIntesityStats')
tab.observe(print_on_select, names='selected_index')
display(tab)

Tab(children=(VBox(children=(VBox(children=(Box(children=(HTML(value='Workflow Inputs.', layout=Layout(height=…