In [1]:
from ipywidgets import FloatSlider, ValueWidget, Layout
import ipywidgets as widgets

from bqplot import OrdinalScale, LinearScale, Bars, Axis, Figure

import numpy as np

from sidepanel import SidePanel

from regulus.utils import io
from regulus.topo import * 
from regulus.alg import *
from regulus.measures import *
from regulus.models import *
from regulus.tree import *

from ipyregulus import TreeWidget, TreeView, DataWidget
from ipyregulus.sensitivity_view import sensitivity_view
from ipyregulus.filters import * 


In [2]:
data = io.load('data/gauss4')



### Helper functions

#### view

In [3]:
def update_view(view, f):
    def _inner():
        view.show = filter_tree(view.tree, f)
    return _inner 

In [4]:
def view(data, attr='span', func=lambda x, v: v <= x, title=''):
    v = TreeView(data, attr=attr)
    f = AttrFilter(attr=attr, func=func)
    m = Monitor(f, func=update_view(v, f))
    p = SidePanel(title=title)
    with p:
        display(v,f) 
    return v, f, p

### Initial tree

In [5]:
tw = TreeWidget(data.tree)
fitness = view(tw, attr='fitness', title='fitness')
fitness[0].details = []

### Test custom widget

In [None]:
box_layout = Layout(overflow_y='scroll', overflow='scroll')
vbox = widgets.VBox([], Layout=box_layout)

p2 = SidePanel(title='Sensitivity')
with p2:
    display(vbox)
    
def on_selection_changed(change):
    children = []
    for node_id in sorted(fitness[0].details):
        coefficients = tw.tree.regulus.attr['linear'].values()[node_id].coef_
        intercept = tw.tree.regulus.attr['linear'].values()[node_id].intercept_

        y_mag = np.max(np.fabs(coefficients))
        x_ord = OrdinalScale()
        y_sc = LinearScale(min=-y_mag, max=y_mag)
        
        colors = []
        for coefficient in coefficients:
            if coefficient < 0:
                colors.append('#1f78b4')
            else:
                colors.append('#a6cee3')

        x_labels = ['x{}'.format(x) for x in range(len(coefficients))]
        bar = Bars(x=x_labels,
                   y=coefficients,
                   scales={'x': x_ord, 'y': y_sc},
                   type='stacked',
                   colors=colors)

        ax_x = Axis(scale=x_ord, grid_lines='solid', label='Dimension')
        ax_y = Axis(scale=y_sc, orientation='vertical', grid_lines='solid', label='Linear Coefficient')

        fig_layout = widgets.Layout(width='auto', height='auto')
        fig = Figure(marks=[bar],
                     axes=[ax_x, ax_y],
                     title='Sensitivity for Node: {}'.format(node_id))

        fig.layout.max_width= '200px'
        fig.layout.max_height= '200px'
        children.append(fig)
#         children.append(widgets.Text(str(intercept), disabled=True))
        children.append(widgets.HTML(value="<hr>"))
    vbox.children = tuple(children)
    
fitness[0].observe(on_selection_changed, names='details')  

In [None]:
change = {'new': fitness[0].details}
print(fitness[0].details)
on_selection_changed(change)

In [6]:
sensitivity = sensitivity_view(fitness[0])

p2 = SidePanel(title='Sensitivity')
with p2:
    display(sensitivity)