In [None]:
# This needs k3d, vtk, ipyfilechooser
import os
import numpy as np
import h5py as h5
import ipywidgets as widgets
import ipyfilechooser
import k3d
from IPython.core.display import display, update_display, HTML

from pynx.utils.plot_utils import complex2rgbalin

display(HTML("<style>.container { width:100% !important; }</style>"))


In [None]:
class WidgetCDI(widgets.Box):
    def __init__(self, cxi_filename):
        super(WidgetCDI, self).__init__()

        d = h5.File(cxi_filename,mode='r')['entry_1/data_1/data'][()]
        nz, ny, nx = d.shape
        self.d = d
        d_abs = abs(d)
        vmin, vmax = d_abs.min(), d_abs.max()
        print(d.shape, vmin, vmax)
        
        self.toggle_mode = widgets.ToggleButtons(options=['Volume','IsoSurface'])

        self.toggle_grid = widgets.ToggleButton(value=True, description='Grid', tooltips='Display grid')
        #self.toggle_roi = widgets.ToggleButton(value=False, description='ROI', tooltips='ROI selection')
        #self.toggle_shadow= widgets.ToggleButton(value=False, description='Shadow', tooltips='Shadow ?')
        #self.toggle_rotate= widgets.ToggleButton(value=False, description='Rotate', tooltips='Rotate ?')
        hbox_toggle = widgets.HBox([self.toggle_grid])

        self.vol_alpha_coef = widgets.IntSlider(value=100, min=10,max=1000,step=10, description='Alpha',
                                            disabled=False, continuous_update=True, orientation='horizontal',
                                            readout=True)
        self.background = widgets.IntSlider(value=255, min=0,max=255,step=5, description='Background',
                                            disabled=False, continuous_update=True, orientation='horizontal',
                                            readout=True)
        self.colormap = widgets.Dropdown(options=['Cool', 'Gray', 'Gray_r', 'Hot', 'Hsv', 'Inferno', 'Jet', 'Plasma', 'Rainbow', 'Viridis'], value='Jet', description='Colors:')
        self.vol_range = widgets.FloatRangeSlider(value=[vmin+(vmax-vmin)/5, vmin+(vmax-vmin)*0.8],
                                          min=vmin,
                                          max=vmax,
                                          step=(vmax-vmin)/100,
                                          description='Range:',
                                          disabled=False,
                                          continuous_update=False,
                                          orientation='horizontal',
                                          readout=True,
                                          #readout_format='.1f'
                                          )
        clip_label = widgets.Label(value='Clipping planes:')
        # X
        self.clipx = widgets.ToggleButton(description='+X', disabled=False, button_style='')
        self.clipx_slider = widgets.FloatSlider(value=d.shape[2]/2, min=0,max=d.shape[2],step=d.shape[2]/100, description='',
                                                 disabled=False, continuous_update=True, orientation='horizontal',
                                                 readout=True)
        clipx_hbox = widgets.HBox([self.clipx, self.clipx_slider])

        self.clipmx = widgets.ToggleButton(description='-X', disabled=False, button_style='')
        self.clipmx_slider = widgets.FloatSlider(value=d.shape[2]/2, min=0,max=d.shape[2],step=d.shape[2]/100, description='',
                                                 disabled=False, continuous_update=True, orientation='horizontal',
                                                 readout=True)
        clipmx_hbox = widgets.HBox([self.clipmx, self.clipmx_slider])
        # Y
        self.clipy = widgets.ToggleButton(description='+Y', disabled=False, button_style='')
        self.clipy_slider = widgets.FloatSlider(value=d.shape[1]/2, min=0,max=d.shape[1],step=d.shape[1]/100, description='',
                                                 disabled=False, continuous_update=True, orientation='horizontal',
                                                 readout=True)
        clipy_hbox = widgets.HBox([self.clipy, self.clipy_slider])

        self.clipmy = widgets.ToggleButton(description='-Y', disabled=False, button_style='')
        self.clipmy_slider = widgets.FloatSlider(value=d.shape[1]/2, min=0,max=d.shape[1],step=d.shape[1]/100, description='',
                                                 disabled=False, continuous_update=True, orientation='horizontal',
                                                 readout=True)
        clipmy_hbox = widgets.HBox([self.clipmy, self.clipmy_slider])

        # Z
        self.clipz = widgets.ToggleButton(description='+Z', disabled=False, button_style='')
        self.clipz_slider = widgets.FloatSlider(value=d.shape[0]/2, min=0,max=d.shape[0],step=d.shape[0]/100, description='',
                                                 disabled=False, continuous_update=True, orientation='horizontal',
                                                 readout=True)
        clipz_hbox = widgets.HBox([self.clipz, self.clipz_slider])

        self.clipmz = widgets.ToggleButton(description='-Z', disabled=False, button_style='')
        self.clipmz_slider = widgets.FloatSlider(value=d.shape[0]/2, min=0,max=d.shape[0],step=d.shape[0]/100, description='',
                                                 disabled=False, continuous_update=True, orientation='horizontal',
                                                 readout=True)
        clipmz_hbox = widgets.HBox([self.clipmz, self.clipmz_slider])

        self.fc = ipyfilechooser.FileChooser('/Users/vincent/data/', filter_pattern=['*.cxi','*.h5'])
        self.fc.register_callback(self.change_file)


        self.vbox = widgets.VBox([self.toggle_mode, hbox_toggle, self.background, self.vol_alpha_coef, self.colormap, self.vol_range,
                                 clip_label, clipx_hbox, clipmx_hbox,clipy_hbox, clipmy_hbox, clipz_hbox, clipmz_hbox, self.fc])
        
        self.toggle_mode.observe(self.update)
        self.toggle_grid.observe(self.update)
        self.vol_alpha_coef.observe(self.update)
        self.background.observe(self.update)
        self.colormap.observe(self.update)
        self.vol_range.observe(self.update)
        self.clipx.observe(self.update)
        self.clipx_slider.observe(self.update)
        self.clipmx.observe(self.update)
        self.clipmx_slider.observe(self.update)
        self.clipy.observe(self.update)
        self.clipy_slider.observe(self.update)
        self.clipmy.observe(self.update)
        self.clipmy_slider.observe(self.update)
        self.clipz.observe(self.update)
        self.clipz_slider.observe(self.update)
        self.clipmz.observe(self.update)
        self.clipmz_slider.observe(self.update)

        self.output_view = widgets.Output(layout={'border': '1px solid black', 'width':'1600px'})

        self.hbox = widgets.HBox([self.output_view, self.vbox])
        
        self.children = [self.hbox]
        self.plot = None
        self.create_plot()

    def create_plot(self):
        if self.plot is not None:
            self.plot.close()
        d = self.d
        d_abs = abs(d)
        vmin, vmax = d_abs.min(), d_abs.max()
        print(d.shape, vmin, vmax)
        self.plot = k3d.plot(camera_auto_fit=True, height=900,
                             grid_auto_fit=True, grid_visible=True,
                             #grid=(0,0,0,nz,ny,nx)
                            )
        self.plot.lighting = 1
        
        smax =max(d.shape)
        self.volume = k3d.volume(
            abs(d),
            alpha_coef=self.vol_alpha_coef.value,
            #shadow='dynamic',
            #samples=600,
            #shadow_res=128,
            shadow_delay=50,
            color_range=self.vol_range.value,
            #color_map=(np.array(k3d.colormaps.matplotlib_color_maps.Gist_heat).reshape(-1,4) * np.array([1,1.75,1.75,1.75])).astype(np.float32),
            color_map=k3d.colormaps.matplotlib_color_maps.Jet,
            #compression_level=9,
            bounds=[0, d.shape[2],0, d.shape[1],0, d.shape[0]],
        )
        #self.volume.transform.bounds = 
        self.plot += self.volume

        self.plt_iso = k3d.marching_cubes(abs(self.d),compression_level=9,
                     xmin=0, xmax=d.shape[2],
                     ymin=0, ymax=d.shape[1],
                     zmin=0, zmax=d.shape[0], 
                     level=float(vmax),
                     flat_shading=False,
                     opacity=1)
        self.plt_iso.visible = False
        self.plot += self.plt_iso


        
        with self.output_view:
            self.plot.display()

    def update(self, k=None):
        self.plot.background_color = self.background.value * (1 + 256 + 256**2)
        self.volume.color_map = eval('k3d.colormaps.matplotlib_color_maps.%s' % self.colormap.value)
        self.plot.grid_visible =  self.toggle_grid.value
        if self.toggle_mode.value == 'Volume':
            self.volume.alpha_coef = self.vol_alpha_coef.value
            self.volume.visible = True
            self.plt_iso.visible = False
            self.volume.color_range = self.vol_range.value
        else:
            self.volume.visible = False
            if not self.plt_iso.visible:
                self.vol_alpha_coef.value = 1000
            with self.plt_iso.hold_trait_notifications():  # Does that work ?
                a,r,g,b = w.volume.color_map.reshape(-1,4)[-1]
                self.plt_iso.color = int(round(r*255*256**2))+int(round(g*255*256))+int(round(b*255))
                #print(r,g,b,a, self.plt_iso.color)
                self.plt_iso.opacity = self.vol_alpha_coef.value/1000
                self.plt_iso.level = self.vol_range.value[0]
            # Make visible last, to avoid extra updates
            self.plt_iso.visible = True
        # Clipping Planes
        clipping_planes = []
        if self.clipx.value:
            clipping_planes += [[-1, 0, 0, self.clipx_slider.value]]
        if self.clipmx.value:
            clipping_planes += [[1, 0, 0, -self.clipmx_slider.value]]
        if self.clipy.value:
            clipping_planes += [[0, -1, 0, self.clipy_slider.value]]
        if self.clipmy.value:
            clipping_planes += [[0, 1, 0, -self.clipmy_slider.value]]
        if self.clipz.value:
            clipping_planes += [[0, 0, -1, self.clipz_slider.value]]
        if self.clipmz.value:
            clipping_planes += [[0, 0, 1, -self.clipmz_slider.value]]
        self.plot.clipping_planes = clipping_planes
    
    def change_file(self, v):
        print(v)
        print('Loading:', self.fc.selected)
        self.volume.visible = False
        self.plt_iso.visible = False
                
        
        try:
            d = h5.File(self.fc.selected,mode='r')['entry_1/data_1/data'][()]
            self.d = d
            d_abs = abs(d)
            nz, ny, nx = d.shape
            vmin, vmax = d_abs.min(), d_abs.max()
            print(d.shape, vmin, vmax)
            
            self.volume.volume = d_abs
            self.plt_iso.scalar_field = d_abs
            
            # Update control values
            self.vol_range.min, self.vol_range.max = vmin, vmax
            self.vol_range.value = [vmin+(vmax-vmin)/5, vmin+(vmax-vmin)*0.8]
            self.vol_range.step = (vmax-vmin)/100
            #print(self.vol_range.step,self.vol_range.value, self.vol_range.min, self.vol_range.max )
            
            # TODO: update bounds, scale for volume. How ? Plot limits ? Use scale ? Re-create the plot entirely ?
            w.volume.transform = k3d.transform(bounds=(0,nx,0,ny,0,nz))
            w.plt_iso.transform = k3d.transform(bounds=(0,nx,0,ny,0,nz))
            
            self.clipmx_slider.max, self.clipx_slider.max = nx,nx
            self.clipmy_slider.max, self.clipy_slider.max = ny,ny
            self.clipmz_slider.max, self.clipz_slider.max = nz,nz
            if self.clipmx_slider.value > nx : self.clipmx_slider.value = nx
            if self.clipx_slider.value > nx : self.clipx_slider.value = nx
            if self.clipmy_slider.value > ny : self.clipmy_slider.value = ny
            if self.clipy_slider.value > ny : self.clipy_slider.value = ny
            if self.clipmz_slider.value > nz : self.clipmz_slider.value = nz
            if self.clipz_slider.value > nz : self.clipz_slider.value = nz
                
        except:
            print("Failed to load file - is this a result CXI result file from a 3D CDI analysis ?")
        self.toggle_mode.value = 'Volume'  # Safer than iso-surface
        self.create_plot()
        self.update()

w = WidgetCDI('/Users/vincent/data/201702-BraggCDI-Pt/latest.cxi')
#w = WidgetCDI('/Users/vincent/data/2020-Coccolithes/RetFen4036_cand1_3D-2020-02-04T08-39-44_Run0002_LLKf003.8944_LLK004.0703_SupportThreshold0.11422.cxi')
display(w)

In [None]:
# TODO :
# exchange X and Z to have usual convention ? Needed ?
# animation / rotate (w.plot.start_auto_play ? w.plot.camera_animation ?)
# Use a range slider for clipping planes - cleaner
# Alpha using a log slider ?
# Color iso-surface with phase ??
# Color volume with phase ??
# Clean all - don't put everything in __init__ and update() !