# Plot COF Properties

In [None]:
from __future__ import print_function
import collections

from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.data.parameter import ParameterData

import ipywidgets as ipw
from IPython.display import display, clear_output
import numpy as np

%matplotlib notebook
import matplotlib.pyplot as plt

In [None]:
def search():
    """Query AiiDA database"""
    
    plot_info.value = "Searching..."
    
    filters = {}
    pk_list = inp_pks.value.strip().split()
    if pk_list:
        filters['id'] = {'in': pk_list}
            
    def add_range_filter(bounds, label):
        filters['attributes.'+label] = {'and':[{'>=':bounds[0]}, {'<':bounds[1]}]}
        
    for k, btn in sliders_dict.iteritems():
        add_range_filter(btn.value, k)
    
    
    projections = [
        'attributes.'+inp_x.value, 
        'attributes.'+inp_y.value, 
        'attributes.'+inp_clr.value, 
        'uuid',
    ]
    
    if len(set(projections)) != len(projections):
        plot_info.value = "Please select different quantities for X, Y and Color"
        return
  
    qb = QueryBuilder()
    #print (filters)
    #print (projections)
    qb.append(ParameterData, filters=filters, project=projections)
    nresults = qb.count()
    if nresults == 0:
        plot_info.value = "No results found."
        return

    plot_info.value = "{} results found. Plotting...".format(nresults)
    
    # x,y position
    x, y, clrs, uuids = zip(*qb.all())
    x = map(float, x)
    y = map(float, y)
    uuids = map(str, uuids)

    title = "{} vs {}".format(inp_y.label, inp_x.label)    
    xlabel = "{} [{}]".format(inp_x.label, quantities[inp_x.value]['unit'])
    ylabel = "{} [{}]".format(inp_y.label, quantities[inp_y.value]['unit'])

    # colors
    
    if inp_clr.value == 'bond_type':
        clrs = map(lambda clr: bondtypes.index(clr), clrs)
        clr_label = "Bond type"
    else:           
        clrs = map(float, clrs)
        clr_label = "{} [{}]".format(inp_clr.label, quantities[inp_clr.value]['unit'])
        
    # plot   
    if btn_mode.value is 'plotly':
        plot_plotly(x, y, uuids, clrs, title=title, xlabel=xlabel, ylabel=ylabel, clr_label=clr_label)
    elif btn_mode.value is 'plotly_links':
        plot_plotly(x, y, uuids, clrs, title=title, xlabel=xlabel, ylabel=ylabel, 
                    clr_label=clr_label, with_links=True)
    elif btn_mode.value is 'bokeh':
        plot_bokeh(x, y, uuids, clrs, title=title, xlabel=xlabel, ylabel=ylabel, 
                    clr_label=clr_label)
    else:
        plot_matplotlib(x, y, uuids, clrs, title=title, 
                        xlabel=xlabel, ylabel=ylabel, clr_label=clr_label)

    plot_info.value = "Plotted {} results.".format(nresults)
   

def discrete_cmap(colors):
    """Return discrete colormap for matplotlib.
    
    :param: list of RGB colors
    """
    N = len(colors)
    base = plt.cm.get_cmap('cubehelix')
    cmap_name = "discrete_" + str(N)
    return base.from_list(cmap_name, colors, N)

def plot_matplotlib(x, y, uuids, clrs, title=None, xlabel=None, ylabel=None, clr_label=None):
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
                
    if inp_clr.value == 'bond_type':
        cmap = discrete_cmap(bondtype_colors)
    else:
        cmap = 'rainbow'
        
    plt.scatter(x, y, c=clrs, cmap=cmap, s=5, linewidth=0)
    colorbar = plt.colorbar(label=clr_label)

    if inp_clr.value == 'bond_type':
        N = len(bondtypes)
        colorbar.set_ticks(range(N))
        colorbar.set_ticklabels(bondtypes)
        # center ticks
        plt.clim(-0.5, N-0.5)


# See https://plot.ly/python/matplotlib-colorscales/    
def matplotlib_to_plotly(cmap, pl_entries=255):
    """Convert matplotlib colormap to plotly colorscale.
    
    :param cmap: matplotlib colormap instance
    """
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale
    
    
def plot_plotly(x, y, uuids, colors, title=None, xlabel=None, ylabel=None, clr_label=None, with_links=False):
    import plotly.offline as py
    import plotly.graph_objs as go
    py.init_notebook_mode(connected=True)
    
    # to be changed to rest url for Rocio's DB
    rest_url = 'http://localhost:8000/explore/sssp/details'
    
    if inp_clr.value == 'bond_type':
        colorscale = matplotlib_to_plotly(discrete_cmap(bondtype_colors))
        
        colorbar = go.ColorBar(
            tickmode = 'array',
            tickvals = np.linspace(0.5, len(bondtypes)-1.5, len(bondtypes)),
            ticktext = bondtypes,
            title=clr_label,
            titleside='right',
        )
    else:
        colorscale = 'Jet'
        colorbar=go.ColorBar(title=clr_label, titleside='right')

    marker = dict(size=10, line=dict(width=2), color=colors, colorscale=colorscale, colorbar=colorbar)
    trace = go.Scatter(x=x, y=y, mode='markers', marker=marker)
    
    N = len(x)
    # workaround for links - for proper handling of click events use plot.ly dash
    links = [ dict(x=x[i], y=y[i], text='<a href="{}/{}">o</a>'.format(rest_url, uuids[i]),
                   showarrow=False, font=dict(color='#ffffff'), opacity=0.1) for i in range(N)]
    
    layout = dict(title=title, xaxis=dict(title=xlabel), yaxis=dict(title=ylabel), hovermode='closest')
    if with_links:
        layout.update(dict(annotations=links))

    fig = go.Figure(data=[trace], layout=layout)
    py.iplot(fig, filename='jupyter/basic-scatter')
    
    
def plot_bokeh(x, y, uuids, colors, title=None, xlabel=None, ylabel=None, clr_label=None, with_links=False):
    import bokeh.plotting as bpl
    import bokeh.models as bmd
    from bokeh.palettes import Viridis256
    import matplotlib as mpl
    from bokeh.io import output_notebook, show
    output_notebook(hide_banner=True)
    
    # to be changed to rest url for Rocio's DB
    rest_url = 'http://localhost:8000/explore/sssp/details'
    
    if inp_clr.value == 'bond_type':
        s = "Coloring by bond_type not yet implemented for bokeh. Exiting..."
        plot_info.value = s
        raise NotImplementedError(s)
       
        #colorscale = matplotlib_to_plotly(discrete_cmap(bondtype_colors))
        
        #colorbar = go.ColorBar(
        #    tickmode = 'array',
        #    tickvals = np.linspace(0.5, len(bondtypes)-1.5, len(bondtypes)),
        #    ticktext = bondtypes,
        #    title=clr_label,
        #    titleside='right',
        #)
    else:
        #colors = [
        #    "#%02x%02x%02x" % (int(r), int(g), int(b)) for r, g, b, _ in 255*mpl.cm.viridis(mpl.colors.Normalize()(colors))
        #]
        cmap = bmd.LinearColorMapper(palette=Viridis256, low=min(colors), high=max(colors))
        cbar = bmd.ColorBar(color_mapper=cmap, title=clr_label, location=(0, 0))
    
    fig = bpl.figure(
        #width=500, height=500, 
        toolbar_location=None,
        title=title,
        x_axis_label=xlabel,
        y_axis_label=ylabel,
        tools=['tap', 'zoom_in', 'zoom_out', 'pan'], 
        output_backend='webgl',
    )
    source = bmd.ColumnDataSource(data=dict(x=x, y=y, uuid=uuids, color=colors))
    
    fig.circle('x', 'y', size=10, source=source, fill_color={'field':'color', 'transform':cmap})
    #fig.scatter(x, y, radius=2.0, fill_color=colors, fill_alpha=0.6, line_color=None)
    fig.add_layout(cbar, 'right')
    
    taptool = fig.select(type=bmd.TapTool)
    url="{}/@uuid".format(rest_url)
    taptool.callback = bmd.OpenURL(url=url)
    
    show(fig)

In [None]:
# search UI
style = {"description_width":"220px"}
layout = ipw.Layout(width="90%")

quantities = collections.OrderedDict([
    ('density', dict(label='Density', range=[10.0,1200.0], default=[200.0,600.0], unit='kg/m^3')),
    ('deliverable_capacity', dict(label='Deliverable capacity', range=[0.0,300.0], unit='v STP/v')),
    ('absolute_methane_uptake_high_P', dict(label='CH4 uptake High-P', range=[0.0,300.0], unit='mol/kg')),
    ('absolute_methane_uptake_low_P', dict(label='CH4 uptake Low-P', range=[0.0,300.0], unit='mol/kg')),
    ('heat_desorption_high_P', dict(label='CH4 heat of desorption High-P', range=[0.0,30.0], unit='kJ/mol')),
    ('heat_desorption_low_P', dict(label='CH4 heat of desorption Low-P', range=[0.0,30.0], unit='kJ/mol')),    
    ('supercell_volume', dict(label='Supercell volume', range=[0.0,1000000.0], unit='A^3')),
    ('surface_area', dict(label='Geometric surface area', range=[0.0,12000.0], unit='m^2/g')),
])

bondtype_dict = collections.OrderedDict([
    ('amide', "#1f77b4"), ('amine', "#d62728"), ('imine', "#ff7f0e"),
    ('CC', "#2ca02c"), ('mixed', "#778899"),
])    
bondtypes = list(bondtype_dict.keys())
bondtype_colors = list(bondtype_dict.values())


inp_pks = ipw.Text(description='PKs', placeholder='e.g. 1006 1009 (space separated)', layout=layout, style=style)

def get_slider(desc, range, default=None):
    if default is None:
        default = range
    return ipw.FloatRangeSlider(description=desc, min=range[0], max=range[1], 
                                    value=default, step=0.05, layout=layout, style=style)

sliders_dict = collections.OrderedDict()
for k,v in quantities.iteritems():
    desc = "{} [{}]".format(v['label'], v['unit'])
    if not 'default' in v.keys():
        v['default'] = None
        
    slider = get_slider(desc, v['range'], v['default'])
    sliders_dict[k] = slider
    
sliders = list(sliders_dict.values())

In [None]:
import logging
logging.basicConfig()

def on_click(b):
    with plot_out:
        clear_output()
        search()

button = ipw.Button(description="Plot")
button.on_click(on_click)

plot_options = { v['label']: k for k,v in quantities.iteritems() }

inp_x = ipw.Dropdown(
    options = plot_options,
    value = 'density',
    description='X:',
)

inp_y = ipw.Dropdown(
    options = plot_options,
    value='deliverable_capacity',
    description='Y:',
)

inp_clr = ipw.Dropdown(
    options = dict(plot_options, **{'Bond type': 'bond_type'}),
    value='bond_type',
    description='Color:',
)


btn_mode = ipw.Dropdown(
    options = {
        'matplotlib (fast)': 'matplotlib',
        'plot.ly (default)':'plotly',
        'plot.ly with links (slow)': 'plotly_links',
        'bokeh': 'bokeh',
    },
    value = 'plotly',
    description='Framework:',
)


properties = ipw.HBox([inp_x, inp_y, inp_clr])

plot_out = ipw.Output()
plot_info = ipw.HTML("")
app = ipw.VBox(sliders + [properties, ipw.HBox([button, btn_mode]), plot_info, plot_out])
display(app)