## Import Packages, Define Functions and Global Variables

In [10]:
### TODO ###
# Fill in more tooltip info
# Get nice-looking chemical formulas to work

import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde

from bokeh.plotting import *
from bokeh.models import ColumnDataSource, CDSView, CustomJS, Select, LegendItem, Legend, Title, Div
from bokeh.layouts import row, column, Spacer
from bokeh.models.filters import GroupFilter
from bokeh.palettes import Category10
from bokeh.models.tools import HoverTool
from bokeh.io import show


### Set GLOBAL VARIABLES #########################################################################
REDUCED_PATH = "../data/processed/csv_version/IMT_Classification_Dataset_Reduced_Feature_Set_v10.csv"

TWO_PLOTS = False
SCATTER_TOOLS = "pan,wheel_zoom,box_zoom"
if TWO_PLOTS: SCATTER_TOOLS += "box_select,lasso_select"

KDE_TOOLS = "pan,wheel_zoom,box_zoom"
FIG_TITLE = "IMT Classification Dataset"
STD_FONT_SIZE = 14
LINE_ALPHA = 0.8
GLYPH_SIZE = 10
GLYPH_OUTLINE_WIDTH = 1.5
KDE_HEIGHT = 180  #height (along KDE value axis) of the KDE plots
BORDER_WIDTH = 50 #width of empty white space on sides of figure

output_notebook() # make 
output_file("../plots/bokeh_visualization.html", title=FIG_TITLE)

### Establish color/marker categories ############################################################
CATEG_DICT = {'0':'Metal', '1':'Insulator', '2':'MIT'}
CATEG_COLORS = Category10[len(CATEG_DICT.keys())]
CATEG_MARKERS = ['circle', 'triangle', 'plus']


### Establish columns used for plotting ##########################################################
COL_LENGTHEN = {'gii':                       'Global Instability Index', 
                'est_hubbard_u':             'Estimated Hubbard U (eV)',
                'est_charge_trans':          'Estimated Charge Transfer Gap, Δ (eV)',
                'ewald_energy_per_atom':     'Ewald Energy (eV/atom)', 
                'avg_dev_Electronegativity': 'Average Deviation in Electronegativity',
                'range_MendeleevNumber':     'Mendeleev Number Range', 
                'avg_dev_CovalentRadius':    'Average Deviation in Covalent Radii (Å)',
                'avg_mm_dists':              'Average Metal-Metal Distance (Å)',
                'avg_mx_dists':              'Average Metal-Anion Distance (Å)',
                'avg_xx_dists':              'Average Anion-Anion Distance (Å)'
               } #{'short_name': 'Pretty Name'}
COL_SHORTEN = {v:k for k,v in COL_LENGTHEN.items()}
AXIS_OPTIONS = COL_LENGTHEN.keys() 
PRETTY_AXIS_OPTIONS = [COL_LENGTHEN[key] for key in AXIS_OPTIONS]
INIT_X = 'est_charge_trans'
INIT_Y = 'gii'
# RANGE_DICT = {
#     'temp':         (0,1100),
#     'magmom_dens':  (0,0.65),
#     'magmom':       (0, 10),
#     'v_per_atom':   (0,39)
# }


### Define useful functions ######################################################################

def format_scatter_plot(plot):
    ### Fonts
    plot.xaxis.axis_label = COL_LENGTHEN[INIT_X]
    plot.yaxis.axis_label = COL_LENGTHEN[INIT_Y] 
    plot.xaxis.axis_label_text_font_style = "normal"
    plot.yaxis.axis_label_text_font_style = "normal"
    plot.xaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.yaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.xaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt'
    plot.yaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt'
    ### Spacing
    plot.min_border = 10
    
def format_kde_plot(plot, vertical=False):
    ### Title, sizing, axes
    if not vertical:
        plot.title.text_font_size = f'{int(STD_FONT_SIZE*1.2)}pt'
        plot.height=KDE_HEIGHT
        plot.sizing_mode='stretch_width'
        plot.yaxis.minor_tick_line_color = None
        plot.xaxis.visible=False
        plot.yaxis.axis_label = 'Renormalized KDE'
    else:
        plot.width=KDE_HEIGHT
        plot.sizing_mode='stretch_height'
        plot.xaxis.minor_tick_line_color = None
        plot.yaxis.visible=False
        plot.xaxis.axis_label = 'Renormalized KDE'
    ### Fonts
    plot.xaxis.axis_label_text_font_style = "normal"
    plot.yaxis.axis_label_text_font_style = "normal"
    plot.xaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.yaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.xaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt'
    plot.yaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt' 
    plot.min_border = 10

# def format_formula(formula): # Doesn't work since Bokeh doesn't treat 
#     """Inserts HTML style subscript tags to all numbers in the given string"""
#     for char in formula:
#             if char.isdigit():
#                 formula = formula.replace(char, f'<sub>{char}</sub>')
#     return formula

## Import and Prepare Data

In [11]:
df = pd.read_csv(REDUCED_PATH)
df.rename(columns={'Label':'classification'}, inplace=True)
df = df.astype({'classification':str})
df['pretty_classification'] = [CATEG_DICT[classif] for classif in df.classification]                
# df['pretty_formula'] = [format_formula(formula) for formula in df.Compound] #not working

source = ColumnDataSource(data=df)


### Precalculate KDEs for all axes, labels ########################################################
kde_dict = {} #{col : ['vals'vals, kde]}

def calc_kde(col_data):
    """Calculate KDE for given 1D array of data, return x and y values"""
    kernel = gaussian_kde(col_data)
    vals = np.linspace(min(col_data),max(col_data),400)
    kde = kernel(vals).T
    kde = kde/max(kde) # max-normalization for easier display
    return (vals, kde)

for col in AXIS_OPTIONS:
    for categ in CATEG_DICT.keys():
        this_df = df[df['classification']== categ]
        vals, kde = calc_kde(this_df[col])
        kde_dict.update({(col,str(categ),'vals'):vals, (col,str(categ),'kde'):kde})
kde_df = pd.DataFrame.from_dict(kde_dict)
kde_source = ColumnDataSource(data=kde_df)

# kde_source.data
df.keys()

Index(['Compound', 'classification', 'struct_file_path', 'gii',
       'est_hubbard_u', 'est_charge_trans', 'ewald_energy_per_atom',
       'avg_dev_Electronegativity', 'range_MendeleevNumber',
       'avg_dev_CovalentRadius', 'avg_mm_dists', 'avg_mx_dists',
       'avg_xx_dists', 'pretty_classification'],
      dtype='object')

## Plot

In [12]:
### Initialize left scatter plot
scatter_plot = figure(tools=SCATTER_TOOLS, sizing_mode='stretch_both')
points_ls = []
for classif, classif_label, color, marker in zip(CATEG_DICT.keys(), CATEG_DICT.values(), CATEG_COLORS, CATEG_MARKERS):
    view = CDSView(source=source, filters=[GroupFilter(column_name='classification', group=classif)])
    points = scatter_plot.scatter(x=INIT_X, y=INIT_Y, source=source, view=view,
                                  fill_alpha=0.01, line_alpha=LINE_ALPHA,# legend_label=classif_label,
                                  color=color, marker=marker, size=GLYPH_SIZE, line_width=GLYPH_OUTLINE_WIDTH)
    points_ls.append(points) #store for making changes in js_callback
    
### Initialize KDE plots 
hkde_plot = figure(tools=KDE_TOOLS, x_range=scatter_plot.x_range, y_range=(0,1.01))
hkde_ls = []
for classif, classif_label, color, marker in zip(CATEG_DICT.keys(), CATEG_DICT.values(), CATEG_COLORS, CATEG_MARKERS):
    hkde_line = hkde_plot.line(x=f'{INIT_X}_{classif}_vals', y=f'{INIT_X}_{classif}_kde', line_width=3,
                               line_alpha=LINE_ALPHA, color=color, #legend_label=classif_label, 
                               source=kde_source)
    hkde_ls.append(hkde_line) #store for making changes in js_callback

vkde_plot = figure(tools=KDE_TOOLS, y_range=scatter_plot.y_range, x_range=(0,1.01))
vkde_ls = []
for classif, classif_label, color, marker in zip(CATEG_DICT.keys(), CATEG_DICT.values(), CATEG_COLORS, CATEG_MARKERS):
    vkde_line = vkde_plot.line(y=f'{INIT_Y}_{classif}_vals', x=f'{INIT_Y}_{classif}_kde', line_width=3,
                               line_alpha=LINE_ALPHA, color=color, #legend_label=classif_label, 
                               source=kde_source)
    vkde_ls.append(vkde_line) #store for making changes in js_callback 

if TWO_PLOTS:
    ### Initialize right scatter plot
    rscatter_plot = figure(tools=SCATTER_TOOLS, sizing_mode='stretch_both')
    rpoints_ls = []
    for classif, classif_label, color, marker in zip(CATEG_DICT.keys(), CATEG_DICT.values(), CATEG_COLORS, CATEG_MARKERS):
        view = CDSView(source=source, filters=[GroupFilter(column_name='classification', group=classif)])
        points = rscatter_plot.scatter(x=INIT_X, y=INIT_Y, source=source, view=view,
                                      fill_alpha=0.01, line_alpha=LINE_ALPHA, #legend_label=classif_label,
                                      color=color, marker=marker, size=GLYPH_SIZE, line_width=GLYPH_OUTLINE_WIDTH)
        rpoints_ls.append(points) #store for making changes in js_callback

    ### Initialize right KDE plots 
    rhkde_plot = figure(tools=KDE_TOOLS, x_range=rscatter_plot.x_range, y_range=(0,1.01))
    rhkde_ls = []
    for classif, classif_label, color, marker in zip(CATEG_DICT.keys(), CATEG_DICT.values(), CATEG_COLORS, CATEG_MARKERS):
        hkde_line = rhkde_plot.line(x=f'{INIT_X}_{classif}_vals', y=f'{INIT_X}_{classif}_kde', line_width=3,
                                   line_alpha=LINE_ALPHA, color=color, #legend_label=classif_label, 
                                   source=kde_source)
        rhkde_ls.append(hkde_line) #store for making changes in js_callback

    rvkde_plot = figure(tools=KDE_TOOLS, y_range=rscatter_plot.y_range, x_range=(0,1.01))
    rvkde_ls = []
    for classif, classif_label, color, marker in zip(CATEG_DICT.keys(), CATEG_DICT.values(), CATEG_COLORS, CATEG_MARKERS):
        vkde_line = rvkde_plot.line(y=f'{INIT_Y}_{classif}_vals', x=f'{INIT_Y}_{classif}_kde', line_width=3,
                                   line_alpha=LINE_ALPHA, color=color, #legend_label=classif_label, 
                                   source=kde_source)
        rvkde_ls.append(vkde_line) #store for making changes in js_callback 
    

### Create universal (dummy) legend
if TWO_PLOTS:
    legend_items = [LegendItem(label=classif,renderers=[points, hkde_line, vkde_line, rpoints, rhkde_line, rvkde_line]) \
                for classif, points, hkde_line, vkde_line, rpoints, rhkde_line, rvkde_line in \
                zip(CATEG_DICT.values(), points_ls, hkde_ls, vkde_ls, rpoints_ls, rhkde_ls, rvkde_ls)]
else:
    legend_items = [LegendItem(label=classif, renderers=[points, hkde_line, vkde_line]) \
            for classif, points, hkde_line, vkde_line in \
            zip(CATEG_DICT.values(), points_ls, hkde_ls, vkde_ls)]

dum_fig = figure(height=60, min_width=300, toolbar_location=None, outline_line_alpha=0, 
                 sizing_mode='stretch_width')
for fig_component in [dum_fig.grid[0],dum_fig.ygrid[0],dum_fig.xaxis[0],dum_fig.yaxis[0]]:
    fig_component.visible = False
dum_fig.renderers += points_ls + hkde_ls + vkde_ls
if TWO_PLOTS: dum_fig.renderers += rpoints_ls + rhkde_ls + rvkde_ls
dum_fig.x_range.end = 100005
dum_fig.x_range.start = 100000
dum_fig.add_layout(Legend(items=legend_items, click_policy='hide', location='top_center', 
                          border_line_alpha=0, glyph_height=40, glyph_width=30, 
                          label_text_font_size=f'{int(0.9*STD_FONT_SIZE)}pt', 
                          orientation='horizontal', spacing=10))
                            
### Make JavaScript callbacks => interactivity
xaxis_code="""
        var col_name = cb_obj.value;
        var column = col_dict[col_name];
        // Change scatter plot axes
        points_ls.forEach(function(points) {{
            points.glyph.x.field = column;
        }});
        //x_range.start = range_dict[column][0];
        //x_range.end = range_dict[column][1];
        ax.axis_label = col_name;
        source.change.emit();
"""

yaxis_code="""
        var col_name = cb_obj.value;
        var column = col_dict[col_name];
        // Change scatter plot axes
        points_ls.forEach(function(points) {{
            points.glyph.y.field = column;
        }});
        //y_range.start = range_dict[column][0];
        //y_range.end = range_dict[column][1];
        ax.axis_label = col_name;
        source.change.emit();"""

hkde_axis_code="""
        // Change horiz KDE axis
        hkde_ls.forEach(function(line) {{
            var current_col = line.glyph.x.field;
            var col_decomp = current_col.split('_');
            var mag = col_decomp[col_decomp.length - 3];
            var classif = col_decomp[col_decomp.length - 2];
            var new_x = `${column}_${classif}_vals`;
            var new_y = `${column}_${classif}_kde`;
            line.glyph.x.field = new_x;
            line.glyph.y.field = new_y;
        }});
        kde_source.change.emit();"""

vkde_axis_code="""
        // Change vertical KDE axis
        vkde_ls.forEach(function(line) {{
            var current_col = line.glyph.y.field;
            var col_decomp = current_col.split('_');
            var mag = col_decomp[col_decomp.length - 3];
            var classif = col_decomp[col_decomp.length - 2];
            var new_y = `${column}_${classif}_vals`;
            var new_x = `${column}_${classif}_kde`;
            line.glyph.x.field = new_x;
            line.glyph.y.field = new_y;
        }});
        kde_source.change.emit();"""

callbackx = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=points_ls, 
                               ax=scatter_plot.xaxis[0], x_range=scatter_plot.x_range, 
                               hkde_ls=hkde_ls, col_dict=COL_SHORTEN, categ_dict=CATEG_DICT),
                     code=xaxis_code+hkde_axis_code)

callbacky = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=points_ls, 
                               ax=scatter_plot.yaxis[0], y_range=scatter_plot.y_range, 
                               vkde_ls=vkde_ls, col_dict=COL_SHORTEN, categ_dict=CATEG_DICT),
                     code=yaxis_code+vkde_axis_code)
if TWO_PLOTS:
    rcallbackx = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=rpoints_ls, 
                                    ax=rscatter_plot.xaxis[0], x_range=rscatter_plot.x_range, 
                                    hkde_ls=rhkde_ls, col_dict=COL_SHORTEN, categ_dict=CATEG_DICT),
                         code=xaxis_code+hkde_axis_code)

    rcallbacky = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=rpoints_ls, 
                                    ax=rscatter_plot.yaxis[0], y_range=rscatter_plot.y_range, 
                                    vkde_ls=rvkde_ls, col_dict=COL_SHORTEN, categ_dict=CATEG_DICT),
                         code=yaxis_code+vkde_axis_code)

### Axis data selector tools
xaxis_select = Select(title="X axis:", value=COL_LENGTHEN[INIT_X], options=PRETTY_AXIS_OPTIONS,
                     min_width=10, sizing_mode='stretch_width')
xaxis_select.js_on_change('value', callbackx)
yaxis_select = Select(title="Y axis:", value=COL_LENGTHEN[INIT_Y], options=PRETTY_AXIS_OPTIONS,
                     min_width=10, sizing_mode='stretch_width')
yaxis_select.js_on_change('value', callbacky)

if TWO_PLOTS:
    rxaxis_select = Select(title="X axis:", value=COL_LENGTHEN[INIT_X], options=PRETTY_AXIS_OPTIONS,
                         min_width=10, sizing_mode='stretch_width')
    rxaxis_select.js_on_change('value', rcallbackx)
    ryaxis_select = Select(title="Y axis:", value=COL_LENGTHEN[INIT_Y], options=PRETTY_AXIS_OPTIONS,
                         min_width=10, sizing_mode='stretch_width')
    ryaxis_select.js_on_change('value', rcallbacky)

### Tooltips
hover = HoverTool()
hover.tooltips ="""
    <div>
        <h3><center>@Compound</center></h3>
        <div><strong>Class:</strong>             @pretty_classification</div>
        <div><strong>GII:</strong>               @gii</div>
        <div><strong>U:</strong>                 @est_hubbard_u eV</div>
        <div><strong>Δ:</strong>                 @est_charge_trans eV</div>
    </div>
"""

scatter_plot.add_tools(hover)
if TWO_PLOTS: rscatter_plot.add_tools(hover)

### Layout
spacer = Spacer(height=KDE_HEIGHT, width=KDE_HEIGHT)
lborder = Spacer(width=BORDER_WIDTH, sizing_mode='stretch_height')
rborder = Spacer(width=BORDER_WIDTH, sizing_mode='stretch_height')

format_scatter_plot(scatter_plot)
format_kde_plot(hkde_plot)
format_kde_plot(vkde_plot, vertical=True)
if TWO_PLOTS:
    format_scatter_plot(rscatter_plot)
    format_kde_plot(rhkde_plot)
    format_kde_plot(rvkde_plot, vertical=True)

title = Div(text=f'<h1>{FIG_TITLE}</h1>', align='center', height_policy='min', margin=(-0,0,-0,0))
# title = figure(height=40, sizing_mode='stretch_width')
# title.add_layout(Title(text=FIG_TITLE, align="center"))

### Layout 1
top = Row(hkde_plot, spacer, sizing_mode='stretch_width')
middle = Row(scatter_plot, vkde_plot, sizing_mode='stretch_width')

if TWO_PLOTS:
    rtop = Row(rhkde_plot, spacer, sizing_mode='stretch_width')
    rmiddle = Row(rscatter_plot, rvkde_plot, sizing_mode='stretch_width')
    controls = Row(yaxis_select, xaxis_select, dum_fig, ryaxis_select, rxaxis_select, 
                   sizing_mode='stretch_width')
    lplots = Column(top, middle, sizing_mode='stretch_width')
    rplots = Column(rtop, rmiddle, sizing_mode='stretch_width')
    layout = Column(title, Row(lplots, rplots, sizing_mode='stretch_width'), controls, 
                    sizing_mode='stretch_width')

else:                     
    controls = Row(yaxis_select, dum_fig, xaxis_select, sizing_mode='stretch_width')
    lplots = Column(top, middle, sizing_mode='stretch_width')
    layout = Row(lborder, Column(title, lplots, controls, sizing_mode='stretch_width'), rborder)

show(layout)