In [None]:
import os
os.environ['BOKEH_VALIDATE_DOC'] = "0"

In [None]:
import os
import yaml
import math
import fnmatch
import numpy as np
import pandas as pd
import panel as pn
from copy import deepcopy
from textwrap import dedent
from string import Formatter
from ruamel import yaml as ruyaml
from itertools import chain, product
from collections import defaultdict, Counter

import bokeh
# Don't remove the stack import line; for some buggy 
# reason Bokeh stacked bars don't work properly without it!
from bokeh.io import curdoc
from bokeh.transform import stack
from bokeh.plotting import figure
from bokeh.transform import linear_cmap
from bokeh.models import RadioButtonGroup
from bokeh.models import CustomJSTickFormatter
from bokeh.models import TabPanel, Tabs, Tooltip
from bokeh.models import Select, Div, BoxZoomTool
from bokeh.models import Slider, CategoricalSlider
from bokeh.layouts import layout, row, column, gridplot, grid
from bokeh.models import ColumnDataSource, CDSView, IndexFilter, HoverTool, CustomJSHover, BasicTicker

from partnn.io_cfg import cache_dir
from partnn.io_utils import eval_formula, deep2hie
from partnn.io_utils import resio, drop_unqcols, hie2deep, filter_df, CartDF 
from partnn.io_utils import decomp_df, save_h5data, load_h5data, parse_refs, get_subdict
from bokeh.core.properties import without_property_validation

In [None]:
workdir = './10_bokeh'
os.makedirs(workdir, exist_ok=True)
suppdir = f'{workdir}/supplement'
os.makedirs(suppdir, exist_ok=True)

# The CLI-invoked dashboard theme should be light to be able to differentiate the colors and blobs
doc_theme = 'light_minimal'
background_color = 'white'
header_color = 'black'

In [None]:
# The default theme inside the jupyter notebook should be dark (less differentiation between blobs)
pn.config.theme = 'dark'
pn.extension('katex', theme="dark")
doc_theme = 'dark_minimal'
background_color = 'black'
header_color = 'white'

In [None]:
class Renamer:
    def __init__(self, rnm_dict=None):
        self.dict = dict() if rnm_dict is None else rnm_dict
        self.invdict = {val: key for key, val in self.dict.items()}
        assert len(self.dict) == len(self.invdict), dedent(f'''
            Conflicting target names found:
                Renaming Dictionary: {self.dict}''')
    
    def encode(self, key):
        return self.dict.get(key, key)

    def decode(self, key):
        return self.invdict.get(key, key)

class ColorManager:
    def __init__(self, dflt_colors, color_spec):
        self.dflt_colors = dflt_colors
        self.color_spec = color_spec
        self.color_palette = list(color_spec.values())
        self.dict = dict()
    
    def __call__(self, lbl_name):
        color = self.dict.get(lbl_name, None)


        if color is None:
            for pat, clr in self.dflt_colors.items():
                if fnmatch.fnmatch(lbl_name, pat):
                    color = self.color_spec.get(clr, clr)
                    self.dict[lbl_name] = color
                    break

        if color is None:
            cntr = Counter(self.dict.values())
            color = min(self.color_palette, key=lambda color: cntr[color])
            self.dict[lbl_name] = color

        return color

def scifrmt(val, latex=True):
    val2 = f'{val:1.1e}'.replace("e-0", "e-")
    if latex:
        aa, bb = val2.split('e')
        out = fr'$${aa}\times 10^{{{bb}}}$$'
    else:
        out = val2
    return out

def get_ychemgrp(i_fig, fig_name, n_framerows, n_framecols):
    chem_name = fig_name.split('/')[-1]
    assert chem_name in ('SO4', 'NO3', 'Cl', 'NH4', 'ARO1', 'ARO2', 'ALK1', 
        'OLE1', 'API1', 'Na', 'OIN', 'OC', 'BC', 'MOC', 'H2O', 'every')
    return chem_name

code_engfmttr = '''
    if (Number.isNaN(tick)) {
        console.log("The tick value is NaN: " + tick);
        return tick;
    } else if (tick == 0) {
        // console.log("The tick value is zero: " + tick);
        return tick;
    } else if (!(!isNaN(tick))) {
        console.log("The tick value is not a number: " + tick);
        return tick;
    }
    
    let tick_abs = Math.abs(tick);
    let tick_log10 = Math.log10(tick_abs);
    let tick_log10d3 = 3 * Math.floor(tick_log10 / 3);

    if (Number.isNaN(tick_log10)) {
        console.log("The abs tick value is non zero yet the log abs is nan: tick=" + tick + ", logabs: " + tick_log10);
        return tick;
    }

    console.assert(tick_log10 >= tick_log10d3, 
        "tick_log10=" + tick_log10 + " is not greater than tick_log10d3=" + tick_log10d3);

    if (tick_log10d3 in pstfxs) {
        var tick_pstfx = pstfxs[tick_log10d3];
        var tick_log10_rmng = tick_log10 - tick_log10d3;
    } else if (tick_log10d3 < minpstfx) {
        var tick_pstfx = pstfxs[minpstfx];
        var tick_log10_rmng = tick_log10 - minpstfx;
    } else if (tick_log10d3 > maxpstfx) {
        var tick_pstfx = pstfxs[maxpstfx];
        var tick_log10_rmng = tick_log10 - maxpstfx;
    } else {
        throw new Error("Odd case happened! " + tick_log10d3);
    }

    let tick_prfx = Math.pow(10, tick_log10_rmng);
    let precision2 = (precision === null) ? (tick_prfx < 10) ? 2 : ((tick_prfx < 100) ? 1 : 0) : precision;
    let tick_eng = " " + parseFloat(Math.sign(tick) * tick_prfx).toFixed(precision2) + tick_pstfx;
    return tick_eng;
'''

# Instantiating the Pre-Processors

In [None]:
import json
import torch
import netCDF4
from ruamel import yaml as ruyaml
from partnn.io_cfg import data_dir
from partnn.tch_utils import BatchRNG
from partnn.notebooks.n09_utils import make_pp, get_snrtspltidxs
from partnn.io_utils import get_subdict, preproc_cfgdict, parse_refs

json_cfgpath = '../configs/02_adhoc/08_klstudy.yml'
if json_cfgpath.endswith('.json'):
    with open(json_cfgpath, 'r') as fp:
        json_cfgdict = json.load(fp, object_pairs_hook=dict)
elif json_cfgpath.endswith('.yml'):
    with open(json_cfgpath, "r") as fp:
        json_cfgdict = dict(ruyaml.safe_load(fp))
else:
    raise RuntimeError(f'unknown config extension: {json_cfgpath}')

json_cfgdict['io/config_id'] = '08_klstudy_00'
json_cfgdict['io/results_dir'] = './09_vaehist/results'
json_cfgdict['io/storage_dir'] = './09_vaehist/storage'
json_cfgdict['io/tch/device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Applying the looping processes
all_cfgdicts1 = preproc_cfgdict(json_cfgdict)[:1]

# Parsing all the "/ref"-, "/def"-, and "/pop"-ending keys
all_cfgdicts = [parse_refs(cfgdict, trnsfrmtn='hie', pathsep=' -> ') 
    for cfgdict in all_cfgdicts1]

# Dropping the trailing references section
all_refdicts = [get_subdict(cfgdict, prefix='refs', pop=True) 
    for cfgdict in all_cfgdicts]

cfg_dict_input = all_cfgdicts[0]

In [None]:
cfg_dict = cfg_dict_input.copy()
rng_seed_list = cfg_dict.pop('rng_seed/list')
pp_mdlscfg = get_subdict(cfg_dict, 'pp/modules', pop=True)
ppcri_type = cfg_dict.pop('cri/rcnst/pp/type')
#########################################################
###################### Data Options #####################
#########################################################
data_cfg = get_subdict(cfg_dict, 'data', pop=False)
data_tree = cfg_dict.pop('data/tree')
data_chems = cfg_dict.pop('data/chems', None)
data_binsi1 =  cfg_dict.pop('data/bins/idx/start', 0)
data_binsi2 =  cfg_dict.pop('data/bins/idx/end', None)
split_cfg = get_subdict(cfg_dict, 'split', pop=True)

device_name = cfg_dict.pop('io/tch/device')
dtype_name = cfg_dict.pop('io/tch/dtype')
name2dtype = dict(float64=torch.double,
    float32=torch.float32,
    float16=torch.float16)
tch_device = torch.device(device_name)
tch_dtype = name2dtype[dtype_name]

eval_bs = 4096

In [None]:
#########################################################
#################### Loading the Data ###################
#########################################################
data_path = f'{data_dir}/{data_tree}'
assert data_path.endswith('.nc')
with netCDF4.Dataset(data_path, "r") as fp:
    ################# The Array Dimensions #################
    # The number of scenarios
    n_snr = fp.dimensions['n_snr'].size
    # The number of time steps in each scenario
    n_t = fp.dimensions['n_t'].size
    # The number of diameter histogram bins
    k_bins = fp.dimensions['k_bins'].size
    # The number of chemical specices
    n_chem = fp.dimensions['n_chem'].size

    #################### The Data Arrays ####################
    # The number of particles within each diameter bin
    n_prthst = fp.variables['n_prthst'][:]
    assert n_prthst.shape == (n_snr, n_t, k_bins)

    # The mass within each diameter bin (the mass histogram)
    m_prthst = fp.variables['m_prthst'][:]
    assert m_prthst.shape == (n_snr, n_t, k_bins)

    # The mass of each chemical species within the diameter bins
    m_chmprthst = np.moveaxis(fp.variables['m_chmprthst'][:], -1, -2)
    assert m_chmprthst.shape == (n_snr, n_t, n_chem, k_bins)

    ####################### Meta-Data #######################
    # The Chemical Species
    chem_species_str  = 'SO4,NO3,Cl,NH4,MSA,ARO1,ARO2,ALK1,OLE1,API1'
    chem_species_str += ',API2,LIM1,LIM2,CO3,Na,Ca,OIN,OC,BC,MOC,H2O'
    chem_species = chem_species_str.split(',')

    # The histogram bin ranges of value
    diam_low, diam_high, diam_logbase = 1e-9, 1e-4, np.log(10)
    logdiam_low, logdiam_high = np.log(diam_low) / diam_logbase, np.log(diam_high) / diam_logbase
    logdiams = np.linspace(logdiam_low, logdiam_high, k_bins + 1, endpoint=True)
    x_histbins = np.exp(logdiams * diam_logbase)
    assert x_histbins.shape == (k_bins + 1,)

# Revising the chemical species if necessary
if data_chems not in (None, 'all'):    
    i_keepchems = [chem_species.index(chem) for chem in data_chems]
    chem_species = [chem_species[i_chm] for i_chm in i_keepchems]
    chem_species_str = ','.join(chem_species)

    n_chem = len(chem_species)
    m_chmprthst = m_chmprthst[:, :, i_keepchems, :]
    assert m_chmprthst.shape == (n_snr, n_t, n_chem, k_bins)

    m_prthst = m_chmprthst.sum(axis=-2)
    assert m_prthst.shape == (n_snr, n_t, k_bins)

    # # number of particles must be re-calculated from the full data
    # n_prthst = np.full_like(n_prthst, -1)
    # assert n_prthst.shape == (n_snr, n_t, k_bins)

# Revising the histogram bins if necessary
data_binsi2 = k_bins if data_binsi2 is None else data_binsi2
if (data_binsi1, data_binsi2) != (0, k_bins):
    i_keepbins = list(range(data_binsi1, data_binsi2))
    k_bins = len(i_keepbins)

    # The number of particles within each diameter bin
    n_prthst = n_prthst[:, :, i_keepbins]
    assert n_prthst.shape == (n_snr, n_t, k_bins)

    # The mass within each diameter bin (the mass histogram)
    m_prthst = m_prthst[:, :, i_keepbins]
    assert m_prthst.shape == (n_snr, n_t, k_bins)

    # The mass of each chemical species within the diameter bins
    m_chmprthst = m_chmprthst[:, :, :, i_keepbins]
    assert m_chmprthst.shape == (n_snr, n_t, n_chem, k_bins)

    # The histogram bin ranges of value
    logdiams = logdiams[i_keepbins + [i_keepbins[-1] + 1]]
    assert logdiams.shape == (k_bins + 1,)
    logdiam_low, logdiam_high = logdiams[0], logdiams[-1]
    diam_low, diam_high, diam_logbase = np.exp(logdiam_low), np.exp(logdiam_high), diam_logbase
    x_histbins = np.exp(logdiams * diam_logbase)
    assert x_histbins.shape == (k_bins + 1,)
    
#######################################
#   Making the Data Variables Dict    #
#######################################
# `datanp_dict` is an input variable (str) to array (np.ndarray) mapping.
datanp_dict = dict(
    n_prthst=n_prthst.reshape(n_snr * n_t, 1, k_bins),
    m_prthst=m_prthst.reshape(n_snr * n_t, 1, k_bins),
    m_chmprthst=m_chmprthst.reshape(n_snr * n_t, n_chem, k_bins))

# `data_dict` is an input variable (str) to tensor (torch.tensor) mapping.
data_dict = {key: torch.from_numpy(varnp).to(device=tch_device, dtype=tch_dtype)
            for key, varnp in datanp_dict.items()}

# Preparing the data hash for later caching
hash_data = {key: data_cfg for key in data_dict}

# `node2chnlnames` is an input variable to labels (list) mapping.
node2chnlnames = {'n_prthst': ['total number'],
    'm_prthst': ['total mass'],
    'm_prthstlog': ['log total mass'],
    'm_chmprthst': chem_species_str.split(','),
    'm_chmprtnrm': chem_species_str.split(','),
    'm_chmhst': chem_species_str.split(',')}

# Example:
#   data_dims == {
#       'n_prthst': (1, k_bins),
#       'm_prthst': (1, k_bins),
#       'm_chmprthst': (n_chem, k_bins)
#   }
data_dims = {key: tuple(varnp.shape[1:]) for key, varnp in datanp_dict.items()}

# The shape variables
shapevars = dict(n_snr=n_snr, n_t=n_t, n_chem=n_chem, k_bins=k_bins)

In [None]:
#########################################################
########### Constructing the Batch RNG Object ###########
#########################################################
n_seeds = len(rng_seed_list)
rng_seeds = np.array(rng_seed_list)
rng = BatchRNG(shape=(n_seeds,), lib='torch',
    device=tch_device, dtype=tch_dtype,
    unif_cache_cols=1_000_000,
    norm_cache_cols=5_000_000)
rng.seed(np.broadcast_to(rng_seeds, rng.shape))

#########################################################
######## Instantiating Train/Test Split Indecis #########
#########################################################
split_vars = split_cfg['vars']
split_idxdict = get_snrtspltidxs(split_cfg, n_snr, n_t, n_seeds,
    rng, tch_device, pop_opts=True)
trn_spltidxs = split_idxdict['split_idxs']
n_trn = trn_spltidxs.shape[-1]
n_tst = n_snr * n_t - n_trn
assert trn_spltidxs.shape == (n_seeds, n_trn)
tst_spltidxs = split_idxdict['negsplit_idxs']
assert tst_spltidxs.shape == (n_seeds, n_tst)

# The Dashboard Config

In [None]:
dash_type = None

# The chemical names
allchem_names = ['SO4', 'NO3', 'Cl', 'NH4', 'ARO1', 'ARO2', 'ALK1', 
    'OLE1', 'API1', 'Na', 'OIN', 'OC', 'BC', 'MOC', 'H2O']
n_chem = len(allchem_names)
# A smaller subset of the chemicals
chem_names = allchem_names
chem_name2idxs = {chem_name: allchem_names.index(chem_name) for chem_name in chem_names}

# The number of seeds in the experiments
n_seeds, n_epochs, n_t, n_bins = 16, 2, 25, 20

# The histogram bin ranges of value
diam_low, diam_high, diam_logbase = 1e-9, 1e-4, np.log(10)
logdiam_low, logdiam_high = np.log(diam_low) / diam_logbase, np.log(diam_high) / diam_logbase
logdiams = np.linspace(logdiam_low, logdiam_high, n_bins + 1, endpoint=True)
assert logdiams.shape == (n_bins + 1,)
x_histbins = np.exp(logdiams * diam_logbase)
assert x_histbins.shape == (n_bins + 1,)
# Dropping the first histogram bins
n_bins, x_histbins = n_bins - 1, x_histbins[1:]
assert x_histbins.shape == (n_bins + 1,)

In [None]:
dash_cfgpath = f'{workdir}/07_frames.yml'
with open(dash_cfgpath, "r") as fp:
    dash_cfgdict1 = dict(ruyaml.safe_load(fp))

dash_cfgdict2 = deep2hie(dash_cfgdict1)
    
# Parsing all the "/ref"-, "/def"-, and "/pop"-ending keys
dash_cfgdict3 = parse_refs(dash_cfgdict2, trnsfrmtn='hie', pathsep=' -> ')

dash_cfgdict = dash_cfgdict3.copy() 
allframe_infoshie = get_subdict(dash_cfgdict, prefix='frames', pop=True)
tnsvars = {'chem_names': chem_names, 'x_histbins': x_histbins, 'get_ychemgrp': get_ychemgrp}
for key, val in list(allframe_infoshie.items()):
    if isinstance(val, str) and any(var in val for var in tnsvars):
        allframe_infoshie[key] = eval_formula(val, tnsvars, catch=True)
allframe_infos = hie2deep(allframe_infoshie, sep='/')

# pop: [chembars, heatmap1, heatmap2, heatmap3, heatmap4, heatmap5, sctrplts]
frame_infos = {frame_name: frame_info for frame_name, frame_info in allframe_infos.items() 
    # if frame_info['type'] in ('bar',)
    if not any (fnmatch.fnmatch(frame_name, pat) for pat in 
        ['heatmap*', 'sctrplts'])}

# The available colors for the dashboard to choose from
color_spec = get_subdict(dash_cfgdict, prefix='color/spec', pop=True)

# Default colors for different labels
dflt_colors = get_subdict(dash_cfgdict, prefix='color/defaults', pop=True)

# The column name and values renaming
rnm_dict = get_subdict(dash_cfgdict, prefix='renames', pop=True)

In [None]:
layout_dirctn = 'col'

# The glyph to settings mapping
ctrls_cfghie = get_subdict(dash_cfgdict, prefix='ctrls', pop=True)
tnsvars = {'list': list, 'range': range}
for key, val in list(ctrls_cfghie.items()):
    if isinstance(val, str) and any(var in val for var in tnsvars):
        ctrls_cfghie[key] = eval_formula(val, tnsvars, catch=True)
ctrls_cfg = hie2deep(ctrls_cfghie, sep='/')

ctrl_opts = {col: colspec['opts'] for col, colspec in ctrls_cfg['spec'].items()}
ctrl_types = {col: colspec['type'] for col, colspec in ctrls_cfg['spec'].items()}

# The number of control columns
n_ctrlrows, n_ctrlcols = ctrls_cfg['nrows'], ctrls_cfg['ncols']
assert len(ctrl_types) <= (n_ctrlrows * n_ctrlcols)
# The control margins
m_ctrltop, m_ctrlright, m_ctrlbottom, m_ctrlleft = (5, 15, 5, 15)
# The control sizes
ctrl_height = 50
if layout_dirctn == 'col':
    ctrl_width = (1700 - (m_ctrlright + m_ctrlleft) * (n_ctrlcols - 1)) // (n_ctrlcols)
elif layout_dirctn == 'row':
    ctrl_width = 400
else:
    raise ValueError(f'undefined layout_dirctn={layout_dirctn}')

# Categorical slider formatter
strfmt_catslider = '{val}'

slider_kwargs = dict(width=ctrl_width, height=ctrl_height, min_height=ctrl_height, 
    margin=(m_ctrltop, m_ctrlright, m_ctrlbottom, m_ctrlleft), sizing_mode='fixed')

In [None]:
ctrl_layout = ctrls_cfg['layout']

ctrl_optscp = set(ctrl_opts)
for i_rowctrl, ctrl_layrow in enumerate(ctrl_layout):
    for i_colctrl, ctrl_col in enumerate(ctrl_layrow):
        assert i_rowctrl < n_ctrlrows, dedent(f'''
            Control Layout Row Overflow:
                ctrl_col: {ctrl_col}
                Row Index: {i_rowctrl}
                Number of Rows: {n_ctrlrows}''')
        assert i_colctrl < n_ctrlcols, dedent(f'''
            Control Layout Col Overflow:
                ctrl_col: {ctrl_col}
                Col Index: {i_colctrl}
                Number of Cols: {n_ctrlcols}''')
        if ctrl_col is None:
            continue
        assert ctrl_col in ctrl_optscp, dedent(f'''
            The "{ctrl_col}" ctrl exists in the layout but 
            is either (1) not defined in the options, or 
            (2) defined multiple times in the layout:
                i_rowctrl: {i_rowctrl}
                i_colctrl: {i_colctrl}''')
        ctrl_optscp.remove(ctrl_col)

    # Appending some None values to complete the row
    ctrl_layrow += [None] * (n_ctrlcols - len(ctrl_layrow))
ctrl_layout += [[None] * n_ctrlcols] * (n_ctrlrows - len(ctrl_layout))

assert len(ctrl_optscp) == 0, dedent(f'''
    Some controls were not specified in the control layout:
        unspecified controls: {ctrl_optscp}''')

In [None]:
# The set of columns needed to specify the data starting and ending indices
# Note: You can substitute `fpidx` with `list(ctrl_types)` as each set 
#   of hyper-parameters is supposed to define a unique `fpidx`.
# Note: These columns may be over-specified, and the same data could be 
#   shared across some of them. If you need to update or tailor them to 
#   your specific needs, feel free to use your judgement.
glyph_spcfyngcols = list(ctrl_opts) + ['v_node', 'v_varname']

# The glyph to settings mapping
glyph_infoshie = get_subdict(dash_cfgdict, prefix='glyphs', pop=True)
tnsvars = {'glyph_spcfyngcols': glyph_spcfyngcols, 'chem_names': chem_names, 
    'x_histbins': x_histbins, 'chem_name2idxs': chem_name2idxs, 'n_bins': n_bins, 'n_chem': n_chem}
for key, val in list(glyph_infoshie.items()):
    if isinstance(val, str) and any(var in val for var in tnsvars):
        glyph_infoshie[key] = eval_formula(val, tnsvars, catch=True)
glyph_infos = hie2deep(glyph_infoshie, sep='/')

# The mapping of each glyph type to the required `v_repr`s in the data
glyph_type2vreprs = {'sct': ['pnts'], 'blb': ['mu', 'sig', 'phi'], 
    'hm': ['pnts'], 'vbs': ['pnts'], 'vas': ['pnts']}

# The excluded set of labels
gnrl_incdict = None
gnrl_excdict = None

# The maximum number of streamed rows to the client before we flush the entire data so far
max_strmdrows = 1_000_000

In [None]:
# The data-frame cartesian components
hpcdf0_data = dict()
hpcdf0_data['v_node'] = ['m_chmprthst', 'm01', 'm02', 'm03', 'm04', 'm05', 'm06', 'm06__NRMLZD', 
    'm07', 'm08', 'm09', 'n_prthst', 'n01', 'n02', 'n03', 'n04']
hpcdf0_data['chems'] = ['every'] + chem_names
hpcdf0_data['v_varname'] = ['orig', 'rcnst']
hpcdf0_data['v_repr'] = ['pnts']
hpcdf0_data.update(ctrl_opts)

hpcdf0 = CartDF({key: {key: vals} for key, vals in hpcdf0_data.items()})
hpcdf4 = hpcdf0.copy(deep=True)

In [None]:
# Adding the Glyph Name 
for frame_name, frame_info in frame_infos.items():
    frame_type = frame_info['type']
    hpcdf4frm = hpcdf4.select(frame_info.get('inc', None), frame_info.get('exc', None), 
        copy='shallow', reset_index=False, has_wildcards=True)

    # The figure division columns inside the frame
    fig_vizcols = frame_info['fig_vizcols']
    # The hue division columns inside the frame
    hue_vizcols = frame_info['hue_vizcols']

    idcols = list(ctrl_types) + list(hue_vizcols) + list(fig_vizcols)
    assert set(idcols).issubset(set(hpcdf4frm.columns)), dedent(f'''
        Missing columns in the frame database:
            {set(idcols) - set(hpcdf4frm.columns)}''')
    
    # Making sure the idcols define a 
    for compid, compdf in hpcdf4frm.data.items():
        idcols_comp = [col for col in idcols if col in compdf.columns]
        hpcdf4frm_dups = compdf[idcols_comp].duplicated()
        if hpcdf4frm_dups.any():
            # Raising an informative error about duplication groups in 
            # Pandas is unnecessarily complicated...
            dup_groups = compdf.groupby(idcols_comp, sort=False, observed=True)\
                .filter(lambda x: len(x) > 1)\
                .groupby(idcols_comp, group_keys=False, sort=False, observed=True)
            raise ValueError(dedent(f'''
                The control columns do not define a unique fpidx all the time. 
                You may have under-specified them:
                    Group ID columns: {idcols_comp}
                    Sample duplicated rows: {dup_groups.first()}'''))

    # Adding the glyph_name column to hpdf
    # Note: If you need some data to be used in multiple glyphs, make sure 
    #       to duplicate its row with various `glyph_name` values.
    vrepr2glyph = frame_info['vrepr2glyph']
    vrepr_comp = hpcdf4frm.get_comp('v_repr')
    hpcdf_framedata = hpcdf4frm.data
    vrepr_compdf = hpcdf_framedata[vrepr_comp].copy(deep=True)
    vrepr_compdf['glyph_name'] = [vrepr2glyph[v_repr] for v_repr in vrepr_compdf['v_repr']]
    vrepr_compdf['glyph_name'] = vrepr_compdf['glyph_name'].astype('category')
    hpcdf_framedata[vrepr_comp] = vrepr_compdf
    vrepr_compdf['glyph_name'] = [vrepr2glyph[v_repr] for v_repr in vrepr_compdf['v_repr']]
    hpcdf_frame = CartDF(hpcdf_framedata)
    
    frame_info['hpcdf'] = hpcdf_frame

# Bokeh Figures

In [None]:
# Applying the renaming business
renamer = Renamer(rnm_dict)
# Instantiating the color manager
colormanger = ColorManager(dflt_colors, color_spec)

# The bokeh figures collection
fig_infos = dict()

# The following defines the figure type, which is used to determine the axis labels, etc.
def get_figtype(fig_name, frame_type, dash_type):
    return frame_type

for frame_name, frame_info in frame_infos.items():
    # The frame type
    frame_type = frame_info['type']
    # The figure division columns inside the frame
    fig_vizcols = frame_info['fig_vizcols']
    # The figure name formatter string
    fig_namfmt = frame_info['fig_namfmt']
    # The figure title formatter string
    fig_ttlfmt = frame_info['fig_ttlfmt']
    # The number of frame columns
    n_framecols = frame_info['ncols']
    # The number of frame rows
    n_framerows = frame_info['nrows']
    # The frame catesian dataframe
    hpcdf_frame = frame_info['hpcdf']

    # # hpdf5gbfigs = hpdf5_frame.groupby(fig_vizcols, sort=False, observed=True).groups \
    # #     if len(fig_vizcols) > 0 else {tuple(): hpdf5_frame.index}
    # if len(fig_vizcols) > 0:
    #     hpdf5gbfigs = dict(list(hpdf5_frame.groupby(fig_vizcols, sort=False, observed=True)))
    #     hpdf5gbfigs = {fig_vizvals: figdf.index for fig_vizvals, figdf in hpdf5gbfigs.items()}
    # else:
    #     hpdf5gbfigs = {tuple(): hpdf5_frame.index}

    frame_fignames = []
    for i_fig, (fig_vizvals, hpcdf_fig) in enumerate(hpcdf_frame.groupby(fig_vizcols)):
        # Example:
        #   fpidx = 02_adhoc/09_klstudy.0.0
        #   eid = 'identity'
        #   vid in ('1x1xtnse', '1x1zpcalr', '1x1ztsne', ...)
        #   v_space in ('x', 'z')
        #   v_node in ('m_chmhist', ...)
        #   v_split in ('train', 'test', 'normal', ...)
        #   v_varname in ('orig', 'rcnst', 'genr', 'ltnt', ...)
        #   v_repr in ('pnts', 'mu', 'sig', 'phi')
        #   i_seed in (0, 1, ..., n_seeds - 1)
        #   i_epoch in (0, 1, ..., n_epochs - 1)
        #   i_time in (0, 1, ..., n_t - 1)
        fig_vizspec = dict(zip(fig_vizcols, fig_vizvals))
        fig_vizspec['frame_name'] = frame_name
        
        # The figure id/name
        fig_name = fig_namfmt.format(**fig_vizspec)
        # The formal figure title
        fig_title = fig_ttlfmt.format(**{col: renamer.encode(val) 
            for col, val in fig_vizspec.items()})

        assert fig_name not in fig_infos, dedent(f'''
            The "fig_namfmt" is under-specific. I found two figures with clashing names:
                fig_namfmt = {fig_namfmt}
                fig_vizcols = {fig_vizcols}
                fig_vizvals = {fig_vizvals}''')

        # Getting the figure type (e.g., 'scatter', 'bar', etc.)
        fig_type = get_figtype(fig_name, frame_type, dash_type)

        fig_infos[fig_name] = {'title': fig_title, 'type': fig_type, 
            'hpcdf': hpcdf_fig, 'frame': frame_name}
        
        frame_fignames.append(fig_name)

    # Registering the figure names associated with this frame
    frame_info['fig_names'] = frame_fignames

    # Adjusting the number of frame rows and columns
    n_framerows1, n_framecols1 = frame_info['nrows'], frame_info['ncols']
    n_framefigs = len(frame_fignames)
    if n_framerows1 is None:
        assert n_framecols1 is not None
        n_framecols2 = min(n_framecols1, n_framefigs)
        n_framerows2 = math.ceil(n_framefigs / n_framecols)
    elif n_framecols1 is None:
        assert n_framerows1 is not None
        n_framerows2 = min(n_framerows2, n_framefigs)
        n_framecols2 = math.ceil(n_framefigs / n_framerows2)
    else:
        raise ValueError(f'undefined case: {n_framerows1, n_framecols1}')
    frame_info['nrows'], frame_info['ncols'] = n_framerows2, n_framecols2
    n_framerows, n_framecols = frame_info['nrows'], frame_info['ncols']

    # Determining the figure width and height and color-bar's existence
    for i_fig, fig_name in enumerate(frame_fignames):
        fig_info = fig_infos[fig_name]
        fig_type = fig_info['type']
        fig_title = fig_info['title']
        frame_name = fig_info['frame']
        frame_info = frame_infos[frame_name]
        fig_kwargs = deepcopy(frame_info['figkwa'])
        # Whether the figure has color bar or not
        fig_hascb = (fig_type == 'hmap') and ((i_fig + 1) % n_framecols == 0)
        fig_width = fig_kwargs.pop('width', 600)
        fig_height = fig_kwargs.pop('height', 600)
        fig_width = int(fig_width * (1 + 0.2 * fig_hascb))
        # Writing down these decisions
        fig_info['width'] = fig_width
        fig_info['height'] = fig_height
        fig_info['has_cb'] = fig_hascb
        fig_info['kwargs'] = fig_kwargs

In [None]:
# Creating the Bokeh figures
diam_engpstfxs = {'0': 'm', '-3': 'mm', '-6': 'um', '-9': 'nm', '-12': 'pm', '-15': 'fm', '-18': 'am'}
diam_engargs = {'pstfxs': diam_engpstfxs, 'precision': 0, 
    'minpstfx': min(diam_engpstfxs, key=float), 'maxpstfx': max(diam_engpstfxs, key=float)}
mass_engpstfxs = {'0': ' kg', '-3': ' g', '-6': ' mg', '-9': ' ug', '-12': ' ng', '-15': ' pg', '-18': ' fg', '-21': ' ag'}
mass_engargs = {'pstfxs': mass_engpstfxs, 'precision': None, 
    'minpstfx': min(mass_engpstfxs, key=float), 'maxpstfx': max(mass_engpstfxs, key=float)}
count_engpstfxs = {'0': ' ', '3': ' K', '6': ' M', '9': ' B', '12': ' T', '15': ' Q'}
count_engargs = {'pstfxs': count_engpstfxs, 'precision': None, 
    'minpstfx': min(count_engpstfxs, key=float), 'maxpstfx': max(count_engpstfxs, key=float)}

diam_tickfrmter = CustomJSTickFormatter(args=diam_engargs, code=code_engfmttr)
diam_hoverfrmtr = CustomJSHover(args=diam_engargs, code='var tick = value;\n' + code_engfmttr)
mass_tickfrmter = CustomJSTickFormatter(args=mass_engargs, code=code_engfmttr)
mass_hoverfrmtr = CustomJSHover(args=mass_engargs, code='var tick = value;\n' + code_engfmttr)
count_tickfrmter = CustomJSTickFormatter(args=count_engargs, code=code_engfmttr)
count_hoverfrmtr = CustomJSHover(args=count_engargs, code='var tick = value;\n' + code_engfmttr)

name2tckfrmtr = {'mass': mass_tickfrmter, 'diam': diam_tickfrmter, 'count': count_tickfrmter}
name2hvrfrmtr = {'mass': mass_hoverfrmtr, 'diam': diam_hoverfrmtr, 'count': count_hoverfrmtr}

for fig_name, fig_info in fig_infos.items():
    fig_title = fig_info['title']
    fig_type = fig_info['type']
    fig_width = fig_info['width']
    fig_height = fig_info['height']
    hpcdf_fig = fig_info['hpcdf']
    fig_kwargs = fig_info['kwargs'].copy()
    # Getting the x and y axis labels
    x_axis_label = fig_kwargs.pop('x_label')
    y_axis_label = fig_kwargs.pop('y_label')
    hue_label = fig_kwargs.pop('hue_label')
    x_tickfrmtrnm = fig_kwargs.pop('x_tick_frmtr', None)
    y_tickfrmtrnm = fig_kwargs.pop('y_tick_frmtr', None)
    hue_tickfrmtrnm = fig_kwargs.pop('hue_tick_frmtr', None)
    x_range_padding = fig_kwargs.pop('x_range_padding', None)

    # Creating the main figure
    zoomtool = BoxZoomTool()
    figtools = [zoomtool, 'reset,pan,wheel_zoom,save']

    bkfig_kwargs = dict(width=fig_width, height=fig_height, title=fig_title, 
        tooltips=None, output_backend="webgl",
        sizing_mode="stretch_both", tools=figtools)

    # Applying the user-provided figure keyword arguments
    bkfig_kwargs.update(fig_kwargs)
    
    # Configuring the tooltips and hover tools 
    if fig_type == 'bar':
        tooltips = [(hue_label, '$name')]
        formatters = dict()
        if x_tickfrmtrnm is not None:
            tooltips.append((x_axis_label, "@x{custom}"))
            formatters['@x'] = name2hvrfrmtr[x_tickfrmtrnm]
        else:
            tooltips.append((x_axis_label, "@x"))
        tooltips.append((y_axis_label, "@$name"))
        hovertool = HoverTool(tooltips=tooltips, formatters=formatters, point_policy='follow_mouse')
        figtools += [hovertool]
    elif fig_type == 'hmap':
        tooltips = [(y_axis_label, '@y')]
        formatters = dict()
        if x_tickfrmtrnm is not None:
            tooltips.append((x_axis_label, "@x{custom}"))
            formatters['@x'] = name2hvrfrmtr[x_tickfrmtrnm]
        else:
            tooltips.append((x_axis_label, "@x"))
        if hue_tickfrmtrnm is not None:
            tooltips.append((hue_label, "@v{custom}"))
            formatters['@v'] = name2hvrfrmtr[hue_tickfrmtrnm]
        else:
            tooltips.append((hue_label, "@v"))
        hovertool = HoverTool(tooltips=tooltips, formatters=formatters)
        figtools += [hovertool]
    
    fig = figure(**bkfig_kwargs)
    
    # Setting the axis labels
    fig.xaxis.axis_label = x_axis_label
    fig.yaxis.axis_label = y_axis_label
    fig.xgrid.grid_line_color = None
    fig.ygrid.grid_line_color = None
    if x_range_padding is not None:
        fig.x_range.range_padding = x_range_padding

    # # The following restricts the data to relevant `v_repr`.
    # # Note: I have not tested the following, and it may even 
    #   be better the next alternative.
    # vreprs_rlvnt = set()
    # for glyph_name in hpcdf_fig.unqs('glyph_name'):
    #     glyph_type = glyph_infos[glyph_name]['type']
    #     vreprs_rlvnt.update(set(glyph_type2vreprs[glyph_type]))
    # incdict_fig = {'v_repr': list(vreprs_rlvnt)}
    
    # The following restricts the data to relevant `glyph_type`s 
    #   based on the figure type. If this seems like a poor idea, 
    #   give the previous lines a shot!
    glyphs_rlvnt = []
    for glyph_name in hpcdf_fig.unqs('glyph_name'):
        # Here you can customize which data points appear in each figure and glyph combination
        assert fig_type in ('scatter', 'bar', 'hmap'), 'the next line should be updated'
        glyph_type = glyph_infos[glyph_name]['type']

        if (fig_type, glyph_type) in (('scatter', 'sct'), ('scatter', 'blb'), 
            ('bar', 'vas'), ('bar', 'vbs'), ('hmap', 'hm')):
            glyphs_rlvnt.append(glyph_name)
    incdict_fig = {'glyph_name': glyphs_rlvnt}

    hpcdf_fig = hpcdf_fig.select(incdict={'glyph_name': glyphs_rlvnt}, 
        excdict=None, copy=False, reset_index=False, has_wildcards=False)
    fig_info['hpcdf'] = hpcdf_fig

    # Adding tick formatters
    if x_tickfrmtrnm is not None:
        fig.xaxis.formatter = name2tckfrmtr[x_tickfrmtrnm]
    if y_tickfrmtrnm is not None:
        fig.yaxis.formatter = name2tckfrmtr[y_tickfrmtrnm]

    fig_info['fig'] = fig

# The Visual Attributes Controllers
need_sctvisualctrls = any(fig_info['type'] == 'scatter' 
    for fig_name, fig_info in fig_infos.items())

if need_sctvisualctrls:
    # Blob Alpha slider
    alpha_slider = Slider(start=0.0, end=0.1, step=0.005, 
        value=dflt_blbalpha, title='Blob Alpha', **slider_kwargs)

    # Scatter points size slider
    size_slider = Slider(start=0, end=10, step=1, value=dflt_pntssize, 
        title='Points Size', **slider_kwargs)

# Bokeh Data Sources and Glyphs

In [None]:
# Creating Bokeh sources, views, and glyphs
bk_srcshie = dict()
bk_glyphviewshie = dict()

for fig_name, fig_info in fig_infos.items():
    # The Bokeh figure object
    fig = fig_info['fig']
    # The formal figure title
    fig_title = fig_info['title']
    hpcdf_fig = fig_info['hpcdf']
    for glyph_name, glyph_cdf in hpcdf_fig.groupby('glyph_name', sort=False, observed=True):
        glyph_info = glyph_infos[glyph_name]
        # Example:
        #    glyph_name = 'sctr'
        #    glyph_type = 'sct'
        #    glyph_bkcols = ['x', 'y']
        #    glyph_keydims = ['lbl_name']
        
        # The glyph column names inside the bokeh sources 
        # (e.g., glyph_bkcols = ['x', 'y'])
        glyph_bkcols = glyph_info['bkcols']
        # The glyph key dimensions (e.g., the 'sct' glyph needs to be `lbl_name`-specific).
        glyph_keydims = glyph_info['keydims']
        # The glyph type
        glyph_type = glyph_info['type']
        # The frame catesian dataframe
        glyph_cdfgb = glyph_cdf.groupby(glyph_keydims, sort=False, observed=True)

        for glyphkeyvals, figglyph_cdf in glyph_cdfgb:
            # Example: 
            #   glyph_keydims = ['lbl_name']
            #   glyphkeyvals = ('train/orig',)
            #   glyph_vars = {'lbl_name': 'train/orig'}
            glyph_bksrckey = (glyph_name, *glyphkeyvals)
            glyph_vars = dict(zip(glyph_keydims, glyphkeyvals))

            # Creating a new ColumnDataSource if it doesn't exist in bk_srcshie
            if glyph_bksrckey not in bk_srcshie:
                bk_src = ColumnDataSource(data={col: [] for col in glyph_bkcols})
                # Storing the Bokeh ColumnDataSource for this particular glyph
                bk_srcinfo = {'src': bk_src, 'len': 0, 'book': dict()}
                bk_srcshie[glyph_bksrckey] = bk_srcinfo
            else:
                bk_srcinfo = bk_srcshie[glyph_bksrckey]
                bk_src = bk_srcinfo['src']
            
            # Storing the Bokeh CDSView for this particular glyph
            bk_view = CDSView(filter=IndexFilter(indices=[]))

            if 'lbl_name' in glyph_vars:
                lbl_name = glyph_vars['lbl_name']
                # The label's color (e.g., `lbl_color = '#001c7f'`)
                lbl_color = colormanger(lbl_name)
                # The labels's formal name (e.g., `lbl_frml = 'Train Original'`)
                lbl_frml = renamer.encode(lbl_name)
            else:
                lbl_name, lbl_frml, lbl_color = None, None, None

            if glyph_type in ('vas', 'vbs'):
                glyph_ysubcol = glyph_info.get('y_subcol', None)
                if glyph_ysubcol is not None:
                    # `bar_names` is the same as the `chem_names` for typical `vas`.
                    fig_chemslst = figglyph_cdf.unqs(glyph_ysubcol)
                    assert len(fig_chemslst) == 1
                    fig_chems = fig_chemslst[0]
                    bar_names = list(glyph_info['y_names']) if fig_chems == 'every' else fig_chems.split('_')
                else:
                    bar_names = list(glyph_info['y_names'])
                assert {'x', *bar_names}.issubset(set(glyph_bkcols))
            else:
                bar_names = None

            if glyph_type == 'sct':
                assert {'x', 'y'}.issubset(set(glyph_bkcols))
                # Creating the scatter points glyph
                bk_glyph = fig.scatter(x="x", y="y", 
                    source=bk_src, view=bk_view, color=lbl_color, legend_label=lbl_frml,
                    size=dflt_pntssize, fill_alpha=1.0, muted_alpha=0.01)
                
                # Linking the visual attribute controllers
                size_slider.js_link('value', bk_glyph.glyph, 'size')
            elif glyph_type == 'blb':
                assert {'x', 'y', 'w', 'h', 'a'}.issubset(set(glyph_bkcols))
                # Creating the blob glyph
                bk_glyph = fig.ellipse(x="x", y="y", width="w", height="h", angle="a", 
                    source=bk_src, view=bk_view, fill_color=lbl_color, legend_label=lbl_frml, 
                    fill_alpha=dflt_blbalpha, line_alpha=dflt_blbalpha, 
                    muted_alpha=0.001, line_color='black')

                # Linking the visual attribute controllers
                alpha_slider.js_link('value', bk_glyph.glyph, 'fill_alpha')
                alpha_slider.js_link('value', bk_glyph.glyph, 'line_alpha')
            elif glyph_type == 'vas':
                color_list = [colormanger(stckr) for stckr in bar_names]
                bk_glyph = fig.varea_stack(stackers=bar_names, x='x', fill_color=color_list, 
                    source=bk_src, view=bk_view, legend_label=bar_names, muted_alpha=0.001)
            elif glyph_type == 'vbs':
                color_list = [colormanger(stckr) for stckr in bar_names]
                bk_glyph = fig.vbar_stack(stackers=bar_names, x='x', fill_color=color_list, 
                    source=bk_src, view=bk_view, legend_label=bar_names, muted_alpha=0.001)
            elif glyph_type == 'hm':
                assert {'x', 'y', 'w'}.issubset(set(glyph_bkcols))
                # `bar_names` is the same as the `chem_names` for typical `vas`, `vbs`, and `hm`.
                bar_names = list(glyph_info['y_names']) 
                color_low, color_high = glyph_info['v_range']
                lowhigh_dict = dict()
                lowhigh_dict['low'] = color_low
                lowhigh_dict['high'] = color_high
                cmap = linear_cmap("v", ['white'] + list(bokeh.palettes.Reds[256][::-1]), 
                    nan_color='white', **lowhigh_dict)
                bk_rectglyph = fig.rect(y='y', x='x', width='w', height=1, fill_color=cmap, 
                    legend_label=fig_name, source=bk_src, view=bk_view, line_color='black')
                    
                fig_hascb = fig_info['has_cb']
                if fig_hascb:
                    cb_kwargs = dict(ticker=BasicTicker(desired_num_ticks=5),
                        label_standoff=6, border_line_color=None, padding=5)

                    fig_kwargs = fig_info['kwargs']
                    hue_tickfrmtrnm = fig_kwargs.get('hue_tick_frmtr', None)
                    if hue_tickfrmtrnm in name2tckfrmtr:
                        cb_kwargs['formatter'] = name2tckfrmtr[hue_tickfrmtrnm] 

                    bk_cbglyph = bk_rectglyph.construct_color_bar(**cb_kwargs)
                    fig.add_layout(bk_cbglyph, 'right')
                else:
                    bk_cbglyph = None
                
                bk_glyph = (bk_rectglyph, bk_cbglyph)
            else:
                raise ValueError(f'undefined glyph_type={glyph_type}')
            
            # Storing the Bokeh glyph for this particular figure and label and glyph
            bk_glyphinfo = {'glyph': bk_glyph, 'view': bk_view}
            bk_glyphviewshie[(fig_name, glyph_name, *glyphkeyvals)] = bk_glyphinfo

for fig_name, fig_info in fig_infos.items():
    fig = fig_info['fig']
    fig_type = fig_info['type']
    # Clicking on each legend iterm makes it transparent
    if fig_type == 'scatter':
        fig.legend.click_policy = "mute"
        fig.legend.location = "top_right"
    elif fig_type in ('bar', 'hmap'):
        fig.legend.visible = False

# Bokeh Controls

In [None]:
# First we need to compartmentalize the controls into the cartdf components
ctrl_compcols = defaultdict(dict)
for ctrl_col, ctrl_type in ctrl_types.items():
    compid = hpcdf4.get_comp(ctrl_col)
    ctrl_compcols[compid][ctrl_col] = ctrl_type

ctrl_refs = dict()
for compid, control_col2types in ctrl_compcols.items():
    # Example:
    #    compid in ('hp', 'var', 'i_time', ...)
    #    compid = 'hp'
    #    control_col2types = {'cri/kl/w/sig': 'catslider', 'cri/kl/w/sig': 'catslider'}

    # `hpdf5` is a regular pd.DataFrame for this particular component of the `hpcdf`.
    hpdf5 = hpcdf4.data[compid]
    
    ctrl_compinfos = []
    for ctrl_col, ctrl_type in control_col2types.items():
        ctrl_ttl = renamer.encode(ctrl_col)
        vals_allnp = hpdf5[ctrl_col].values
        vals_allunq = np.unique(vals_allnp) if isinstance(vals_allnp, np.ndarray) else vals_allnp.unique()
        vals_allunqfrml = [renamer.encode(val) for val in vals_allunq]
        
        # Inferring the control type (i.e., slider or menu)
        if ctrl_type is None:
            ctrl_type = 'catslider' if np.issubdtype(vals_allnp.dtype, np.number) else 'menu'
        
        # The categorical slider needs a mapping from integers to strings. 
        # I wish bokeh was not so strict about requiring any notion of 
        # categories to be strictly string.
        ctrl_enc, ctrl_dec = None, None
        if ctrl_type in ('catslider', 'radiobtngrp'):
            ctrl_encdict = dict()
            # for val in vals_allnp.tolist():
            for val in vals_allunq.tolist():
                val_frml = renamer.encode(val)
                ctrl_encdict[val_frml] = strfmt_catslider.format(val=val_frml)
            ctrl_encoder = Renamer(ctrl_encdict)
            ctrl_enc = ctrl_encoder.dict
            ctrl_dec = ctrl_encoder.invdict

        # The first key in the control book options will define the default value
        top_val = hpdf5[ctrl_col].iloc[0]
        topval_frml = renamer.encode(top_val)
        if ctrl_enc is not None:
            topval_frml = ctrl_enc[topval_frml]

        ctrl_info = {'col': ctrl_col, 'title': ctrl_ttl, 'type': ctrl_type, 
            'enc': ctrl_enc, 'dec': ctrl_dec, 'dflt': topval_frml}
        ctrl_compinfos.append(ctrl_info)

    ctrl_ref = {'infos': ctrl_compinfos, 'hpdf': hpdf5}
    ctrl_refs[compid] = ctrl_ref

In [None]:
for compid, ctrl_ref in ctrl_refs.items():
    ctrl_infos = ctrl_ref['infos']
    hpdf5 = ctrl_ref['hpdf']
    
    # # Deprecated: The following is slow
    # # This is a deep mapping of control values to the scene id
    # # For example:
    # #    ctrl_cols =[    'vid', 'i_seed', 'cri/kl/mu/w', 'cri/kl/sig/w', 'nn/ltnt/dim']
    # #    ctrl_book   ['X TSNE']     ['0']     ['1e-06']       ['0.001']           ['9'] == 864
    # #    sceneid == 864 
    # ctrl_book = make_ctrlbook(ctrl_infos, hpdf5, renamer, i_ctrl=0)

    # Here's a more performant replacement
    ctrl_cols = [ctrl_info['col'] for ctrl_info in ctrl_infos]
    # hpdf5ctrl1 = hpdf5[ctrl_cols + ['sceneid']].drop_duplicates()
    hpdf5ctrl1 = hpdf5[ctrl_cols].drop_duplicates()
    hpdf5ctrl2 = hpdf5ctrl1.sort_values(ctrl_cols).reset_index(drop=True)
    # `hpdf5ctrl3` will contain the original values of each column
    hpdf5ctrl3 = hpdf5ctrl2[ctrl_cols]
    # `hpdf5ctrl4` will contain the renamed values of each column
    hpdf5ctrl4 = hpdf5ctrl3.copy(deep=True)
    # Applying the renaming business
    for ctrl_info in ctrl_infos:
        ctrl_col = ctrl_info['col']
        ctrl_enc = ctrl_info['enc']
        is_colcat = (hpdf5ctrl4.dtypes[ctrl_col] == 'category')
        if not is_colcat:
            hpdf5ctrl4[f'{ctrl_col}/frml'] = None
        val2frml = dict()
        for val, hpdf_val in hpdf5ctrl4.groupby(ctrl_col, sort=True, observed=True):
            val_frml = renamer.encode(val)
            if ctrl_enc is not None:
                val_frml = ctrl_enc[val_frml]
            
            if is_colcat:
                val2frml[val] = val_frml
            else:
                hpdf5ctrl4.loc[hpdf_val.index, f'{ctrl_col}/frml'] = val_frml
        
        if len(val2frml) > 0:
            hpdf5ctrl4[ctrl_col] = hpdf5ctrl4[ctrl_col].cat.rename_categories(val2frml)
        if not is_colcat:
            hpdf5ctrl4[ctrl_col] = hpdf5ctrl4[f'{ctrl_col}/frml']
            hpdf5ctrl4 = hpdf5ctrl4.drop(columns=f'{ctrl_col}/frml')

    # ctrl_bookhie = dict(zip(hpdf5ctrl4.itertuples(index=False, name='Control'), hpdf5ctrl2['sceneid']))
    ctrl_bookhie = dict(zip(hpdf5ctrl4.itertuples(index=False, name='FormalControl'), 
        hpdf5ctrl3.itertuples(index=False, name='RawControl')))
    ctrl_book = hie2deep(ctrl_bookhie)
    ctrl_ref['book'] = ctrl_book
    ctrl_ref['cols'] = ctrl_cols

In [None]:
for compid, ctrl_ref in ctrl_refs.items():
    hpdf5 = ctrl_ref['hpdf']
    ctrl_book = ctrl_ref['book']
    ctrl_infos = ctrl_ref['infos']
    ctrl_cols = ctrl_ref['cols']

    # Instantiating the Bokeh controllers
    ctrl_bookrcrsd = ctrl_book
    for ctrl_info in ctrl_infos:
        ctrl_col = ctrl_info['col']
        ctrl_ttl = ctrl_info['title']
        ctrl_type = ctrl_info['type']
        ctrl_enc = ctrl_info['enc']
        ctrl_dec = ctrl_info['dec']
        ctrl_dflt = ctrl_info['dflt']
        ctrl_vals = list(ctrl_bookrcrsd)
        
        if ctrl_type == 'catslider':
            control = CategoricalSlider(categories=ctrl_vals, value=ctrl_dflt, 
                title=ctrl_ttl, **slider_kwargs)
            container = control
        elif ctrl_type == 'radiobtngrp':
            control = RadioButtonGroup(labels=ctrl_vals, active=ctrl_vals.index(ctrl_dflt))
            title_div = Div(text=f'{ctrl_ttl}:', styles={'color': header_color})
            container = column([title_div, control], background=background_color, 
                width=ctrl_width, height=ctrl_height, min_height=ctrl_height, 
                margin=(0, m_ctrlright, 0, m_ctrlleft), sizing_mode='fixed')
        elif ctrl_type == 'menu':
            options = [(val, str(val)) for val in ctrl_vals]
            control = Select(options=options, value=ctrl_dflt, 
                title=ctrl_ttl, width=ctrl_width, height=ctrl_height, 
                margin=(m_ctrltop, m_ctrlright, m_ctrlbottom, m_ctrlleft),
                sizing_mode='fixed')
            container = control
        else:
            raise ValueError(f'undefined ctrl_type = {ctrl_type}')
        
        # Adding the controller
        ctrl_info['control'] = control
        ctrl_info['container'] = container

        # Restricting `hpdf6` for the next round
        ctrl_bookrcrsd = ctrl_bookrcrsd[ctrl_dflt]

# Sharing Axes in Figures

In [None]:
for frame_name, frame_info in frame_infos.items():
    frame_type = frame_info['type']
    n_framerows = frame_info['nrows']
    n_framecols = frame_info['ncols']

    if frame_type in ('bar', 'hmap', 'scatter'):
        first_fig = None
        sharex = frame_info['sharex']
        sharey = frame_info['sharey']

        get_xfiggrp, get_yfiggrp = None, None
        for axname, shareax in (('x', sharex), ('y', sharey)):
            if shareax in ('all', True):
                def get_axfiggrp(*args, **kwargs):
                    return 0
            elif shareax in ('none', False, None):
                def get_axfiggrp(i_fig, *args, **kwargs):
                    return i_fig
            elif shareax == 'row':
                def get_axfiggrp(i_fig, fig_name, n_framerows, n_framecols):
                    return i_fig // n_framecols
            elif shareax == 'col':
                def get_axfiggrp(i_fig, fig_name, n_framerows, n_framecols):
                    return i_fig % n_framecols
            elif callable(shareax):
                get_axfiggrp = shareax
            else:
                raise ValueError(f'undefined shareax={shareax} for {frame_name}')

            if axname == 'x':
                get_xfiggrp = get_axfiggrp
            elif axname == 'y':
                get_yfiggrp = get_axfiggrp
            else:
                raise ValueError(f'undefined axname={axname}')

        frame_figs = frame_info['fig_names']

        frame_xgrp2fignms = defaultdict(list)
        frame_ygrp2fignms = defaultdict(list)
        frame_xgrpicol2fignmshie = defaultdict(list)
        frame_ygrpirow2fignmshie = defaultdict(list)
        for i_fig, fig_name in enumerate(frame_figs):
            # The row index of the figure within the frame
            i_frmrow = i_fig // n_framecols
            # The column index of the figure within the frame
            i_frmcol = i_fig % n_framecols
            # The x axis group index of the figure within the frame
            fig_frmgrpx = get_xfiggrp(i_fig=i_fig, fig_name=fig_name, 
                n_framerows=n_framerows, n_framecols=n_framecols)
            # The y axis group index of the figure within the frame
            fig_frmgrpy = get_yfiggrp(i_fig=i_fig, fig_name=fig_name, 
                n_framerows=n_framerows, n_framecols=n_framecols)
            
            # Saving these variables to the figure infos
            fig_info = fig_infos[fig_name]
            fig_info['i_frmrow'] = i_frmrow
            fig_info['i_frmcol'] = i_frmcol
            fig_info['framegrp/x'] = fig_frmgrpx
            fig_info['framegrp/y'] = fig_frmgrpy

            # Compiling an inverse lookup table of the same vars
            frame_xgrp2fignms[fig_frmgrpx].append(fig_name)
            frame_ygrp2fignms[fig_frmgrpy].append(fig_name)
            frame_xgrpicol2fignmshie[(i_frmcol, fig_frmgrpx)].append((i_frmrow, fig_name))
            frame_ygrpirow2fignmshie[(i_frmrow, fig_frmgrpy)].append((i_frmcol, fig_name))
        
        # Configuring the x axis visibility among groups
        for (i_frmcol, fig_frmgrpx), fig_icolnames in frame_xgrpicol2fignmshie.items():
            # The bottom row's x axis should certainly be visible
            last_ifrmrow, btm_figname = fig_icolnames[-1]
            fig_infos[btm_figname]['fig'].xaxis.visible = True
            for i_frmrow, fig_name in fig_icolnames[-2::-1]:
                fig = fig_infos[fig_name]['fig']
                # The other figures in the same frame group should only have their 
                # x axis invisible if the figure to their bottom is in their group.
                fig.xaxis.visible = (i_frmrow != (last_ifrmrow - 1))
                last_ifrmrow = i_frmrow

        # Configuring the y axis visibility among groups
        for (i_frmrow, fig_frmgrpy), fig_irownames in frame_ygrpirow2fignmshie.items():
            # The leftmost column's y axis should certainly be visible
            last_ifrmcol, left_figname = fig_irownames[0]
            fig_infos[left_figname]['fig'].yaxis.visible = True
            for i_frmcol, fig_name in fig_irownames[1:]:
                fig = fig_infos[fig_name]['fig']
                # The other figures in the same frame group should only have their 
                # y axis invisible if the figure to their left is in their group.
                fig.yaxis.visible = (i_frmcol != (last_ifrmcol + 1))
                last_ifrmcol = i_frmcol
        
        # All the figures within the same x frame group should share the same x range
        for fig_frmgrpx, fig_names in frame_xgrp2fignms.items():
            first_fig = fig_infos[fig_names[0]]['fig']
            for fig_name in fig_names[1:]:
                fig = fig_infos[fig_name]['fig']
                fig.x_range = first_fig.x_range

        # All the figures within the same y frame group should share the same y range
        for fig_frmgrpy, fig_names in frame_ygrp2fignms.items():
            first_fig = fig_infos[fig_names[0]]['fig']
            for fig_name in fig_names[1:]:
                fig = fig_infos[fig_name]['fig']
                fig.y_range = first_fig.y_range

        # Some additional frame figure tweaks
        for i_fig, fig_name in enumerate(frame_figs):
            fig = fig_infos[fig_name]['fig']
            fig.min_border = 10
            fig.sizing_mode = 'stretch_both'

# The Callback Function

In [None]:
class RioHandler:
    def __init__(self, data_dict, hash_data, trn_spltidxs, n_seeds, 
        eval_bs, n_snr, n_t, tch_device, tch_dtype, 
        n_ppcache, n_origcache, n_rcnstcache, verbose):

        # The set of already instantiated and loaded PP objects
        self.pp_cache = dict()
        self.orig_cache = dict()
        self.rcnst_cache = dict()
        
        self.data_dict = data_dict
        self.hash_data = hash_data
        self.trn_spltidxs = trn_spltidxs
        self.n_seeds = n_seeds
        self.eval_bs = eval_bs
        self.n_snr = n_snr
        self.n_t = n_t
        self.tch_device = tch_device
        self.tch_dtype = tch_dtype
        self.n_ppcache = n_ppcache
        self.n_origcache = n_origcache
        self.n_rcnstcache = n_rcnstcache
        self.verbose = verbose
        
    @staticmethod
    def get_ppmdlcfg(config):
        """
        Generates the PP config according to the framework described in the 
        presentation slides.

        Parameters
        ----------
        config: (dict) The dictionary of hyper-parameters.

        Returns
        -------
        pp_mdlcfg: (dict) The PP module config. 
        """
        configcp = dict() if config is None else dict(config)

        eps = configcp.pop('eps', 1e-100)
        # The magnitude dimensionality; one of `'n_chem', 'k_bins', 'one'`.
        magdim = configcp.pop('magdim')
        # The magnitude normalization norm
        magpnrm = configcp.pop('magpnrm')
        # The magnitude exponent
        magexp = configcp.pop('magexp')
        # Whether to subtract mu3
        magshft = configcp.pop('magshft')
        # Whether to scale sig4
        magscl = configcp.pop('magscl')
        # The normalized image exponent
        nrmexp = configcp.pop('nrmexp')
        # Whether to subtract mu7
        nrmshft = configcp.pop('nrmshft')
        # Whether to scale sig8
        nrmscl = configcp.pop('nrmscl')
        # The mu7 and sig8 dimensionality; one of `'n_chem', 'k_bins', 'one', 'n_chem, k_bins'`.
        nrmscldim = configcp.pop('nrmscldim')

        # The epsilon added to the particle counts
        cnteps = configcp.pop('cnteps', 1000)
        # The counts exponent
        cntexp = configcp.pop('cntexp')
        # Whether to subtract mu2 for the counts
        cntshft = configcp.pop('cntshft')
        # Whether to scale by sig3 for the counts
        cntscl = configcp.pop('cntscl')
        # The mu3 and sig3 dimensionality; one of `'k_bins', 'one'`.
        cntscldim = configcp.pop('cntscldim')

        # Translating some of the options
        onoff_dict = {'on': True, 'off': False, True: True, False: False}
        magshft = onoff_dict[magshft]
        magscl = onoff_dict[magscl]
        nrmshft = onoff_dict[nrmshft]
        nrmscl = onoff_dict[nrmscl]
        cntshft = onoff_dict[cntshft]
        cntscl = onoff_dict[cntscl]
        magpnrm = {1: 1, 2: 2, 'l1': 1, 'l2': 2}[magpnrm]

        assert len(configcp) == 0, f'unused options: {configcp}'

        pp_mdlcfg = dict()

        # The input and output
        pp_mdlcfg.update(ruyaml.safe_load('''
            input:
                m_chmprthst: [n_seeds, n_mb, n_chem, k_bins]
                n_prthst:    [n_seeds, n_mb, 1,      k_bins]
            output: 
                m_chmprtnrm: m09b
                m_prthstmag: m05b
                n_prthstnrm: n04b
        '''))

        # Defining the layers
        pp_mdllayers = dict()
        #######################################
        ######## M1: Adding an Epsilon ########
        #######################################
        pp_mdllayers.update(ruyaml.safe_load(f'''
            # M1 = m0 + eps
            m01:
                type: add
                value: {eps}
                input: m_chmprthst
                shape: [n_seeds, n_mb, n_chem, k_bins]
        '''))

        ############################################################
        ################## Part I: The Magnitude ###################
        ############################################################

        #######################################
        #### M2: Taking the Lp-Norm of X1 #####
        #######################################
        # The magnitude reduction dimensions
        mag_rdcdims = {'n_chem': [-1], 'k_bins': [-2], 'one': [-1, -2]}[magdim]
        # The magnitude's number of channels and bins
        n_chnlsmag, n_binsmag = {'n_chem': ('n_chem', '1'), 'k_bins': ('1', 'k_bins'), 'one': ('1', '1')}[magdim]

        pp_mdllayers.update(ruyaml.safe_load(f'''
            # M2 = \| M1 \| 
            m02:
                type: pnorm
                dim: {mag_rdcdims}
                pnorm: {magpnrm}
                shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
        '''))

        #######################################
        ## M3: Exponentiating the magnitude ###
        #######################################
        pp_mdllayers.update(ruyaml.safe_load(f'''
            # M3 = M2 ^ alpha
            m03:
                type: sgnpow
                exponent: {magexp}
                shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
        '''))

        #######################################
        ###### M4: Subtracting the Mean #######
        #######################################
        if magshft:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # M4 = M3 - mu3
                m04:
                    type: shift
                    dim: {mag_rdcdims}
                    shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
            '''))
        else:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # M4 = M3
                m04:
                    type: add
                    value: 0
                    shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
            '''))

        #######################################
        ####### M5: Dividing the Scale ########
        #######################################
        if magscl:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # M5 = M4 / sigma4
                m05:
                    type: scale
                    dim: {mag_rdcdims}
                    shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
            '''))
        else:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # M5 = M4
                m05:
                    type: add
                    value: 0
                    shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
            '''))

        pp_mdllayers.update(ruyaml.safe_load(f'''
            # Scalar shift and scale for properly-scaled noise injection
            m05a:
                type: shift
                dim: [-1, -2]
                shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
            m05b:
                type: scale
                dim: [-1, -2]
                shape: [n_seeds, n_mb, {n_chnlsmag}, {n_binsmag}]
        '''))
        
        ############################################################
        ################ Part II: The Normalization ################
        ############################################################

        #######################################
        ######### M6: Normalization ###########
        #######################################
        pp_mdllayers.update(ruyaml.safe_load(f'''
            # M6 = M1 / M2
            m06:
                type: normalize
                dim: {mag_rdcdims}
                pnorm: {magpnrm}
                input: m01
                shape: [n_seeds, n_mb, n_chem, k_bins]
        '''))

        #######################################
        ## M7: Exponentiating the normalztn ###
        #######################################
        pp_mdllayers.update(ruyaml.safe_load(f'''
            # M7 = M6 ^ alpha
            m07:
                type: sgnpow
                exponent: {nrmexp}
                shape: [n_seeds, n_mb, n_chem, k_bins]
        '''))

        #######################################
        ###### M8: Subtracting the Mean #######
        #######################################

        # The magnitude reduction dimensions
        mu7sig8_rdcdims = {'n_chem': [-1], 'k_bins': [-2], 'one': [-1, -2], 'n_chem, k_bins': []}[nrmscldim]

        if nrmshft:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # M8 = M7 - mu7
                m08:
                    type: shift
                    dim: {mu7sig8_rdcdims}
                    shape: [n_seeds, n_mb, n_chem, k_bins]
            '''))
        else:
            pp_mdllayers.update(ruyaml.safe_load('''
                # M8 = M7
                m08:
                    type: add
                    value: 0
                    shape: [n_seeds, n_mb, n_chem, k_bins]
            '''))

        #######################################
        ####### M9: Dividing the Scale ########
        #######################################
        if nrmscl:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # M9 = M8 / sig8
                m09:
                    type: scale
                    dim: {mu7sig8_rdcdims}
                    shape: [n_seeds, n_mb, n_chem, k_bins]
            '''))
        else:
            pp_mdllayers.update(ruyaml.safe_load('''
                # M9 = M8
                m09:
                    type: add
                    value: 0
                    shape: [n_seeds, n_mb, n_chem, k_bins]
            '''))
        
        pp_mdllayers.update(ruyaml.safe_load('''
            # Scalar shift and scale for properly-scaled noise injection
            m09a:
                type: shift
                dim: [-1, -2]
                shape: [n_seeds, n_mb, n_chem, k_bins]
            m09b:
                type: scale
                dim: [-1, -2]
                shape: [n_seeds, n_mb, n_chem, k_bins]
        '''))

        ############################################################
        ################### Part III: The Count ####################
        ############################################################
        cnt_rdcdims = {'k_bins': [], 'one': [-1]}[cntscldim]

        #######################################
        ######## N1: Adding an Epsilon ########
        #######################################
        pp_mdllayers.update(ruyaml.safe_load(f'''
            # n1 = n0 + cnteps
            n01:
                type: add
                input: n_prthst
                value: {cnteps}
                shape: [n_seeds, n_mb, 1, k_bins]
        '''))

        #######################################
        ## N2: Exponentiating the magnitude ###
        #######################################
        pp_mdllayers.update(ruyaml.safe_load(f'''
            # N2 = N1 ^ alpha
            n02:
                type: sgnpow
                exponent: {cntexp}
                shape: [n_seeds, n_mb, 1, k_bins]
        '''))

        #######################################
        ###### N3: Subtracting the Mean #######
        #######################################
        if cntshft:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # N3 = N2 - mu2
                n03:
                    type: shift
                    dim: {cnt_rdcdims}
                    shape: [n_seeds, n_mb, 1, k_bins]
            '''))
        else:
            pp_mdllayers.update(ruyaml.safe_load('''
                # N3 = N2
                n03:
                    type: add
                    value: 0
                    shape: [n_seeds, n_mb, 1, k_bins]
            '''))

        #######################################
        ####### N4: Dividing the Scale ########
        #######################################
        if cntscl:
            pp_mdllayers.update(ruyaml.safe_load(f'''
                # N4 = N3 / sigma3
                n04:
                    type: scale
                    dim: {cnt_rdcdims}
                    shape: [n_seeds, n_mb, 1, k_bins]
            '''))
        else:
            pp_mdllayers.update(ruyaml.safe_load('''
                # N4 = N3
                n04:
                    type: add
                    value: 0
                    shape: [n_seeds, n_mb, 1, k_bins]
            '''))

        pp_mdllayers.update(ruyaml.safe_load('''
            # Scalar shift and scale for properly-scaled noise injection
            n04a:
                type: shift
                dim: [-1]
                shape: [n_seeds, n_mb, 1, k_bins]
            n04b:
                type: scale
                dim: [-1]
                shape: [n_seeds, n_mb, 1, k_bins]
        '''))

        pp_mdlcfg['layers'] = pp_mdllayers

        return pp_mdlcfg

    @torch.no_grad()
    def __call__(self, rowdict):
        data_dict = self.data_dict
        hash_data = self.hash_data
        trn_spltidxs = self.trn_spltidxs
        n_seeds = self.n_seeds
        eval_bs = self.eval_bs
        n_snr = self.n_snr
        n_t = self.n_t
        tch_device = self.tch_device
        tch_dtype = self.tch_dtype
        verbose = self.verbose

        pp_cache = self.pp_cache
        orig_cache = self.orig_cache
        rcnst_cache = self.rcnst_cache
        n_ppcache = self.n_ppcache
        n_origcache = self.n_origcache
        n_rcnstcache = self.n_rcnstcache

        rowdictcp = dict(rowdict)
        if verbose:
            print(f'Working on {rowdictcp}', flush=True)

        ######### Creating the Pre-Processor
        # Example:
        #   pp_idkeys = ('magdim', 'magpnrm', 'magexp', 'magshft', 'magscl', 
        #                'nrmexp', 'nrmshft', 'nrmscl', 'nrmscldim')
        #
        #   pp_idvals = ('k_bins', 1, 1.0, False, True, 1.0, False, True, 'n_chem')
        #
        #   ppcfg = {'magdim': 'k_bins', 'magpnrm': 1, 'magexp': 1.0, 'magshft': False, 
        #       'magscl': True, 'nrmexp': 1.0, 'nrmshft': False, 'nrmscl': True, 
        #       'nrmscldim': 'n_chem'}
        #
        #   v_node = 'm_chmprthst'
        #   v_varname = 'orig'
        #   chems = 'every'
        #   v_repr = 'pnts'
        #   noise_amp = 0.0
        #   noise_seed = 0
        #   i_seed = 0

        pp_idkeys = ('magdim', 'magpnrm', 'magexp', 'magshft', 'magscl', 
            'nrmexp', 'nrmshft', 'nrmscl', 'nrmscldim', 'cntexp', 
            'cntshft', 'cntscl', 'cntscldim')
        pp_idvals = tuple(rowdictcp.pop(key) for key in pp_idkeys)
        ppcfg = dict(zip(pp_idkeys, pp_idvals))
        pp_mdlcfg = self.get_ppmdlcfg(ppcfg)
        if pp_idvals in pp_cache:
            (cri_pp, u_dims, x_dims) = pp_cache[pp_idvals]
        else:
            cri_pp = make_pp('cstm', dict(), {'cstm': pp_mdlcfg}, tch_device, tch_dtype)
            cri_pp.infer(data_dict, trn_spltidxs, n_seeds, eval_bs, hash_data=hash_data, 
                cache_path=f'{cache_dir}/pptmp.tar')
            u_dims = cri_pp.get_dims(data_dims, types='output', n_seeds=n_seeds, n_mb=1, n_dims=2)
            x_dims = cri_pp.get_dims(data_dims, types='input', n_seeds=n_seeds, n_mb=1, n_dims=2)
            pp_cache[pp_idvals] = (cri_pp, u_dims, x_dims)

        ######### Getting the data
        # The scenario / trajectory index
        i_snr = rowdictcp.pop('i_snr')
        # The time point index
        i_time = rowdictcp.pop('i_time')
        # The train/test split random seed index
        i_seed = rowdictcp.pop('i_seed')
        # Now, we collect these indices from the full data.

        key_origcache = (*pp_idvals, i_snr, i_time)
        if key_origcache not in orig_cache:
            x_mbs = dict()
            for x_node, (n_xchnls, n_xlen) in x_dims.items():
                x_tnsrall_ = data_dict[x_node]
                assert x_tnsrall_.shape == (n_snr * n_t, n_xchnls, n_xlen)
                x_tnsrall = x_tnsrall_.reshape(n_snr, n_t, n_xchnls, n_xlen)
                assert x_tnsrall.shape == (n_snr, n_t, n_xchnls, n_xlen)
                x_mb_ = x_tnsrall[i_snr, i_time]
                assert x_mb_.shape == (n_xchnls, n_xlen)
                x_mb = x_mb_.reshape(1, 1, n_xchnls, n_xlen).expand(n_seeds, 1, n_xchnls, n_xlen)
                assert x_mb.shape == (n_seeds, 1, n_xchnls, n_xlen)
                x_mbs[x_node] = x_mb

            ######### Applying the forward loop
            u_mbs, u_shaper = cri_pp.forward(x_mbs, full=True)
            for u_node, (n_uchnls, n_ulen) in u_dims.items():
                u_mb = u_mbs[u_node]
                assert u_mb.shape == (n_seeds, 1, n_uchnls, n_ulen)

            # Storing the evaluations in the array cache dictionary
            u_mbsnp = dict()
            for u_node, u_mb in u_mbs.items():
                (n_uchnls, n_ulen) = u_mb.shape[-2:]
                assert u_mb.shape == (n_seeds, 1, n_uchnls, n_ulen)
                u_mbnp = u_mb.detach().cpu().numpy()
                assert u_mbnp.shape == (n_seeds, 1, n_uchnls, n_ulen)
                u_mbsnp[u_node] = u_mbnp
            
            orig_cache[(*pp_idvals, i_snr, i_time)] = (u_mbsnp, u_shaper)

        assert key_origcache in orig_cache, dedent(f'''
            There must be a mistake with the cache ids in saving/loading since 
            the data should have been available at this point in the cache:
                rowdict = {rowdict}''')

        u_mbsnp, u_shaper = orig_cache[key_origcache]

        # The visualization node
        v_node = rowdictcp.pop('v_node')
        # The variable name ('orig' or 'rcnst')
        v_varname = rowdictcp.pop('v_varname')
        # The learning noise amplitude for magnitude
        noise_mag = rowdictcp.pop('noise_mag')
        # The learning noise amplitude for the normalized image
        noise_nrm = rowdictcp.pop('noise_nrm')
        # The learning noise amplitude for the normalized counts
        noise_cnt = rowdictcp.pop('noise_cnt')
        # The noise seed
        seed_noise = rowdictcp.pop('seed_noise')

        if v_varname == 'orig':
            u_mbnp = u_mbsnp[v_node]
            n_uchnls, n_ulen = u_mbnp.shape[-2:]
            assert u_mbnp.shape == (n_seeds, 1, n_uchnls, n_ulen)
            output = u_mbnp[i_seed, 0]
            assert output.shape == (n_uchnls, n_ulen)
        elif v_varname == 'rcnst':
            key_rcnstcache = (*pp_idvals, i_snr, i_time, seed_noise, noise_mag, noise_nrm, noise_cnt)
            if key_rcnstcache not in rcnst_cache:
                # Instantiating the random seed
                np_random = np.random.RandomState(seed=seed_noise)

                # Letting it run for a while!
                np_random.randn(10000)

                # Applying the noise injection business
                uhat_mbs = dict()
                uhat_noiseamps = {'m_prthstmag': noise_mag, 'm_chmprtnrm': noise_nrm, 'n_prthstnrm': noise_cnt}
                for u_node, (n_uchnls, n_ulen) in u_dims.items():
                    u_mbnp = u_mbsnp[u_node]
                    assert u_mbnp.shape == (n_seeds, 1, n_uchnls, n_ulen)
                    noise_amp = uhat_noiseamps[u_node]
                    u_noisemb = np_random.randn(n_seeds, 1, n_uchnls, n_ulen) * noise_amp
                    assert u_noisemb.shape == (n_seeds, 1, n_uchnls, n_ulen)
                    uhat_mbnp = u_mbnp + u_noisemb
                    assert uhat_mbnp.shape == (n_seeds, 1, n_uchnls, n_ulen)
                    uhat_mb = torch.from_numpy(uhat_mbnp).to(device=tch_device, dtype=tch_dtype)
                    assert uhat_mb.shape == (n_seeds, 1, n_uchnls, n_ulen)
                    uhat_mbs[u_node] = uhat_mb
                
                xhat_mbs = cri_pp.inverse(uhat_mbs, u_shaper, strict=True, full=True)
                
                for x_node, (n_xchnls, n_xlen) in x_dims.items():
                    xhat_mb = xhat_mbs[x_node]
                    assert xhat_mb.shape == (n_seeds, 1, n_xchnls, n_xlen)

                xhat_mbsnp = {x_node: xhat_mb.detach().cpu().numpy() 
                    for x_node, xhat_mb in xhat_mbs.items()}

                rcnst_cache[key_rcnstcache] = xhat_mbsnp
            
            assert key_rcnstcache in rcnst_cache, dedent(f'''
                There must be a mistake with the cache ids in saving/loading since 
                the data should have been available at this point in the cache:
                    rowdict = {rowdict}''')
            
            xhat_mbsnp = rcnst_cache[key_rcnstcache]
            xhat_mbnp = xhat_mbsnp[v_node]
            n_xchnls, n_xlen = xhat_mbnp.shape[-2:]
            assert xhat_mbnp.shape == (n_seeds, 1, n_xchnls, n_xlen)
            output = xhat_mbnp[i_seed, 0]
            assert output.shape == (n_xchnls, n_xlen)
        else:
            raise ValueError(f'undefined v_varname={v_varname}')

        while (n_ppcache is not None) and (len(pp_cache) > n_ppcache):
            pp_cache.pop(next(iter(pp_cache)))

        while (n_origcache is not None) and (len(orig_cache) > n_origcache):
            orig_cache.pop(next(iter(orig_cache)))

        while (n_rcnstcache is not None) and (len(rcnst_cache) > n_rcnstcache):
            rcnst_cache.pop(next(iter(rcnst_cache)))

        rowdictcp.pop('chems', None)
        rowdictcp.pop('v_repr', None)
        rowdictcp.pop('glyph_name', None)
        assert len(rowdictcp) == 0, f'unused options: {rowdictcp}'

        if (v_node in ('m_prthstmag', 'm02', 'm03', 'm04', 'm05', 'm05a', 'm05b')) and (ppcfg['magdim'] == 'n_chem'):
            output = output.T

        return output


In [None]:
class Changer:
    def __init__(self, ctrl_refs, rio_handler, glyph_type2vreprs,
        bk_srcshie, bk_glyphviewshie, glyph_infos, fig_infos, 
        max_strmdrows, verbose):
        self.ctrl_refs = ctrl_refs
        self.last_sceneid = None

        self.rio_handler = rio_handler
        self.bk_srcshie = bk_srcshie
        self.bk_glyphviewshie = bk_glyphviewshie

        # Applying a safe hie2deep, no matter glyph_keyvals is empty or not.
        self.bk_srcs = defaultdict(dict)
        for (glyph_name, *glyph_keyvals), bk_srcinfo in bk_srcshie.items():
            self.bk_srcs[glyph_name][tuple(glyph_keyvals)] = bk_srcinfo

        self.glyph_infos = glyph_infos
        self.glyph_type2vreprs = glyph_type2vreprs
        self.fig_infos = fig_infos
        self.max_strmdrows = max_strmdrows

        self.verbose = verbose
        self.clear_data()

    def clear_data(self):
        if self.verbose:
            print('Performing a data flush!', flush=True)
        for (glyph_name, *glyph_keyvals), bk_srcinfo in self.bk_srcshie.items():
            # Example: 
            #    glyph_name = 'sct'
            #    glyph_keydims = ['lbl_name']
            #    glyph_keyvals = ('train/orig',)
            #    bk_src = ColumnDataSource(...)

            # Getting the bokeh source and view and size
            bk_src = bk_srcinfo['src']
            # The mapping of each data id to the starting and ending 
            # row index within the Bokeh Data Source
            bk_srcbook = bk_srcinfo['book']

            # Clearing out any existing data in the bokeh sources
            bk_src.data = {col: [] for col in bk_src.column_names}

            # Zeroing out the length of the bokeh source
            bk_srcinfo['len'] = 0
            # Emptying the source book
            bk_srcbook.clear()
        
        # Clearing out any existing data in the bokeh cds views
        for (fig_name, glyph_name, *glyph_keyvals), bk_glyphinfo in self.bk_glyphviewshie.items():
            bk_glyph = bk_glyphinfo['glyph']
            bk_view = bk_glyphinfo['view']
            bk_view.filter.indices = []

    @property
    def nrows_bksrcs(self):
        return sum(bk_srcinfo['len'] for bk_srcinfo in self.bk_srcshie.values())

    @staticmethod
    def adjust_controls(ctrl_refs, verbose):
        #######################################################################
        ############### Reading the Controls and Adjusting Them ###############
        #######################################################################
        sceneid = dict()
        for compid, ctrl_ref in ctrl_refs.items():
            hpdf5 = ctrl_ref['hpdf']
            ctrl_book = ctrl_ref['book']
            ctrl_infos = ctrl_ref['infos']
            ctrl_cols = ctrl_ref['cols']
                
            ctrl_bookrcrsd = ctrl_book
            for ctrl_info in ctrl_infos:
                ctrl_col = ctrl_info['col']
                ctrl_ttl = ctrl_info['title']
                ctrl_type = ctrl_info['type']
                ctrl_enc = ctrl_info['enc']
                ctrl_dec = ctrl_info['dec']
                control = ctrl_info['control']

                # The user-selected value for the control
                if ctrl_type in ('catslider', 'slider', 'menu'):
                    ctrl_userval = control.value
                elif ctrl_type in ('radiobtngrp',):
                    ctrl_userval = control.labels[control.active]
                else:
                    raise ValueError(f'undefined ctrl_type={ctrl_type}')

                # The available set of values for the control
                ctrl_avlblevals = list(ctrl_bookrcrsd)

                if ctrl_userval in ctrl_bookrcrsd:
                    # Looks like this selected control value is available and 
                    # has a scene defined in our data.
                    ctrl_uservalnew = ctrl_userval                
                else:
                    # Looks like this selected control value is not available and 
                    # does not have a scene defined in our data.
                    ctrl_uservalnew = ctrl_avlblevals[0]

                # Updating the control attributes accordingly
                if ctrl_type == 'catslider':
                    if control.categories != ctrl_avlblevals:
                        control.categories = ctrl_avlblevals
                    if control.value != ctrl_uservalnew:
                        control.value = ctrl_uservalnew
                elif ctrl_type == 'radiobtngrp':
                    if control.labels != ctrl_avlblevals:
                        control.labels = ctrl_avlblevals
                    if ctrl_userval != ctrl_uservalnew:
                        control.active = ctrl_avlblevals.index(ctrl_uservalnew)
                elif ctrl_type == 'menu':
                    options = [(val, str(val)) for val in ctrl_avlblevals]
                    if control.options != options:
                        control.options = options
                    if control.value != ctrl_uservalnew:
                        control.value = ctrl_uservalnew
                else:
                    raise ValueError(f'undefined ctrl_type={ctrl_type}')

                # Recursing over the control book
                ctrl_bookrcrsd = ctrl_bookrcrsd[ctrl_uservalnew]
            
            # After we have recused over all the controls, the value of 
            # the control book should be the scene id.
            comp_sceneid = ctrl_bookrcrsd

            for ctrl_col, ctrl_val in zip(ctrl_cols, comp_sceneid):
                assert ctrl_col not in sceneid, dedent(f'''
                    I am not sure how a single column is controlled 
                    by two components:
                        Offending column: {ctrl_col}
                        First Value: {sceneid[ctrl_col]}
                        Second Value: {sceneid[ctrl_col]}''')
                sceneid[ctrl_col] = ctrl_val

        if verbose:
            print(f'Control: scene={sceneid}')
        
        return sceneid
    
    @staticmethod
    @without_property_validation
    def load_data(sceneid, bk_glyphviewshie, glyph_infos, glyph_type2vreprs, bk_srcs, verbose):
        #######################################################################
        ################# Reading the New Data From the Disk ##################
        #######################################################################
        # The columns that define a unique hdf key data to be loaded
        for (fig_name, glyph_name, *glyph_keyvals), bk_glyphinfo in bk_glyphviewshie.items():
            fig_info = fig_infos[fig_name]
            hpcdf_fig = fig_info['hpcdf']

            # Example: 
            #    glyph_name = 'sct'
            #    glyph_bkcols = ['x', 'y']
            #    glyph_keydims = ['lbl_name']
            #    glyph_keyvals = ['train/orig']
            #    glyph_vars = (('lbl_name': 'train/orig'),)
            #    bk_src = ColumnDataSource(...)
            #    bk_view = CDSView(...)
            #    bk_srclen = 0
            #    glyph_idctrls = 'all'
            #    glyph_idcols = ['fpidx', 'eid', 'vid', 'v_space', 'v_node', 
            #       'v_split', 'v_varname', 'ppid', 'i_seed', 'i_epoch', 'i_time']
            glyph_info = glyph_infos[glyph_name]
            # The glyph type
            glyph_type = glyph_info['type']
            # The glyph type v_reprs
            glyph_vreprs = glyph_type2vreprs[glyph_type]

            # The glyph key dimensions (e.g., the 'sct' glyph needs to be `lbl_name`-specific).
            glyph_keydims = glyph_info['keydims']
            assert len(glyph_keydims) == len(glyph_keyvals)

            # Restricting the scene hpdf to this particular glyph and figure
            glyph_vars = dict(zip(glyph_keydims, glyph_keyvals))
            glyph_vars['glyph_name'] = glyph_name

            hpcdf_fg = hpcdf_fig.select(incdict=glyph_vars, excdict=None, 
                copy='shallow', reset_index=False, has_wildcards=False)

            # The set of columns identifying one set of data points from another
            glyph_idcols = glyph_info['idcols']

            # The subset of control columns responsible for identifying the current 
            # dataset for the glyph. Use 'all' to use all control columns.
            glyph_idctrls = glyph_info['idctrls']
            assert (glyph_idctrls == 'all') or isinstance(glyph_idctrls, (tuple, list))
            # Extracting the control values related to this particular glyph
            glyph_scnid = sceneid if (glyph_idctrls == 'all') else {col: sceneid[col] for col in glyph_idctrls}
            
            hpcdf_fgs = hpcdf_fg.select(incdict=glyph_scnid, excdict=None, 
                copy='shallow', reset_index=False, has_wildcards=False)

            if hpcdf_fgs.shape[0] == 0:
                # The `figure + glyph + scene` combination defines an empty data set
                glyph_idvals, hpcdf_fgs2 = None, None
            else:
                # Making sure the `figure + glyph + scene` combination define a single data set
                assert hpcdf_fgs.shape[0] == len(glyph_vreprs), dedent(f'''
                    The `figure + glyph + scene` combination did not define 
                    a unique dataset to load. We expected the number of dataframe 
                    rows to be identical to the number of glyph v_reprs.
                        figure name: {fig_name}
                        glyph name: {glyph_name}
                        glyph vars: {glyph_vars}
                        The narrowed down df size: {hpcdf_fgs.shape[0]}
                        The glyph v_reprs: {glyph_vreprs}
                        The glyph scene id: {glyph_scnid}''')

                hpcdf_fgsgb = list(hpcdf_fgs.groupby(glyph_idcols, sort=False, observed=True))
                assert len(hpcdf_fgsgb) == 1, dedent(f'''
                    Multiple datasets were found for this particular glyph and scene. The glyph 
                    id cols may be under-specified, or the glyph must be too special!
                        figure name: {fig_name}
                        glyph name: {glyph_name}
                        glyph vars: {glyph_vars}
                        The narrowed down df size: {hpcdf_fgs.shape[0]}
                        The glyph v_reprs: {glyph_vreprs}
                        The glyph scene id: {glyph_scnid}''')

                # `glyph_idvals` is a tuple of values. It will be used as a key to 
                # the `bk_srcbook` to identify the starting and ending indices.
                glyph_idvals, hpcdf_fgs2 = hpcdf_fgsgb[0]

            # Now, we will open up the bokeh sources business
            bk_srcinfo = bk_srcs[glyph_name][tuple(glyph_keyvals)]
            # Getting the bokeh source and view and size
            bk_src = bk_srcinfo['src']
            # The current size of the bokeh source
            n_bksrcrows = bk_srcinfo['len']
            # The mapping of each data id to the starting and ending 
            # row indices within the Bokeh Data Source.
            bk_srcbook = bk_srcinfo['book']

            if glyph_idvals not in bk_srcbook:
                # Example:
                #   data_repr = {'pnts': np.randn(n_pnts, 2)}
                #   data_repr = {
                #       'mu': np.randn(n_pnts, 2),
                #       'sig': np.randn(n_pnts, 2),
                #       'phi': np.randn(n_pnts, 2)}

                if hpcdf_fgs2 is not None:
                    data_repr = dict()
                    hpcdf_fgs3 = hpcdf_fgs2.dense()

                    for rowvals in hpcdf_fgs3.itertuples(index=False, name='row01'):
                        rowdict = dict(zip(hpcdf_fgs3.columns, rowvals))
                        # Calling rio_handler to grab the data
                        data_lblnp3 = rio_handler(rowdict)
                        data_repr[rowdict['v_repr']] = data_lblnp3

                    # Converting the `data_repr` into a dictionary that fits the needs of Bokeh
                    bk_srcdata, n_pnts = Changer.get_bksrcdata(data_repr, glyph_info)
                else:
                    # There was no data to load!
                    assert glyph_idvals is None
                    bk_srcdata, n_pnts = None, 0
        
                if verbose:
                    print(f'Streaming: {(glyph_name, *glyph_keyvals)}')

                if n_pnts > 0:
                    # Streaming the new data to the figure source
                    bk_src.stream(bk_srcdata)
                
                # Updating the Bokeh source book to reflect the starting and ending row indices
                bk_srcbook[glyph_idvals] = (n_bksrcrows, n_bksrcrows + n_pnts)
                # Bumping up the current size of the bokeh source
                bk_srcinfo['len'] = n_bksrcrows + n_pnts

            # The starting and ending row indices of this data
            i1_view, i2_view = bk_srcbook[glyph_idvals]

            # Getting the bokeh view 
            bk_view = bk_glyphinfo['view']

            if verbose:
                print(f'View Update: {(fig_name, glyph_name, *glyph_keyvals)} -> {i1_view}: {i2_view}')

            # Updating the indecis
            bkview_indices = list(range(i1_view, i2_view))
            bk_view.filter.indices = bkview_indices

    @staticmethod
    @without_property_validation
    def get_bksrcdata(data_repr, glyph_info):
        # Example:
        #   data_repr = {'pnts': np.randn(n_pnts, 2)}
        #   data_repr = {
        #       'mu': np.randn(n_pnts, 2),
        #       'sig': np.randn(n_pnts, 2),
        #       'phi': np.randn(n_pnts, 2)}

        # Example: 
        #    glyph_name = 'sctr'
        #    glyph_type = 'sct'
        #    glyph_bkcols = ['x', 'y']
        #    glyph_keydims = ['lbl_name']
        #    glyph_keyvals = ('train/orig',)
        #    glyph_vars = {'lbl_name': 'train/orig'}
        #    bk_src = ColumnDataSource(...)
        #    bk_view = CDSView(...)
        #    bk_srclen = 0
        glyph_type = glyph_info['type']

        # The bokeh column names for this glyph (e.g., glyph_bkcols = ['x', 'y'])
        glyph_bkcols = glyph_info['bkcols']

        # The number of points/rows to be streamed to the bokeh sources
        n_pnts = 0
        bk_srcdata = dict()
        #######################################################################
        ################## Preparing the Scatter Glyphs Data ##################
        #######################################################################
        if (glyph_type == 'sct') and ('pnts' in data_repr):
            pnts_np = data_repr.pop('pnts')
            n_pnts = pnts_np.shape[0]
            assert pnts_np.shape == (n_pnts, 2)
            
            bk_srcdata['x'] = pnts_np[:, 0]
            bk_srcdata['y'] = pnts_np[:, 1]
            assert set(bk_srcdata.keys()) == set(glyph_bkcols)

        #######################################################################
        ################## Preparing the Ellipse Glyphs Data ##################
        #######################################################################
        if (glyph_type == 'blb') and all(key in data_repr for key in ('mu', 'sig', 'phi')):
            mus_np = data_repr.pop('mu')
            n_pnts = mus_np.shape[0]
            assert mus_np.shape == (n_pnts, 2)
            sigs_np = data_repr.pop('sig')
            assert sigs_np.shape == (n_pnts, 2)
            phis_np = data_repr.pop('phi')
            assert phis_np.shape == (n_pnts, 1)

            bk_srcdata['x'] = mus_np[:, 0]
            bk_srcdata['y'] = mus_np[:, 1]
            bk_srcdata['w'] = (2 * sigs_np[:, 0])
            bk_srcdata['h'] = (2 * sigs_np[:, 1])
            bk_srcdata['a'] = phis_np[:, 0]
            assert set(bk_srcdata.keys()) == set(glyph_bkcols)
        
        if (glyph_type == 'blb') and ('mu' in data_repr):
            mus_np = data_repr.pop('mu')
            # We cannot do anything for this odd case!

        #######################################################################
        #################### Preparing the Area Glyphs Data ###################
        #######################################################################
        if (glyph_type == 'vas') and ('pnts' in data_repr):
            n_bins, n_chem = glyph_info['n_x'], glyph_info['n_y']
            pnts_np = data_repr.pop('pnts')
            assert pnts_np.shape == (n_chem, n_bins)
            
            x_values = glyph_info['x_values']
            assert len(x_values) == n_bins + 1
            y_nm2idx = glyph_info['y_nm2idx']
            
            # Note: Use the following for a simple implementation
            # n_pnts = n_bins
            # bk_srcdata['x'] = x_values[:-1]
            # for y_name, i_chem in y_nm2idx.items():
            #     y_values1 = pnts_np[i_chem]
            #     bk_srcdata[y_name] = y_values1

            n_pnts =  2 * (n_bins + 1)
            x_values1 = np.stack([x_values, x_values], axis=1)
            assert x_values1.shape == (n_bins + 1, 2)
            x_values2 = x_values1.ravel().tolist()
            assert len(x_values2) == (n_bins + 1) * 2
            # Example:
            #    x_values  =  [x_0, x_1, x_2, ..., x_{n-1}]
            #    x_values1 = [[x_0,     x_0 + eps],
            #                 [x_1,     x_1 + eps],
            #                 [x_2,     x_2 + eps],
            #                 ...,
            #                 [x_{n-1}, x_{n-1} + eps]]
            #    x_values2 = [x_0, x_0, x_1, x_1, x_2, x_2, ..., x_{n-2}, x_{n-1}, x_{n-1}]
            #    y_values3 = [0.0, y_0, y_0, y_1, y_1, y_2, ..., y_{n-1}, y_{n-1}, 0.0    ]
            bk_srcdata['x'] = x_values2
            for y_name, i_chem in y_nm2idx.items():
                y_values1 = pnts_np[i_chem]
                assert y_values1.shape == (n_bins,)
                y_values2 = np.stack([y_values1, y_values1], axis=1)
                assert y_values2.shape == (n_bins, 2)
                y_values3 = [0.0] + y_values2.ravel().tolist() + [0.0]
                assert len(y_values3) == ((n_bins + 1) * 2)
                bk_srcdata[y_name] = y_values3

            assert set(bk_srcdata.keys()) == set(glyph_bkcols)
        
        #######################################################################
        #################### Preparing the Bar Glyphs Data ####################
        #######################################################################
        if (glyph_type == 'vbs') and ('pnts' in data_repr):
            n_bins, n_chem = glyph_info['n_x'], glyph_info['n_y']

            pnts_np = data_repr.pop('pnts')
            assert pnts_np.shape == (n_chem, n_bins)
            
            x_names = glyph_info['x_names']
            assert len(x_names) == n_bins
            y_nm2idx = glyph_info['y_nm2idx']
            
            # Note: Use the following for a simple implementation
            n_pnts = n_bins
            bk_srcdata['x'] = x_names
            for y_name, i_chem in y_nm2idx.items():
                y_values1 = pnts_np[i_chem]
                bk_srcdata[y_name] = y_values1

            assert set(bk_srcdata.keys()) == set(glyph_bkcols)

        #######################################################################
        #################### Preparing the Bar Glyphs Data ####################
        #######################################################################
        if (glyph_type == 'hm') and ('pnts' in data_repr):
            n_bins, n_chem = glyph_info['n_x'], glyph_info['n_y']
            
            pnts_np = data_repr.pop('pnts')
            assert pnts_np.shape == (n_chem, n_bins)
            
            x_values = glyph_info['x_values']
            assert len(x_values) == n_bins + 1
            y_nm2idx = glyph_info['y_nm2idx']
            y_names = glyph_info['y_names']
            n_y = len(y_names)
            assert len(y_names) == n_y
            y_idxs = [y_nm2idx[y_name] for y_name in y_names]
            assert len(y_idxs) == n_y

            x_cntrs = (x_values[:-1] + x_values[1:]) / 2.0
            assert x_cntrs.shape == (n_bins,)
            x_widths = (x_values[1:] - x_values[:-1])
            assert x_widths.shape == (n_bins,)
            v_2d = pnts_np[y_idxs]
            assert v_2d.shape == (n_y, n_bins)

            x_1d = np.broadcast_to(x_cntrs.reshape(1, n_bins), (n_y, n_bins)).ravel()
            assert x_1d.shape == (n_y * n_bins,)

            w_1d = np.broadcast_to(x_widths.reshape(1, n_bins), (n_y, n_bins)).ravel()
            assert w_1d.shape == (n_y * n_bins,)

            y_1d = [y_name for y_name in y_names for _ in range(n_bins)]
            assert len(y_1d) == (n_y * n_bins)

            v_1d = v_2d.ravel()
            assert v_1d.shape == (n_y * n_bins,)

            n_pnts = n_y * n_bins
            bk_srcdata['x'] = x_1d
            bk_srcdata['y'] = y_1d
            bk_srcdata['w'] = w_1d
            bk_srcdata['v'] = v_1d

            assert set(bk_srcdata.keys()) == set(glyph_bkcols)
        
        assert all(len(vals) == n_pnts for key, vals in bk_srcdata.items())
        assert len(data_repr) == 0, f'unused keys: {data_repr}'

        return bk_srcdata, n_pnts
     
    @staticmethod
    def update_legends(fig_infos, bk_glyphviewshie, bk_srcshie):
        # Disabling legend labels without data
        for fig_name, fig_info in fig_infos.items():
            fig = fig_info['fig']
            fig_type = fig_info['type']
            if fig_type == 'scatter':
                for handle in fig.legend.items:
                    n_hndlpnts = sum(
                        len(bk_glyph.view.filter.indices) 
                        for bk_glyph in handle.renderers)
                        
                    handle.visible = (n_hndlpnts > 0)

        # Sharing the same color-bar range for all heatmaps within the same frame
        framefig_vrngs = dict()
        for (fig_name, glyph_name, *glyph_keyvals), bk_glyphinfo in bk_glyphviewshie.items():
            bk_glyph = bk_glyphinfo['glyph']
            bk_view = bk_glyphinfo['view']
            fig_info = fig_infos[fig_name]
            fig_frame = fig_info['frame']
            fig_type = fig_info['type']
            if fig_type == 'hmap':
                bk_src = bk_srcshie[(glyph_name, *glyph_keyvals)]['src']
                fig_vals = bk_src.data['v'][bk_view.filter.indices]
                if len(fig_vals) > 0:
                    framefig_vrngs[(fig_frame, fig_name)] = (fig_vals.min(), fig_vals.max())

        for frame_name, fig_vrngs in hie2deep(framefig_vrngs, kind='tuple').items():
            # Finding the minimum and maximum range of values in all figures
            if len(fig_vrngs) > 0:
                v_min = min(minmax[0] for minmax in fig_vrngs.values())
                v_max = max(minmax[1] for minmax in fig_vrngs.values())
            else:
                v_min, v_max = None, None

            # Setting the same range of values for all figures and color-bars
            for (fig_name, glyph_name, *glyph_keyvals), bk_glyphinfo in bk_glyphviewshie.items():
                if fig_name not in fig_vrngs:
                    continue
                fig_info = fig_infos[fig_name]
                fig_type = fig_info['type']
                assert fig_type == 'hmap'

                bk_glyph = bk_glyphinfo['glyph']
                bk_rectglyph, bk_cbglyph = bk_glyph
                bk_rectglyph.glyph.fill_color.transform.low = v_min
                bk_rectglyph.glyph.fill_color.transform.high = v_max

    def __call__(self, attr, old, new, sceneid=None):
        # Adjusting the controls and finding the right fpidx
        if sceneid is None:
            sceneid = self.adjust_controls(self.ctrl_refs, self.verbose)

        # Only updating the bokeh data sources when the controls actually changed
        if sceneid != self.last_sceneid:
            if self.verbose:
                print(f'Got a new scene: {sceneid}')
                
            # Clearing the client data if we have too much of it!
            if self.nrows_bksrcs >= self.max_strmdrows:
                self.clear_data()

            # Loading the data from the disk
            self.load_data(sceneid, self.bk_glyphviewshie, self.glyph_infos, 
                self.glyph_type2vreprs, self.bk_srcs, self.verbose)

            # Disabling and enabling the legend handles
            self.update_legends(self.fig_infos, self.bk_glyphviewshie, self.bk_srcshie)

            # Updating the latest control scene
            self.last_sceneid = sceneid


In [None]:
rio_handler = RioHandler(data_dict=data_dict, hash_data=hash_data, 
    trn_spltidxs=trn_spltidxs, n_seeds=n_seeds, eval_bs=eval_bs, 
    n_snr=n_snr, n_t=n_t, tch_device=tch_device, 
    tch_dtype=tch_dtype, n_ppcache=None, n_origcache=None, 
    n_rcnstcache=None, verbose=False)

changer = Changer(ctrl_refs=ctrl_refs, glyph_type2vreprs=glyph_type2vreprs, 
    rio_handler=rio_handler, glyph_infos=glyph_infos, 
    bk_srcshie=bk_srcshie, bk_glyphviewshie=bk_glyphviewshie, 
    fig_infos=fig_infos, max_strmdrows=max_strmdrows, verbose=False)

changer(None, None, None)

In [None]:
# Making a control column to info mapping
ctrl_col2info = dict()
for compid, ctrl_ref in ctrl_refs.items():
    for ctrl_info in ctrl_ref['infos']:
        ctrl_col = ctrl_info['col']
        assert ctrl_col not in ctrl_col2info
        ctrl_col2info[ctrl_col] = ctrl_info

# # The simple old code without any control layouts
# controls_datalst = [ctrl_info['control'] 
#     for compid, ctrl_ref in ctrl_refs.items()
#     for ctrl_info in ctrl_ref['infos']]
# # Adding the control callbacks
# for control in controls_datalst:
#     control.on_change('value', changer)

# Listing the actual controls in the same order as the user-specified control layout
controls_cntnrlst = []
for i_rowctrl, ctrl_layrow in enumerate(ctrl_layout):
    for i_colctrl, ctrl_col in enumerate(ctrl_layrow):
        container = None
        if ctrl_col is not None:
            ctrl_info = ctrl_col2info[ctrl_col]
            container = ctrl_info['container']
            # Adding the control callbacks
            ctrl_type = ctrl_info['type']
            control = ctrl_info['control']
            control.on_change('active' if ctrl_type == 'radiobtngrp' else 'value', changer)
            
        controls_cntnrlst.append(container)

if need_sctvisualctrls:
    controls_cntnrlst += [alpha_slider, size_slider]

frame_info = {'title': 'Data Selection', 'type': 'ctrl', 'nrows': n_ctrlrows, 'ncols': n_ctrlcols}
frame_infos = {'datasel': frame_info, **frame_infos}

# Making the data selection control frame
for frame_name, frame_info in frame_infos.items():
    frame_type = frame_info['type']
    if frame_type in ('bar', 'hmap', 'scatter'):
        frame_info['bkchilds'] = [fig_infos[fig_name]['fig'] 
            for fig_name in frame_info['fig_names']]
    elif frame_type in ('ctrl',):
        frame_info['bkchilds'] = controls_cntnrlst
    else:
        raise ValueError(f'undefined frame_type={frame_type}')

In [None]:
layout_dirctn = 'col'

for frame_name, frame_info in frame_infos.items():
    frame_type = frame_info['type']
    n_framerows = frame_info['nrows']
    n_framecols = frame_info['ncols']
    frame_bkchilds = frame_info['bkchilds']
    n_bkchilds = len(frame_bkchilds)

    if frame_type == 'ctrl':
        sizing_mode = {'col': 'stretch_width', 'row': 'stretch_width'}[layout_dirctn]
        grid_layout = grid(frame_bkchilds, ncols=n_framecols, 
            nrows=n_framerows, sizing_mode=sizing_mode)
    elif frame_type in ('bar', 'hmap', 'scatter'):
        sizing_mode = {'col': 'stretch_both', 'row': 'stretch_both'}[layout_dirctn]
        grid_layout = gridplot(frame_bkchilds, ncols=n_framecols,
            merge_tools=True, sizing_mode=sizing_mode)
    else:
        raise ValueError(f'undefined frame_type={frame_type}')

    frame_title = frame_info['title']
    headszng_kwargs = {
        'col': dict(width=3000, height=ctrl_height, sizing_mode='fixed'),
        'row': dict(width=ctrl_width, height=ctrl_height, sizing_mode='fixed')}[layout_dirctn]
    frame_heading = Div(text=f'<h1 style="text-align: center">{frame_title}</h1>',
        styles={'color': header_color}, disable_math=False, **headszng_kwargs)

    if layout_dirctn == 'col':
        sizing_mode = 'stretch_width' if frame_type == 'ctrl' else "scale_height"
        frame_szngkwargs = dict(sizing_mode=sizing_mode, width=3000)
    elif layout_dirctn == 'row':
        sizing_mode = 'fixed' if frame_type == 'ctrl' else "stretch_both"
        frame_szngkwargs = dict(sizing_mode=sizing_mode, height=1500)
    else:
        raise ValueError(f'undefined layout_dirctn={layout_dirctn}')
        
    frame_layout = column([frame_heading, grid_layout], background=background_color, 
        **frame_szngkwargs)
    frame_info['layouts'] = {'heading': frame_heading, 'grid': grid_layout, 'frame': frame_layout}

In [None]:
if layout_dirctn == 'col':
    ctrl_frame = frame_infos['datasel']['layouts']['frame']
    
    # The PP Normalization Bar Plots
    bar01_frame = frame_infos['stckbar1']['layouts']['frame']
    bar02_frame = frame_infos['stckbar2']['layouts']['frame']
    bar03_frame = frame_infos['stckbar3']['layouts']['frame']
    bar04_frame = frame_infos['stckbar4']['layouts']['frame']
    bar05_frame = frame_infos['stckbar5']['layouts']['frame']
    bar06_frame = frame_infos['stckbar6']['layouts']['frame']

    # The PP Magnitude Bar Plots with the diameter bins as the x axis
    bar07_frame = frame_infos['magbins1']['layouts']['frame']
    bar08_frame = frame_infos['magbins2']['layouts']['frame']
    bar09_frame = frame_infos['magbins3']['layouts']['frame']
    bar10_frame = frame_infos['magbins4']['layouts']['frame']

    # The PP Magnitude Bar Plots with the chemicals as the x axis
    bar11_frame = frame_infos['magchems1']['layouts']['frame']
    bar12_frame = frame_infos['magchems2']['layouts']['frame']
    bar13_frame = frame_infos['magchems3']['layouts']['frame']
    bar14_frame = frame_infos['magchems4']['layouts']['frame']

    # The PP Magnitude Bar Plots with a single column on the x axis
    bar15_frame = frame_infos['magbar1']['layouts']['frame']
    bar16_frame = frame_infos['magbar2']['layouts']['frame']
    bar17_frame = frame_infos['magbar3']['layouts']['frame']
    bar18_frame = frame_infos['magbar4']['layouts']['frame']
    
    # The PP Count Bar Plots with a single column on the x axis
    bar19_frame = frame_infos['cntbins1']['layouts']['frame']
    bar20_frame = frame_infos['cntbins2']['layouts']['frame']
    bar21_frame = frame_infos['cntbins3']['layouts']['frame']
    bar22_frame = frame_infos['cntbins4']['layouts']['frame']

    # The PP Normalization Heatmaps
    # hmap01_frame = frame_infos['heatmap1']['layouts']['frame']
    # hmap02_frame = frame_infos['heatmap2']['layouts']['frame']
    # hmap03_frame = frame_infos['heatmap3']['layouts']['frame']
    # # hmap04_frame = frame_infos['heatmap4']['layouts']['frame']
    # hmap05_frame = frame_infos['heatmap5']['layouts']['frame']
    chmbars_frame = frame_infos['chembars']['layouts']['frame']

    # The PP Normalization Frame Rows
    frame_row1a = row([bar01_frame, bar02_frame, bar03_frame], height=400, 
        sizing_mode="stretch_width", background=background_color)
    frame_row1b = row([bar04_frame, bar05_frame, bar06_frame], height=400, 
        sizing_mode="stretch_width", background=background_color)
    frame_row4 = row([bar19_frame, bar20_frame, bar21_frame, bar22_frame], 
        height=400, sizing_mode="stretch_width", background=background_color)
    frame_row5 = row([chmbars_frame], height=1100, sizing_mode="stretch_width", 
        background=background_color)
    
    # # The PP Magnitude Frame Rows
    # frame_row2a = row([bar07_frame, bar08_frame], height=400, 
    #     sizing_mode="stretch_width", background=background_color)
    # frame_row3a = row([bar09_frame, bar10_frame], height=400, 
    #     sizing_mode="stretch_width", background=background_color)
    # frame_row2b = row([bar11_frame, bar12_frame], height=400, 
    #     sizing_mode="stretch_width", background=background_color)
    # frame_row3b = row([bar13_frame, bar14_frame], height=400, 
    #     sizing_mode="stretch_width", background=background_color)
    # frame_row2c = row([bar15_frame, bar16_frame], height=400, 
    #     sizing_mode="stretch_width", background=background_color)
    # frame_row3c = row([bar17_frame, bar18_frame], height=400, 
    #     sizing_mode="stretch_width", background=background_color)

    # tab_rows23 = Tabs(tabs=[
    #     TabPanel(child=column([frame_row2a, frame_row3a], 
    #         sizing_mode="stretch_both", background=background_color), 
    #         title=renamer.encode('k_bins')),
    #     TabPanel(child=column([frame_row2b, frame_row3b], 
    #         sizing_mode="stretch_both", background=background_color), 
    #         title=renamer.encode('n_chem')),
    #     TabPanel(child=column([frame_row2c, frame_row3c], 
    #         sizing_mode="stretch_both", background=background_color), 
    #         title=renamer.encode('one'))])

    # The PP Magnitude Frame Rows
    frame_row2a = row([bar07_frame, bar08_frame, bar09_frame, bar10_frame], 
        height=400, sizing_mode="stretch_width", background=background_color)
    frame_row2b = row([bar11_frame, bar12_frame, bar13_frame, bar14_frame], 
        height=400, sizing_mode="stretch_width", background=background_color)
    frame_row2c = row([bar15_frame, bar16_frame, bar17_frame, bar18_frame], 
        height=400, sizing_mode="stretch_width", background=background_color)

    tab_rows23 = Tabs(tabs=[
        TabPanel(child=frame_row2a, title=renamer.encode('k_bins')),
        TabPanel(child=frame_row2b, title=renamer.encode('n_chem')),
        TabPanel(child=frame_row2c, title=renamer.encode('one'))])

    layout = column([ctrl_frame, frame_row1a, frame_row1b, tab_rows23, 
        frame_row4, frame_row5], sizing_mode="stretch_both", 
        background=background_color)
elif layout_dirctn == 'row':
    ctrl_frame = frame_infos['datasel']['layofuts']['frame']
    sctr_frame = frame_infos['sctrplts']['layouts']['frame']
    layout = row([ctrl_frame, sctr_frame],
        sizing_mode="stretch_both", background=background_color)
else:
    raise ValueError(f'undefined layout_dirctn={layout_dirctn}')

In [None]:
def rearrng_layout(attr, old, new):
    magdim = changer.last_sceneid['magdim']
    magdims = ['k_bins', 'n_chem', 'one']
    idx_tab = magdims.index(magdim)
    tab_rows23.active = idx_tab

ctrl_info = ctrl_refs['magdim']['infos'][0]
control = ctrl_info['control']
ctrl_type = ctrl_info['type']
control.on_change('active' if ctrl_type == 'radiobtngrp' else 'value', rearrng_layout)

In [None]:
bokeh_pane = pn.pane.Bokeh(layout, theme=doc_theme)
bokeh_pane