# Search AiiDA Database for Deliverable Capacities

In [None]:
from __future__ import print_function
import collections


from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.data.structure import StructureData
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job import JobCalculation
from aiida.orm.data.parameter import ParameterData

import ase.io
import sys
import numpy as np
import ipywidgets as ipw
from IPython.display import display, clear_output, Image

%matplotlib notebook
import matplotlib.pyplot as plt

In [None]:
############################   START OF PREPROCESSING   ###############################

In [None]:
PREPROCESS_VERSION = 0.06

def preprocess_newbies():
    qb = QueryBuilder()
    qb.append(WorkCalculation, filters={
        'attributes._process_label': 'DcMethane',
        'or':[
               {'extras': {'!has_key': 'preprocess_version'}},
               {'extras.preprocess_version': {'<': PREPROCESS_VERSION}},
           ],
    })
    
    
    for m in qb.all(): # iterall() would interfere with set_extra()
        n = m[0]
        if not n.is_sealed:
            print("Skipping underway workchain PK %d"%n.pk)
            continue
        if 'obsolete' not in n.get_extras():
            n.set_extra('obsolete', False)
        try:
            preprocess_one(n)
            n.set_extra('preprocess_successful', True)
            n.set_extra('preprocess_version', PREPROCESS_VERSION)
            print("Preprocessed PK %d"%n.pk)
        except Exception as e:
            n.set_extra('preprocess_successful', False)
            n.set_extra('preprocess_error', str(e))
            n.set_extra('preprocess_version', PREPROCESS_VERSION)
            print("Failed to preprocess PK %d: %s"%(n.pk, e))

In [None]:
def preprocess_one(workcalc):


    # clean extras
    for key in workcalc.get_extras():
        workcalc.del_extra(key)
    
    
    # input structure
    structure = workcalc.inp.structure
    workcalc.set_extra('structure_id', structure.id)

    # extract the output parameters and put them in extras for the easy access
    res = workcalc.out.result
    workcalc.set_extra('accessible_surface_area', res.get_attr('accessible_surface_area'))
    workcalc.set_extra('accessible_surface_area_units', res.get_attr('accessible_surface_area_units'))
    workcalc.set_extra('density', res.get_attr('density'))
    workcalc.set_extra('density_units', res.get_attr('density_units'))
    workcalc.set_extra('deliverable_capacity', res.get_attr('deliverable_capacity')*22.4139757476*1e4/6.02214086/res.get_attr('unitcell_volume'))
    workcalc.set_extra('deliverable_capacity_units', 'v STP/v')
    workcalc.set_extra('largest_included_sphere', res.get_attr('largest_included_sphere'))
    workcalc.set_extra('largest_included_sphere_units', res.get_attr('largest_included_sphere_units'))
    workcalc.set_extra('pore_accesible_volume', res.get_attr('pore_accesible_volume'))
    workcalc.set_extra('pore_accesible_volume_units', res.get_attr('pore_accesible_volume_units'))
    workcalc.set_extra('unitcell_volume', res.get_attr('unitcell_volume'))
    workcalc.set_extra('unitcell_volume_units', res.get_attr('unitcell_volume_units'))


In [None]:
preprocess_newbies()

In [None]:
def search():
    """Query AiiDA database"""
    
    plot_info.value = "Searching..."
    
    filters = {}
    filters = {'attributes._process_label': 'DcMethane'}

    def add_range_filter(bounds, label):
        filters['extras.'+label] = {'and':[{'>=':bounds[0]}, {'<':bounds[1]}]}
        
    for k, btn in sliders_dict.iteritems():
        add_range_filter(btn.value, k)
    
    
    projections = [
        'extras.'+inp_x.value, 
        'extras.'+inp_y.value, 
        'extras.'+inp_clr.value, 
        'extras.structure_id',
    ]
    
    if len(set(projections)) != len(projections):
        plot_info.value = "Please select different quantities for X, Y and Color"
        return
  
    qb = QueryBuilder()     
    qb.append(WorkCalculation, filters=filters, project=projections)
    nresults = qb.count()
    if nresults == 0:
        plot_info.value = "No results found."
        return

    plot_info.value = "{} results found. Plotting...".format(nresults)
    
    # x,y position
    x, y, clrs, ids = zip(*qb.all())
    x = map(float, x)
    y = map(float, y)
    uuids = map(str, ids)

    title = "{} vs {}".format(inp_y.label, inp_x.label)    
    xlabel = "{} [{}]".format(inp_x.label, quantities[inp_x.value]['unit'])
    ylabel = "{} [{}]".format(inp_y.label, quantities[inp_y.value]['unit'])

    # colors
    
    clrs = map(float, clrs)
    clr_label = "{} [{}]".format(inp_clr.label, quantities[inp_clr.value]['unit'])
        
    # plot   
    if btn_mode.value is 'plotly':
        plot_plotly(x, y, uuids, clrs, title=title, xlabel=xlabel, ylabel=ylabel, clr_label=clr_label)
    elif btn_mode.value is 'plotly_links':
        plot_plotly(x, y, uuids, clrs, title=title, xlabel=xlabel, ylabel=ylabel, 
                    clr_label=clr_label, with_links=True)
    elif btn_mode.value is 'bokeh':
        plot_bokeh(x, y, uuids, clrs, title=title, xlabel=xlabel, ylabel=ylabel, 
                    clr_label=clr_label)
    else:
        plot_matplotlib(x, y, uuids, clrs, title=title, 
                        xlabel=xlabel, ylabel=ylabel, clr_label=clr_label)

    plot_info.value = "Plotted {} results.".format(nresults)
   

def discrete_cmap(colors):
    """Return discrete colormap for matplotlib.
    
    :param: list of RGB colors
    """
    N = len(colors)
    base = plt.cm.get_cmap('cubehelix')
    cmap_name = "discrete_" + str(N)
    return base.from_list(cmap_name, colors, N)

def plot_matplotlib(x, y, uuids, clrs, title=None, xlabel=None, ylabel=None, clr_label=None):
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
                
    if inp_clr.value == 'bond_type':
        cmap = discrete_cmap(bondtype_colors)
    else:
        cmap = 'rainbow'
        
    plt.scatter(x, y, c=clrs, cmap=cmap, s=5, linewidth=0)
    colorbar = plt.colorbar(label=clr_label)

    if inp_clr.value == 'bond_type':
        N = len(bondtypes)
        colorbar.set_ticks(range(N))
        colorbar.set_ticklabels(bondtypes)
        # center ticks
        plt.clim(-0.5, N-0.5)


# See https://plot.ly/python/matplotlib-colorscales/    
def matplotlib_to_plotly(cmap, pl_entries=255):
    """Convert matplotlib colormap to plotly colorscale.
    
    :param cmap: matplotlib colormap instance
    """
    h = 1.0/(pl_entries-1)
    pl_colorscale = []
    
    for k in range(pl_entries):
        C = map(np.uint8, np.array(cmap(k*h)[:3])*255)
        pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))])
        
    return pl_colorscale
    
    
def plot_plotly(x, y, uuids, colors, title=None, xlabel=None, ylabel=None, clr_label=None, with_links=False):
    import plotly.offline as py
    import plotly.graph_objs as go
    py.init_notebook_mode(connected=True)
    
    # to be changed to rest url for Rocio's DB
    rest_url = 'http://localhost:8000/explore/sssp/details'
    
    if inp_clr.value == 'bond_type':
        colorscale = matplotlib_to_plotly(discrete_cmap(bondtype_colors))
        
        colorbar = go.ColorBar(
            tickmode = 'array',
            tickvals = np.linspace(0.5, len(bondtypes)-1.5, len(bondtypes)),
            ticktext = bondtypes,
            title=clr_label,
            titleside='right',
        )
    else:
        colorscale = 'Jet'
        colorbar=go.ColorBar(title=clr_label, titleside='right')

    marker = dict(size=10, line=dict(width=2), color=colors, colorscale=colorscale, colorbar=colorbar)
    trace = go.Scatter(x=x, y=y, mode='markers', marker=marker)
    
    N = len(x)
    # workaround for links - for proper handling of click events use plot.ly dash
    links = [ dict(x=x[i], y=y[i], text='<a href="{}/{}">o</a>'.format(rest_url, uuids[i]),
                   showarrow=False, font=dict(color='#ffffff'), opacity=0.1) for i in range(N)]
    
    layout = dict(title=title, xaxis=dict(title=xlabel), yaxis=dict(title=ylabel), hovermode='closest')
    if with_links:
        layout.update(dict(annotations=links))

    fig = go.Figure(data=[trace], layout=layout)
    py.iplot(fig, filename='jupyter/basic-scatter')
    
    
def plot_bokeh(x, y, uuids, colors, title=None, xlabel=None, ylabel=None, clr_label=None, with_links=False):
    import bokeh.plotting as bpl
    import bokeh.models as bmd
    from bokeh.palettes import Viridis256
    import matplotlib as mpl
    from bokeh.io import output_notebook, show
    from bokeh.models import HoverTool
    
    rest_url = 'http://localhost:8000/explore/sssp/details'

    output_notebook(hide_banner=True)

    cmap = bmd.LinearColorMapper(palette=Viridis256, low=min(colors), high=max(colors))
    cbar = bmd.ColorBar(color_mapper=cmap, title=clr_label, location=(0, 0))
    
    hover = HoverTool(tooltips=[
        ("structure id", "@uuid"),
    ])

    fig = bpl.figure(
        #width=500, height=500, 
        toolbar_location=None,
        title=title,
        x_axis_label=xlabel,
        y_axis_label=ylabel,
        tools=['tap', 'zoom_in', 'zoom_out', 'pan', hover], 
        output_backend='webgl',
    )
    source = bmd.ColumnDataSource(data=dict(x=x, y=y, uuid=uuids, color=colors))

    fig.circle('x', 'y', size=10, source=source, fill_color={'field':'color', 'transform':cmap})
    #fig.scatter(x, y, radius=2.0, fill_color=colors, fill_alpha=0.6, line_color=None)
    fig.add_layout(cbar, 'right')
    
    taptool = fig.select(type=bmd.TapTool)
    url="{}/@uuid".format(rest_url)
    taptool.callback = bmd.OpenURL(url=url)
    
    show(fig)

In [None]:
# search UI
style = {"description_width":"220px"}
layout = ipw.Layout(width="90%")

quantities = collections.OrderedDict([
    ('density', dict(label='Density', range=[0.0,20.0], unit='g/cm^3')),
    ('deliverable_capacity', dict(label='Deliverable capacity', range=[0.0,500.0], unit='v STP/v')),
    ('largest_included_sphere', dict(label='Largest included sphere', range=[0.0,30.0], unit='A')),
    ('accessible_surface_area', dict(label='Accesible surface area', range=[0.0,12000.0], unit='m^2/g')),
    ('pore_accesible_volume', dict(label='Pore accesible volume', range=[0.0,12000.0], unit='A^3')),
])


def get_slider(desc, range, default=None):
    if default is None:
        default = range
    return ipw.FloatRangeSlider(description=desc, min=range[0], max=range[1], 
                                    value=default, step=0.05, layout=layout, style=style)

sliders_dict = collections.OrderedDict()
for k,v in quantities.iteritems():
    desc = "{} [{}]".format(v['label'], v['unit'])
    if not 'default' in v.keys():
        v['default'] = None
        
    slider = get_slider(desc, v['range'], v['default'])
    sliders_dict[k] = slider
    
sliders = list(sliders_dict.values())

In [None]:
import logging
logging.basicConfig()

def on_click(b):
    with plot_out:
        clear_output()
        search()

button = ipw.Button(description="Plot")
button.on_click(on_click)

plot_options = { v['label']: k for k,v in quantities.iteritems() }

inp_x = ipw.Dropdown(
    options = plot_options,
    value = 'density',
    description='X:',
)

inp_y = ipw.Dropdown(
    options = plot_options,
    value='deliverable_capacity',
    description='Y:',
)

inp_clr = ipw.Dropdown(
    options = plot_options,
    value='pore_accesible_volume',
    description='Color:',
)


btn_mode = ipw.Dropdown(
    options = {
        'matplotlib (fast)': 'matplotlib',
        'plot.ly (default)':'plotly',
        'plot.ly with links (slow)': 'plotly_links',
        'bokeh': 'bokeh',
    },
    value = 'plotly',
    description='Framework:',
)


properties = ipw.HBox([inp_x, inp_y, inp_clr])

plot_out = ipw.Output()
plot_info = ipw.HTML("")
app = ipw.VBox(sliders + [properties, ipw.HBox([button, btn_mode]), plot_info, plot_out])
display(app)