In [None]:
%matplotlib inline
import re
import os
import math
import torch
import fnmatch
import numpy as np
import pandas as pd
from os.path import exists
from textwrap import dedent
import ruamel.yaml as ruyaml
from scipy.stats import norm

from itertools import product
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from collections import defaultdict
from bokeh.io import output_file, save
from matplotlib.backends.backend_pdf import PdfPages

from partnn.io_cfg import data_dir
from partnn.tch_utils import torch_cdf
from partnn.aerometrics import get_relerr
from n08_utils import squeeze_colnames, get_aggdf, get_dfidxs, draw_matplotlib
from n08_utils import build_dashboard, get_dashdata, adjust_mtrcnames, BStrapAgg
from n20_utils import plot_mpl, load_histdata, tag_axis, print_axheader, mark_rect
from partnn.io_utils import parse_refs, deep2hie, hie2deep, get_subdict, get_subdictrnmd, get_ovatgrps
from partnn.io_utils import hie2deep, downcast_df, decomp_df, parse_refs, resio, sort_keys, load_h5data
from partnn.io_utils import save_h5data, load_h5data, get_h5du, drop_unqcols, eval_formula, results_dir, save_h5datav2

In [None]:
workdir = './08_plotting'
! mkdir -p {workdir}

suppdir = f'{workdir}/supplement'
! mkdir -p {suppdir}

# Aerosol Diagnostic Plots

In [None]:
def get_aerocsts():
    n_epoch, n_chem, n_bins, n_wave = 2, 15, 19, 8

    chem_species = ['SO4', 'NO3', 'Cl', 'NH4', 'ARO1', 'ARO2', 'ALK1', 'OLE1', 
        'API1', 'Na', 'OIN', 'OC', 'BC', 'MOC', 'H2O']

    diam_low, diam_high = 10**(-8.75), 1e-4
    logdiam_low, logdiam_high = np.log(diam_low), np.log(diam_high)
    logdiams = np.linspace(logdiam_low, logdiam_high, n_bins + 1, endpoint=True)
    d_histbins = np.exp(logdiams)
    assert d_histbins.shape == (n_bins + 1,)

    # The log-epsilon interval
    eps_histmin, eps_histmax, n_epshist = 1e-3, 1e-1, 100
    eps_histbins = np.exp(np.linspace(math.log(eps_histmin), math.log(eps_histmax), n_epshist + 1))
    assert eps_histbins.shape == (n_epshist + 1,)

    tmprtr_inpmin, tmprtr_inpmax, n_tmprtr = -40, 0, 100
    temprtr_bins = np.linspace(tmprtr_inpmin, tmprtr_inpmax, n_tmprtr + 1)
    assert temprtr_bins.shape == (n_tmprtr + 1,)

    len_wvmin, len_wvmax = 300e-9, 1000e-9
    len_wv = np.linspace(len_wvmin, len_wvmax, n_wave)
    assert len_wv.shape == (n_wave,)

    ycol2dims = {'m_chmprthst': (n_chem, n_bins), 'n_prthst': (1, n_bins), 
        'ccn_cdf': (1, n_epshist), 'qs_prt': (n_wave, n_bins), 
        'qscs_prt': (n_wave, n_bins), 'qs_pop': (n_wave, 1), 
        'qa_prt': (n_wave, n_bins),  'qacs_prt': (n_wave, n_bins), 
        'qa_pop': (n_wave, 1),  'frznfrac_tmp': (1, n_tmprtr), 
        'logfrznfrac_tmp': (1, n_tmprtr)}

    # The dictionary of aerosol constants
    aero_csts = dict(n_chem=n_chem, n_bins=n_bins, n_wave=n_wave, 
        eps_histmin=eps_histmin, eps_histmax=eps_histmax, n_epshist=n_epshist, 
        tmprtr_inpmin=tmprtr_inpmin, tmprtr_inpmax=tmprtr_inpmax, n_tmprtr=n_tmprtr, 
        len_wvmin=len_wvmin, len_wvmax=len_wvmax, len_wv=len_wv, diam_low=diam_low, 
        diam_high=diam_high, d_histbins=d_histbins, eps_histbins=eps_histbins, 
        temprtr_bins=temprtr_bins, chem_species=chem_species, ycol2dims=ycol2dims)
    
    return aero_csts, ycol2dims

aero_csts, ycol2dims = get_aerocsts()

# The sample index used for the intro figures
i_sampintro = 5391


In [None]:
def get_binidx(x, x_min, x_max, n_x, tnsfrm='none'):
    assert (x_min <= x <= x_max), x
    assert tnsfrm in ('log', 'none')
    x2 = math.log(x) if tnsfrm == 'log' else x
    x_min2 = math.log(x_min) if tnsfrm == 'log' else x_min
    x_max2 = math.log(x_max) if tnsfrm == 'log' else x_max
    i_x = round((x2 - x_min2) * n_x / (x_max2 - x_min2))
    assert (0 <= i_x < n_x), f'x={x}, i_x={i_x}'

    alpha = (i_x / n_x)
    assert (0 <= alpha <= 1), alpha
    x3 = alpha * x_max2 + (1 - alpha) * x_min2
    x4 = math.exp(x3) if tnsfrm == 'log' else x3
    return x4, i_x

@torch.no_grad()
def np_flatten(input, start_dim, end_dim):
    return torch.from_numpy(input).flatten(start_dim, end_dim).cpu().numpy()

def cnvrt_physunits(v_ydataraw: dict, aero_csts: dict):
    """
    Takes the freshly loaded data out of the storage HDF files, and converts the 
    physical units of the data (e.g., from kg to ug).

    Parameters
    ----------
    v_ydataraw: (dict) a dictionary of the loaded data arrays.

    aero_csts: (dict) a dictionary of the physical aerosol constants.

    Returns
    -------
    v_ydata: (dict) The same data with the physical units converted.

    aero_csts: (dict) The same aerosol constants with the units converted.
    """

    n_bins = aero_csts['n_bins']
    n_wave = aero_csts['n_wave']
    chem_species = aero_csts['chem_species']

    d_histbins = aero_csts['d_histbins']
    assert d_histbins.shape == (n_bins + 1,)
    len_wv = aero_csts['len_wv']
    assert len_wv.shape == (n_wave,)

    # Finding `log10len_bin`
    logdiams = np.log(d_histbins)
    logdiam_low, logdiam_high, n_bins = logdiams[0], logdiams[-1], len(d_histbins) - 1
    log10len_bin = (logdiam_high - logdiam_low) / (n_bins * np.log(10))

    # Making sure the `d_histbins` is log-uniformly distributed.
    logdiams2 = np.linspace(logdiam_low, logdiam_high, n_bins + 1, endpoint=True)
    d_histbins2 = np.exp(logdiams2)
    assert d_histbins2.shape == (n_bins + 1,)
    assert np.allclose(d_histbins, d_histbins2)

    # Converting the diameter data unit from `m` to `um`
    d_histbinsum = d_histbins * 1e6
    len_wvum = len_wv * 1e6

    ########### Data Cleaning and Unit Conversions ############
    v_ydata = dict(v_ydataraw)
    vrnts = list({key.split('/', 1)[1]: None for key in v_ydata})
    
    ycols = ['m_chmprthst', 'n_prthst', 'qs_pop', 'qa_pop', 
        'qs_prt', 'qa_prt', 'qscs_prt', 'qacs_prt']
    ycol2vrnts = {ycol: list(get_subdict(v_ydata, ycol)) for ycol in ycols}
    ycol2vrnts['m_prthst'] = ycol2vrnts['m_chmprthst']
    ycol2vrnts['m_chmprt'] = ycol2vrnts['m_chmprthst']

    # Zeroing out negative particle mass or count value
    for vrnt in ycol2vrnts['m_chmprthst']:
        v_ydata[f'm_chmprthst/{vrnt}'] = np.clip(v_ydata[f'm_chmprthst/{vrnt}'], 0, None)
    for vrnt in ycol2vrnts['n_prthst']:
        v_ydata[f'n_prthst/{vrnt}'] = np.clip(v_ydata[f'n_prthst/{vrnt}'], 0, None)

    # The ACSM mass measurements
    #   1. The `1e9` rate is for conversion from `kg` to `ug`
    inorg_species = ['SO4', 'NO3', 'NH4']
    orgnc_species = ['OC', 'MOC', 'ARO1', 'ARO2', 'ALK1', 'OLE1', 'API1']
    i_ychmsinorg = [chem_species.index(chm) for chm in inorg_species]
    i_ychmsorgnc = [chem_species.index(chm) for chm in orgnc_species]
    acsm_species = inorg_species + ['OA']
    for vrnt in ycol2vrnts['m_chmprthst']:
        m_chmprtvrnt = v_ydata[f'm_chmprthst/{vrnt}'].sum(axis=-1)
        m_inorgvrnt = m_chmprtvrnt[..., i_ychmsinorg]
        m_orgncvrnt = m_chmprtvrnt[..., i_ychmsorgnc].sum(axis=-1, keepdims=True)
        m_acsmvrnt = np.concatenate([m_inorgvrnt, m_orgncvrnt], axis=-1)
        v_ydata[f'm_acsm/{vrnt}'] = m_acsmvrnt[..., None] * 1e9

    # The SMPS particle concentration measurements
    # The `1e-6` rate is for conversion from `#/m^3` to `#/cm^3`
    # Restricting the smps diameter bins to 10nm-560nm.
    i1_smpsdiam = np.abs(d_histbinsum[:-1] - 0.010).argmin()
    i2_smpsdiam = np.abs(d_histbinsum[:-1] - 0.560).argmin()
    assert abs(d_histbinsum[i1_smpsdiam] - 0.010) < 0.001, dedent(f'''
        Could not find a 10nm diameter: 
            d_histbinsum: {d_histbinsum}''')
    assert abs(d_histbinsum[i2_smpsdiam] - 0.560) < 0.010, dedent(f'''
        Could not find a 560nm diameter: 
            d_histbinsum: {d_histbinsum}''')
    n_binssmps = i2_smpsdiam - i1_smpsdiam
    d_binssmpsum = d_histbinsum[i1_smpsdiam: i2_smpsdiam + 1]
    for vrnt in ycol2vrnts['n_prthst']:
        v_ydata[f'n_smps/{vrnt}'] = v_ydata[f'n_prthst/{vrnt}'][..., i1_smpsdiam: i2_smpsdiam] * 1e-6

    # Converting the mass data unit:
    #   1. The `1e9` rate is for conversion from `kg` to `ug`
    #   2. The `1/log10len_bin` rate is for diameter histogram density scaling.
    for vrnt in ycol2vrnts['m_chmprthst']:
        v_ydata[f'm_chmprthst/{vrnt}'] *= (1e9 / log10len_bin)

    # Converting the mass data unit:
    #   1. The `1e-6` rate is for conversion from `#/m^3` to `#/cm^3`
    #   2. The `1/log10len_bin` rate is for diameter histogram density scaling.
    for vrnt in ycol2vrnts['n_prthst']:
        v_ydata[f'n_prthst/{vrnt}'] *= (1e-6 / log10len_bin)

    # Converting the cross-section unit from `1/m` to `1/Mm`
    for ycol in ['qs_pop', 'qscs_prt', 'qa_pop', 'qacs_prt']:
        for vrnt in ycol2vrnts[ycol]:
            v_ydata[f'{ycol}/{vrnt}'] = v_ydata[f'{ycol}/{vrnt}'] * 1e6

    # Computing the total mass data for each particle
    for vrnt in ycol2vrnts['m_prthst']:
        v_ydata[f'm_prthst/{vrnt}'] = v_ydata[f'm_chmprthst/{vrnt}'].sum(axis=-2, keepdims=True)
    
    # Computing the species mass data for each particle
    for vrnt in ycol2vrnts['m_chmprt']:
        v_ydata[f'm_chmprt/{vrnt}'] = v_ydata[f'm_chmprthst/{vrnt}'].sum(axis=-1, keepdims=True)

    # Restricting the diameter bins to a max of 10um.
    i1_diam, i2_diam = 0, np.abs(d_histbinsum[:-1] - 10).argmin()
    assert abs(d_histbinsum[i2_diam] - 10) < 0.1, dedent(f'''
        Could not find a 10um diameter: 
            d_histbins: {d_histbins}
            d_histbinsum: {d_histbinsum}''')

    for ycol in ['m_chmprthst', 'm_prthst', 'n_prthst', 'qs_prt', 'qa_prt', 'qscs_prt', 'qacs_prt']:
        for vrnt in ycol2vrnts[ycol]:
            v_ydata[f'{ycol}/{vrnt}'] = v_ydata[f'{ycol}/{vrnt}'][..., i1_diam : i2_diam]
    for vrnt in ycol2vrnts['qs_pop']:
        v_ydata[f'qs_pop/{vrnt}'] = v_ydata[f'qscs_prt/{vrnt}'].sum(axis=-1, keepdims=True)
    for vrnt in ycol2vrnts['qa_pop']:
        v_ydata[f'qa_pop/{vrnt}'] = v_ydata[f'qacs_prt/{vrnt}'].sum(axis=-1, keepdims=True)
    
    d_histbinsum2 = d_histbinsum[i1_diam : i2_diam + 1] 
    n_binsum2 = len(d_histbinsum2) - 1
    chmrnm = {'OIN': 'Dust', 'OC': 'POA'}
    chem_species2 = [chmrnm.get(chm, chm) for chm in chem_species]

    aero_cstscnv = dict(aero_csts)
    aero_cstscnv['n_bins'] = n_binsum2
    aero_cstscnv['d_histbins'] = d_histbinsum2
    aero_cstscnv['len_wv'] = len_wvum
    aero_cstscnv['i1_diam'] = i1_diam
    aero_cstscnv['i2_diam'] = i2_diam
    aero_cstscnv['chem_species'] = chem_species2
    aero_cstscnv['n_binssmps'] = n_binssmps
    aero_cstscnv['d_binssmps'] = d_binssmpsum
    aero_cstscnv['acsm_species'] = acsm_species
    aero_cstscnv['n_chemacsm'] = len(acsm_species)

    return v_ydata, aero_cstscnv

def calc_aerometrics(v_ydata: dict, aero_cstscnv: dict, avg_errs: bool, vrnt_trg: str = 'rcnst'):
    """
    Calculates the aerosol error metrics from the loaded and converted data.

    Parameters
    ----------
    v_ydata: (dict) a dictionary of the loaded data arrays with the 
        physical units converted.

    aero_cstscnv: (dict) a dictionary of the physical aerosol constants 
        with the physical units converted

    
    vrnt_trg: (str) The target variant. Either 'rcnst' or 'znrm'.

    Returns
    -------
    v_mtrcdata: (dict) The aerosol error metrics, with the underlying data.
    """
    
    eps_histmin = aero_cstscnv['eps_histmin']
    eps_histmax = aero_cstscnv['eps_histmax']
    n_epshist = aero_cstscnv['n_epshist']
    tmprtr_inpmin = aero_cstscnv['tmprtr_inpmin']
    tmprtr_inpmax = aero_cstscnv['tmprtr_inpmax']
    n_tmprtr = aero_cstscnv['n_tmprtr']
    len_wvmin = aero_cstscnv['len_wvmin']
    len_wvmax = aero_cstscnv['len_wvmax']
    n_wave = aero_cstscnv['n_wave']
    n_binsum = aero_cstscnv['n_bins']
    n_chem = aero_cstscnv['n_chem']
    n_chemacsm = aero_cstscnv['n_chemacsm']
    n_binssmps = aero_cstscnv['n_binssmps']
    

    # Defining the specific super-saturation epsilons
    eps1, i_eps1 = get_binidx(0.001, eps_histmin, eps_histmax, n_epshist, 'log')
    eps2, i_eps2 = get_binidx(0.003, eps_histmin, eps_histmax, n_epshist, 'log')
    eps3, i_eps3 = get_binidx(0.006, eps_histmin, eps_histmax, n_epshist, 'log')

    # Defining the specific freezing temperatures
    tmp1, i_tmp1 = get_binidx(-25, tmprtr_inpmin, tmprtr_inpmax, n_tmprtr)
    tmp2, i_tmp2 = get_binidx(-17, tmprtr_inpmin, tmprtr_inpmax, n_tmprtr)
    tmp3, i_tmp3 = get_binidx(-10, tmprtr_inpmin, tmprtr_inpmax, n_tmprtr)

    # Defining the specific wave-length
    wvl1, i_wvl1 = get_binidx(500e-9, len_wvmin, len_wvmax, n_wave)

    # Computing the plotting data
    v_mtrcdata = dict()

    mtrcs_spec = [
        #             ycol          ycol2,     n_chnls,         n_len,   e_type,   i_elow,      i_ehigh, rel_rdcdims, y_tnsstr
        (        'ccn_cdf',          None,           1,     n_epshist,    'rel',   i_eps1,   i_eps3 + 1,    [-1, -2],      'y'),
        (         'qs_pop',          None,      n_wave,             1,    'rel',        0,         None,    [-1, -2],      'y'),
        (         'qa_pop',          None,      n_wave,             1,    'rel',        0,         None,    [-1, -2],      'y'),
        (         'qs_pop',   'logqs_pop',      n_wave,             1, 'logrel',        0,         None,    [-1, -2],      'y'),
        (         'qa_pop',   'logqa_pop',      n_wave,             1, 'logrel',        0,         None,    [-1, -2],      'y'),
        (         'qs_prt',          None,      n_wave,      n_binsum,     None,        0,         None,        None,      'y'),
        (         'qa_prt',          None,      n_wave,      n_binsum,     None,        0,         None,        None,      'y'),
        (       'qscs_prt',          None,      n_wave,      n_binsum,     None,        0,         None,        None,      'y'),
        (       'qacs_prt',          None,      n_wave,      n_binsum,     None,        0,         None,        None,      'y'),
        ('logfrznfrac_tmp',          None,           1,      n_tmprtr,    'rel',   i_tmp1,   i_tmp3 + 1,    [-1, -2],      'y'),
        (   'frznfrac_tmp',          None,           1,      n_tmprtr,    'rel',   i_tmp1,   i_tmp3 + 1,    [-1, -2],      'y'),
        (    'm_chmprthst',          None,      n_chem,      n_binsum,    'rel',        0,         None,    [-1, -2],      'y'),
        (       'n_prthst',          None,           1,      n_binsum,    'rel',        0,         None,    [-1, -2],      'y'),
        (       'm_prthst',          None,           1,      n_binsum,    'rel',        0,         None,    [-1, -2],      'y'),
        (    'm_chmprthst', 'm_perchmhst',      n_chem,      n_binsum,    'rel',        0,         None,        [-1],      'y'),
        (         'm_acsm',          None,  n_chemacsm,             1,    'rel',        0,         None,    [-1, -2],      'y'),
        (         'n_smps',          None,           1,    n_binssmps,    'rel',        0,         None,    [-1, -2],      'y'),
        (       'm_chmprt',          None,      n_chem,             1,    'rel',        0,         None,    [-1, -2],      'y')]

    for ycol, ycol2, n_chnls, n_len, e_type, i_elow, i_ehigh, rel_rdcdims, y_tnsstr in mtrcs_spec:
        ycol2 = ycol if ycol2 is None else ycol2

        y_rcnst = torch.from_numpy(v_ydata[f'{ycol}/{vrnt_trg}'])
        n_seeds, n_snrt, n_rcns = y_rcnst.shape[:-2]
        assert y_rcnst.shape == (n_seeds, n_snrt, n_rcns, n_chnls, n_len)

        y_rcnst2 = y_rcnst.flatten(1, 2)
        assert y_rcnst2.shape == (n_seeds, n_snrt * n_rcns, n_chnls, n_len)

        tnsfm_y = eval(f'lambda y: {y_tnsstr}')
        v_mtrcdata[f'{ycol2}/{vrnt_trg}'] = tnsfm_y(y_rcnst2).detach().cpu().numpy()

        if f'{ycol}/orig' not in v_ydata:
            v_mtrcdata[f'{ycol2}/err'] = None    
            v_mtrcdata[f'{ycol2}/orig'] = None
            v_mtrcdata[f'{ycol2}/origraw'] = None
            v_mtrcdata[f'{ycol2}/origcdf'] = None
            continue

        y_orig1 = torch.from_numpy(v_ydata[f'{ycol}/orig'])
        n_rcns1 = y_orig1.shape[2]
        assert y_orig1.shape == (n_seeds, n_snrt, n_rcns1, n_chnls, n_len)

        assert n_rcns1 == 1
        y_orig2 = y_orig1.expand(n_seeds, n_snrt, n_rcns, n_chnls, n_len)
        assert y_orig2.shape == (n_seeds, n_snrt, n_rcns, n_chnls, n_len)
        
        y_orig3 = y_orig2.flatten(1, 2)
        assert y_orig3.shape == (n_seeds, n_snrt * n_rcns, n_chnls, n_len)

        y_orig4 = y_orig3[..., i_elow: i_ehigh]
        y_rcnst3 = y_rcnst2[..., i_elow: i_ehigh]
        
        if e_type == 'mae':
            y_err = (y_orig4 - y_rcnst3).abs().mean(dim=[-1, -2])
            assert y_err.shape == (n_seeds, n_snrt * n_rcns)

            if avg_errs:
                y_err2 = y_err.mean(dim=-1).detach().cpu().numpy()
                assert y_err2.shape == (n_seeds,)
            else:
                y_err2 = y_err.detach().cpu().numpy()
                assert y_err2.shape == (n_seeds, n_snrt * n_rcns)
        elif e_type in ('rel', 'logrel'):
            if e_type == 'rel':
                y_orig5, y_rcnst4 = y_orig4, y_rcnst3
            elif e_type == 'logrel':
                y_orig5 = y_orig4.clamp(min=1e-1).log()
                y_rcnst4 = y_rcnst3.clamp(min=1e-1).log()
            else:
                raise ValueError(f'e_type={e_type} undefined')
            
            y_err = get_relerr(y_orig5, y_rcnst4, rdcdims=rel_rdcdims) / 2.0
            y_err = y_err.unsqueeze(-1).flatten(2, -1).mean(-1)
            assert y_err.shape == (n_seeds, n_snrt * n_rcns)

            if avg_errs:
                y_err2 = y_err.mean(dim=-1).detach().cpu().numpy()
                assert y_err2.shape == (n_seeds,)
            else:
                y_err2 = y_err.detach().cpu().numpy()
                assert y_err2.shape == (n_seeds, n_snrt * n_rcns)
        else:
            assert e_type is None, f'undefined e_type = {e_type}'
            y_err, y_err2 = None, None
        
        # Computing the cdf if necessary
        need_yorigcdf = (vrnt_trg == 'znrm')
        if need_yorigcdf:
            y_rcnstsrtd = y_rcnst.sort(dim=2).values
            assert y_rcnstsrtd.shape == (n_seeds, n_snrt, n_rcns, n_chnls, n_len)

            y_origcdf1 = torch_cdf(y_orig1, y_rcnstsrtd, dim=2, 
                frame='data', domain=[-float('inf'), float('inf')])
            assert y_origcdf1.shape == (n_seeds, n_snrt, n_rcns1, n_chnls, n_len)

            y_origraw = y_orig1.squeeze(dim=2)
            assert y_origraw.shape == (n_seeds, n_snrt, n_chnls, n_len)

            y_origcdf2 = y_origcdf1.squeeze(dim=2)
            assert y_origcdf2.shape == (n_seeds, n_snrt, n_chnls, n_len)
        else:
            y_origcdf2 = None
        
        v_mtrcdata[f'{ycol2}/err'] = y_err2 if (y_err2 is not None) else None    
        v_mtrcdata[f'{ycol2}/orig'] = tnsfm_y(y_orig3).detach().cpu().numpy()
        v_mtrcdata[f'{ycol2}/origraw'] = tnsfm_y(y_origraw).detach().cpu().numpy() if need_yorigcdf else None
        v_mtrcdata[f'{ycol2}/origcdf'] = tnsfm_y(y_origcdf2).detach().cpu().numpy() if need_yorigcdf else None

    v_mtrcdata.update(dict(
        eps1=eps1, i_eps1=i_eps1, eps2=eps2, i_eps2=i_eps2, eps3=eps3, i_eps3=i_eps3,
        tmp1=tmp1, i_tmp1=i_tmp1, tmp2=tmp2, i_tmp2=i_tmp2, tmp3=tmp3, i_tmp3=i_tmp3,
        wvl1=wvl1, i_wvl1=i_wvl1))
    
    return v_mtrcdata


## Collecting the Data For Aerosol Diagnostic Plots

In [None]:
# Reading the experiment to fpidx specification configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

exprm_infos = dict()
expfam_cfgs = v_mplcfgs['expspec']
for expfam, expfam_cfg in hie2deep(expfam_cfgs, maxdepth=1).items():
    exp_id2fpidx = get_subdict(expfam_cfg, 'fpidx', pop=True)
    splt2vrnts = get_subdict(expfam_cfg, 'vrnts', pop=True)
    resdir = expfam_cfg.pop('resdir').format(results_dir=results_dir)
    figdir = expfam_cfg.pop('fig/dir').format(workdir=workdir, suppdir=suppdir)
    n_seeds = expfam_cfg.pop('n_seeds')
    nicknmfrmla = expfam_cfg.pop('nicknm/frmla')
    nicknmfrmla = 'None' if nicknmfrmla is None else nicknmfrmla
    i_figcfg = get_subdict(expfam_cfg, 'fig/idx', pop=True)
    n_rcnsspec = get_subdict(expfam_cfg, 'n_rcns', pop=True)
    optns_piped = deep2hie(expfam_cfg)
    n_epoch = expfam_cfg.pop('n_epoch', 2)
    n_snrt = expfam_cfg.pop('n_snrt', 1000)
    if isinstance(n_rcnsspec, int):
        n_rcnsspec = {f'{splt}/{vrnt}': n_rcnsspec 
            for splt, vrnts in splt2vrnts.items() for vrnt in vrnts}

    exprm_faminfos = dict()
    for exp_idarch, fpidx in exp_id2fpidx.items():
        expid, arch = exp_idarch.split('/')
        exprmnt = f'{expfam}.{expid}'
        namevars = {'expid': expid, 'expfam': expfam, 'arch': arch}
        nicknm = eval_formula(nicknmfrmla, namevars)
        exprm_faminfos[f'{exprmnt}:{arch}/fpidx'] = fpidx
        exprm_faminfos[f'{exprmnt}:{arch}/splt2vrnts'] = splt2vrnts
        exprm_faminfos[f'{exprmnt}:{arch}/resdir'] = resdir
        exprm_faminfos[f'{exprmnt}:{arch}/n_seeds'] = n_seeds
        exprm_faminfos[f'{exprmnt}:{arch}/n_epoch'] = n_epoch
        exprm_faminfos[f'{exprmnt}:{arch}/n_snrt'] = n_snrt
        exprm_faminfos[f'{exprmnt}:{arch}/nicknm'] = nicknm
        exprm_faminfos[f'{exprmnt}:{arch}/figdir'] = figdir
        for split, vrnts in splt2vrnts.items():
            for vrnt in vrnts:
                n_rcns = n_rcnsspec.get(f'{split}/{vrnt}', 1)
                exprm_faminfos[f'{exprmnt}:{arch}/n_rcns/{split}/{vrnt}'] = n_rcns
        for key, val in optns_piped.items():
            exprm_faminfos[f'{exprmnt}:{arch}/{key}'] = val

    i_fig = None
    for fig_type, i_fig0 in i_figcfg.items():
        i_fig = i_fig0 if (i_fig0 is not None) else i_fig
        assert i_fig is not None, 'first one must be specified'
        for exprm_id in list(hie2deep(exprm_faminfos, sep='/', maxdepth=1)):
            (exprmnt, arch) = exprm_id.split(':')
            exprm_faminfos[f'{exprmnt}:{arch}/i_fig/{fig_type}'] = i_fig
            i_fig += 1
        
    exprm_infos.update(exprm_faminfos)

exprm_infos = {key: val for key, val in exprm_infos.items() 
    if any(fnmatch.fnmatch(key, pat) for pat in 
        ['trad.*', 'cond.cont.acsmsmps.*', 'cond.cont.trilbl.*', 
        'cond.mqnt.indzy.nrmtrg.himb.frac.null*',
        'cond.mqnt.indzy.nrmtrg.himb.frac.pow1*'])}


In [None]:
viz_datas = dict()

for exprm_id, exprm_info in hie2deep(exprm_infos, maxdepth=1).items():
    (exprmnt, arch) = exprm_id.split(':')
    fpidx = exprm_info['fpidx']
    splt2vrnts = exprm_info['splt2vrnts']
    resdir = exprm_info['resdir']
    n_epoch = exprm_info['n_epoch']
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_epoch = -1
    
    try:
        rio = resio(fpidx=fpidx, resdir=resdir)
    except Exception as exc:
        print(f'There was an error opening {fpidx}. I will move on.')
        continue
    
    for split, vrnts in splt2vrnts.items():
        print(f'Loading the data for {exprmnt}:{arch}:{split}...')

        ############## Collecting the plotting data ###############
        ycols = ['m_chmprthst', 'n_prthst', 'ccn_cdf', 'qs_prt', 'qscs_prt', 'qs_pop',
            'qa_prt', 'qacs_prt', 'qa_pop', 'frznfrac_tmp', 'logfrznfrac_tmp']

        v_ydataraw = dict()
        for ycol, vrnt in product(ycols, vrnts):
            n_ychnls, n_ylen = ycol2dims[ycol]
            n_rcns = exprm_info[f'n_rcns/{split}/{vrnt}']   
            y_nparr = rio(f'var/eval/raw/yaero:x:{ycol}/{split}/{vrnt}/pnts/data')
            assert y_nparr.shape == (n_epoch * n_seeds, n_snrt, n_rcns, n_ychnls, n_ylen)
            y_nparr2 = y_nparr.reshape(n_epoch, n_seeds, n_snrt, n_rcns, n_ychnls, n_ylen)
            assert y_nparr2.shape == (n_epoch, n_seeds, n_snrt, n_rcns, n_ychnls, n_ylen)
            v_ydataraw[f'{ycol}/{vrnt}'] = y_nparr2[i_epoch]

        ########### Data Cleaning and Unit Conversions ############
        v_ydata, aero_cstscnv = cnvrt_physunits(v_ydataraw, aero_csts)
        n_chem = aero_cstscnv['n_chem']
        n_bins = aero_cstscnv['n_bins']
        chem_species = aero_cstscnv['chem_species']
        
        ############### Selecting Examples to Show ################
        if set(vrnts) == {'orig', 'rcnst'}:
            n_clctn, n_prtrt = 1, 5
            with torch.no_grad():
                m_orig = torch.from_numpy(v_ydata['m_chmprthst/orig'])
                assert m_orig.shape == (n_seeds, n_snrt, 1, n_chem, n_bins)
                m_rcnst = torch.from_numpy(v_ydata['m_chmprthst/rcnst'])
                assert m_rcnst.shape == (n_seeds, n_snrt, 1, n_chem, n_bins)
                e_samps = get_relerr(m_orig.squeeze(1), m_rcnst.squeeze(1)
                    ).ravel().detach().cpu().numpy()
                assert e_samps.shape == (n_seeds * n_snrt,)
                fltr_hih2o, m_fltrh2o = True, m_orig
        elif set(vrnts) == {'orig', 'rcnst', 'znrm', 'yknn'}:
            n_clctn, n_prtrt = 1, 20
            with torch.no_grad():
                m_orig = torch.from_numpy(v_ydata['m_chmprthst/orig'])
                assert m_orig.shape == (n_seeds, n_snrt, 1, n_chem, n_bins)
                e_samps = m_orig.sum(dim=[-1, -2, -3]).ravel().detach().cpu().numpy()
                assert e_samps.shape == (n_seeds * n_snrt,)
                fltr_hih2o, m_fltrh2o = True, m_orig
        elif set(vrnts) == {'genr', 'genr0', 'genr1', 'genr2', 'genr3', 
            'gaus', 'gaus0', 'gaus1', 'gaus2', 'gaus3'}:
            n_clctn, n_prtrt = 20, 5
            with torch.no_grad():
                m_gauss1 = np.stack([v_ydata[f'm_chmprthst/{vrnt}'] 
                    for vrnt in ['gaus0', 'gaus1', 'gaus2', 'gaus3']], axis=-1)
                assert m_gauss1.shape == (n_seeds, n_snrt, 1, n_chem, n_bins, 4)
                m_gauss2 = m_gauss1 / (1e-30 + m_gauss1.sum(axis=(-3, -2), keepdims=True))
                assert m_gauss2.shape == (n_seeds, n_snrt, 1, n_chem, n_bins, 4)
                m_gauss3 = m_gauss2.max(axis=-1)
                assert m_gauss3.shape == (n_seeds, n_snrt, 1, n_chem, n_bins)
                np_random = np.random.RandomState(12345)
                e_samps = np_random.rand(n_seeds * n_snrt)
                assert e_samps.shape == (n_seeds * n_snrt,)
                fltr_hih2o, m_fltrh2o = True, m_gauss3
        elif set(vrnts) == {'genr'}:
            n_clctn, n_prtrt = 5, 4
            with torch.no_grad():
                m_genr = v_ydata['m_chmprthst/genr']
                assert m_genr.shape == (n_seeds, n_snrt, 1, n_chem, n_bins)
                i_oin = chem_species.index('Dust')
                e_samps1 = m_genr[:, :, 0, i_oin, :].sum(axis=-1) / m_genr.sum(axis=(-3, -2, -1))
                assert e_samps1.shape == (n_seeds, n_snrt)
                e_samps = e_samps1.ravel()
                assert e_samps.shape == (n_seeds * n_snrt,)
                fltr_hih2o, m_fltrh2o = True, m_genr
        else:
            raise ValueError(f'case undefined: {vrnts}')

        # Removing the samples with high water content
        if fltr_hih2o:
            i_water = chem_species.index('H2O')
            assert m_fltrh2o.shape == (n_seeds, n_snrt, 1, n_chem, n_bins)
            m_allchm = m_fltrh2o[:, :, 0, :, :].sum(axis=(-1, -2)).ravel()
            assert m_allchm.shape == (n_seeds * n_snrt,)
            m_water = m_fltrh2o[:, :, 0, i_water, :].sum(axis=-1).ravel()
            assert m_water.shape == (n_seeds * n_snrt,)
            m_waterfrac = m_water / m_allchm
            assert m_waterfrac.shape == (n_seeds * n_snrt,)
            e_samps[m_waterfrac > 0.1] = np.nan
        
        e_sampargsrt1 = np.argsort(e_samps)
        assert e_sampargsrt1.shape == (n_seeds * n_snrt,)

        # Filtering out the nan values in `e_samps`
        i_nanesamp = np.where(np.isnan(e_samps))[0]
        i_nanesampset = set(i_nanesamp)
        e_sampargsrt2 = np.array([ii for ii in e_sampargsrt1 if ii not in i_nanesampset])
        n_samps2 = n_seeds * n_snrt - len(i_nanesampset)
        assert e_sampargsrt2.shape == (n_samps2,)

        q_sel1 = np.linspace(0.1, 0.9, n_prtrt)
        assert q_sel1.shape == (n_prtrt,)
        q_sel = q_sel1[None, :] + 100 * np.arange(n_clctn)[:, None] / (n_samps2)
        assert q_sel.shape == (n_clctn, n_prtrt)
        iq_sel = (q_sel.ravel() * n_samps2).astype(int)
        iq_sel = np.clip(iq_sel, 0, n_samps2 - 1)
        assert iq_sel.shape == (n_clctn * n_prtrt,)
        i_sel = e_sampargsrt2[iq_sel].reshape(n_clctn, n_prtrt)
        assert i_sel.shape == (n_clctn, n_prtrt)

        viz_datas[f'{exprmnt}:{arch}:{split}/v_ydata'] = v_ydata
        viz_datas[f'{exprmnt}:{arch}:{split}/v_ydataraw'] = v_ydataraw
        viz_datas[f'{exprmnt}:{arch}:{split}/q_sel'] = q_sel
        viz_datas[f'{exprmnt}:{arch}:{split}/i_sel'] = i_sel
        viz_datas[f'{exprmnt}:{arch}:{split}/e_samps'] = e_samps
        viz_datas[f'{exprmnt}:{arch}:{split}/aero_csts'] = aero_cstscnv
        for key, val in aero_cstscnv.items():
            viz_datas[f'{exprmnt}:{arch}:{split}/{key}'] = val


## Conditional Tri-Label Generation

### Generating Conditional Tables

In [None]:
viz_datasdeep = hie2deep(viz_datas, maxdepth=1)

df_errslst = []
for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    if (not exprmnt.startswith('cond.cont.')) or (split not in ('test',)):
        continue

    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    n_rcns = exprm_info['n_rcns/test/znrm']

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata1 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=False, vrnt_trg='znrm')

    eps1, eps2, eps3 = v_mtrcdata1['eps1'], v_mtrcdata1['eps2'], v_mtrcdata1['eps3']
    i_eps1, i_eps2, i_eps3 = v_mtrcdata1['i_eps1'], v_mtrcdata1['i_eps2'], v_mtrcdata1['i_eps3']
    tmp1, tmp2, tmp3 = v_mtrcdata1['tmp1'], v_mtrcdata1['tmp2'], v_mtrcdata1['tmp3']
    i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata1['i_tmp1'], v_mtrcdata1['i_tmp2'], v_mtrcdata1['i_tmp3']
    wvl1, i_wvl1 = v_mtrcdata1['wvl1'], v_mtrcdata1['i_wvl1']

    ######## Framing the Error Metrics as a DataFrame #########
    dfdict = dict(viz_id=viz_id, exprmnt=exprmnt, arch=arch, 
        split=split, epoch=-1, rng_seed=np.arange(n_seeds))

    for ycol, y_samps1 in v_mtrcdata1.items():
        if not ycol.endswith('/err'):
            continue
        
        if y_samps1 is None:
            y_samps2 = None
        else:
            assert y_samps1.shape == (n_seeds, n_snrt * n_rcns)
            y_samps2 = y_samps1.mean(axis=-1)
            assert y_samps2.shape == (n_seeds,)

        dfdict[ycol] = y_samps2

    df_vizid = pd.DataFrame(dfdict)
    df_errslst.append(df_vizid)

df_errs = pd.concat(df_errslst, axis=0, ignore_index=True)
df_errs = df_errs.dropna(axis=1, how='all')
ycols = [col for col in df_errs.columns if col.endswith('/err')]

######## Bootstrap Aggregating the Results #########
aggcfg = dict(type='bootstrap', n_boot=40, q=[2.5, 97.5], stat='mean', device='cpu')
hpcols = ['viz_id', 'exprmnt', 'arch', 'split']
stcols = [col for col in df_errs.columns if col not in hpcols]
agg_data = get_aggdf(df_errs[hpcols], df_errs[stcols], xcol='epoch', 
    huecol='viz_id', rngcol='rng_seed', aggcfg=aggcfg)
hpdf_agg, stdf_agg = agg_data['hpdf'], agg_data['stdf']
df_agg = pd.concat([hpdf_agg, stdf_agg], axis=1)
df_agg = df_agg.drop(columns=['exprmnt', 'arch', 'split', 'epoch'])

In [None]:
######## Creating a Formatted Table for Latex #########
df_agglst2 = []
for i_row, row_dict in df_agg.iterrows():
    row_dict2 = dict(viz_id=row_dict['viz_id'])
    for ycol in ycols:
        y_mean = row_dict[f'{ycol}/mean'] * 100
        y_low = row_dict[f'{ycol}/low'] * 100
        y_high = row_dict[f'{ycol}/high'] * 100
        y_meanstr = f'{y_mean:.2g}' if y_mean < 1 else f'{y_mean:.3g}'
        y_lowstr = f'{y_low:.2g}' if y_low < 1 else f'{y_low:.3g}'
        y_highstr = f'{y_high:.2g}' if y_high < 1 else f'{y_high:.3g}'
        row_dict2[ycol] = f'${y_meanstr}\% [{y_lowstr}\%,{y_highstr}\%]$'
    df_agglst2.append(row_dict2)
df_agg2 = pd.DataFrame(df_agglst2)

# Transposing the table
df_agg3 = df_agg2.set_index('viz_id').T
df_agg3.columns.name = None

err_rnmngs = {
    'ccn_cdf/err': 'CCN Spectrum',
    'logqa_pop/err': 'Vol Scat Coef',
    'logqs_pop/err': 'Vol Abs Coef',
    'logfrznfrac_tmp/err': 'Frozen Fraction',
    'm_acsm/err': 'ACSM Readings',
    'n_smps/err': 'SMPS Readings',
    'n_prthst/err': 'Number Dist',
    'm_prthst/err': 'Total Mass',
    'm_chmprt/err': 'Species Mass'}

col_rnmngs0 = {
    'index': 'Gen Ambiguity',
    'cond.cont.trilbl.depzy:mlp:test': 'HDM (Trad)',
    'cond.cont.trilbl.indzy:mlp:test': 'HDM (Ours)',
    'cond.cont.acsmsmps.duo.depzy:mlp:test': 'LDM (Trad)',
    'cond.cont.acsmsmps.duo.indzy:mlp:test': 'LDM (Ours)'}

col_rnmngs1 = {
    'index': r'\makecell[c]{Generative\\Ambiguity}',
    'cond.cont.trilbl.indzy:mlp:test': r'\makecell[c]{High-Dim Measurements\\Mean [95\% CI]}',
    'cond.cont.acsmsmps.duo.indzy:mlp:test': r'\makecell[c]{Low-Dim Measurements\\Mean [95\% CI]}'}

col_rnmngs2 = {
    'index': '\makecell[c]{Generative\\Ambiguity}',
    'cond.cont.trilbl.depzy:mlp:test': r'\makecell[c]{Traditional CVAE\\Mean [95\% CI]}',
    'cond.cont.trilbl.indzy:mlp:test': r'\makecell[c]{Wasserstein-Regularized CVAE (Ours)\\Mean [95\% CI]}'}

col_rnmngs3 = {
    'index': '\makecell[c]{Generative\\Ambiguity}',
    'cond.cont.acsmsmps.duo.depzy:mlp:test': r'\makecell[c]{Traditional CVAE\\Mean [95\% CI]}',
    'cond.cont.acsmsmps.duo.indzy:mlp:test': r'\makecell[c]{Wasserstein-Regularized CVAE (Ours)\\Mean [95\% CI]}'}


df_agg4 = df_agg3.copy(deep=True)
# Selecting a subset of the errors and Reordering them
df_agg5 = df_agg4.loc[list(err_rnmngs)].reset_index()
# Renaming the errors
df_agg6 = df_agg5.replace(err_rnmngs)
# Selecting a subset of the columns and Reordering them and renaming the columns
df_agg7 = df_agg6.loc[:, list(col_rnmngs0)].rename(columns=col_rnmngs0)
df_agg8 = df_agg6.loc[:, list(col_rnmngs1)].rename(columns=col_rnmngs1)
df_agg9 = df_agg6.loc[:, list(col_rnmngs2)].rename(columns=col_rnmngs2)
df_agg10 = df_agg6.loc[:, list(col_rnmngs3)].rename(columns=col_rnmngs3)

df_agg10

In [None]:
tbl_tex = df_agg10.to_latex(index=False)
tbl_tex = tbl_tex.replace(r'\begin{tabular}{lllll}', 
    r'\begin{tabular}{' + r'|p{0.17\textwidth}' * 5 + r'|}')
tbl_tex = tbl_tex.replace(r'\toprule', r'\hline')
tbl_tex = tbl_tex.replace(r'\midrule', r'\hline')
tbl_tex = tbl_tex.replace(r'\bottomrule', r'\hline')
tbl_tex = tbl_tex.replace(r'$ \\', r'$ \\\hline')
print(tbl_tex)

### Anecdotal Mass and Number Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_binsum = viz_data['n_bins']
    d_histbinsum = viz_data['d_histbins']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.trilbl')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/xorig']

    n_figrows, n_figcols = 1, 2
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):
            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]

            # The speciated mass data
            i_row, i_col = 0, 0
            n_rcns, i_rcns = 1, 0
            v_ynparr = v_ydata['m_chmprthst/orig']
            assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
            v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
            assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
            ax_idx = i_row * n_figcols + i_col
            for i_chem, chem in enumerate(chem_species):
                v_mpldatas[f'm_chmprthst.cndinp/{ax_idx}:{chem}:x'] = d_histbinsum
                v_mpldatas[f'm_chmprthst.cndinp/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            ax_ixyspec = []
            ax_ixyspec.append([1, d_histbinsum, 'n_prthst'])
            
            for i_row in range(n_figrows):
                for i_col, xvals, ycol in ax_ixyspec:
                    ax_idx = i_row * n_figcols + i_col
                    n_rcns, i_rcns = 1, 0
                    v_ynparr = v_ydata[f'{ycol}/orig']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndinp/{ax_idx}:orig:x'] = xvals
                    v_mpldatas[f'{ycol}.cndinp/{ax_idx}:orig:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        tag_text = f'({"ab"[i_figcol]})'
                        tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.9))
                # for i_figcol, ax in enumerate(axes[0]):
                #     print_axheader(ax, f'Input Label {i_figcol+1}', 'top', fontsize=14, fontweight='bold')

            figs_dict[f'{i_page}/label'] = fig

        os.makedirs(figdir, exist_ok=True)
        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_xanec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Anecdotal Input Label Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    i_sel = viz_data['i_sel']
    n_chem = viz_data['n_chem']
    n_binsum = viz_data['n_bins']
    e_samps = viz_data['e_samps']
    d_histbinsum = viz_data['d_histbins']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.trilbl.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/label']

    n_figrows, n_figcols = 1, 3
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):
            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            i_chems = np.arange(n_chem + 1)
            ax_ixyspec = []
            ax_ixyspec.append([0, d_histbinsum, 'n_prthst'])
            ax_ixyspec.append([1, d_histbinsum, 'm_prthst'])
            ax_ixyspec.append([2, i_chems,      'm_chmprt'])
            
            for i_row in range(n_figrows):
                for i_col, xvals, ycol in ax_ixyspec:
                    ax_idx = i_row * n_figcols + i_col
                    n_rcns, i_rcns = 1, 0
                    v_ynparr = v_ydata[f'{ycol}/orig']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndlbl/{ax_idx}:orig:x'] = xvals
                    v_mpldatas[f'{ycol}.cndlbl/{ax_idx}:orig:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        tag_text = f'({"cde"[i_figcol]})'
                        tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.9))
                for i_figcol, ax in enumerate(axes[0]):
                    print_axheader(ax, f'Measurement {i_figcol+1}', 'top', fontsize=14, fontweight='bold')

            figs_dict[f'{i_page}/label'] = fig

        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_yanec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=f, pig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Anecdotal Generated Sample and Label Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbinsum = viz_data['d_histbins']
    n_binsum = viz_data['n_bins']
    len_wvum = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    eps_histbins = viz_data['eps_histbins']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']

    n_binssmps = aero_cstscnv['n_binssmps']
    d_binssmpsum = aero_cstscnv['d_binssmps']
    acsm_species = aero_cstscnv['acsm_species']
    n_chemacsm = aero_cstscnv['n_chemacsm']

    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.trilbl.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/gendiag']

    # vrnt_specs = exprm_info.get('vrnt_specs', None)
    vrnt_specs = [
        ('orig', 0, 'Original'), ('znrm', 0, 'Norm Lat'), 
        ('znrm', 1, 'Norm Lat'), ('znrm', 2, 'Norm Lat'), 
        ('znrm', 3, 'Norm Lat'), ('znrm', 4, 'Norm Lat')][1:]
    n_figrows, n_figcols = 4, len(vrnt_specs)
    n_figrows = v_mplcfgs['m_chmprthst.cndgen']['plt.subplots/nrows']
    v_mplcfgs['m_chmprthst.cndgen']['plt.subplots/ncols'] = n_figcols
    assert len(vrnt_specs) == n_figcols

    v_mplcfgs['ccn_cdf.cndgen']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['frznfrac_tmp.cndgen']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):
            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]
            
            # The speciated mass data
            i_row = 0
            for i_col in range(n_figcols):
                vrnt, i_rcns, ttl = vrnt_specs[i_col]
                n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                v_ynparr = v_ydata[f'm_chmprthst/{vrnt}']
                assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
                v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                ax_idx = i_row * n_figcols + i_col
                for i_chem, chem in enumerate(chem_species):
                    v_mpldatas[f'm_chmprthst.cndgen/{ax_idx}:{chem}:x'] = d_histbinsum
                    v_mpldatas[f'm_chmprthst.cndgen/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            i_chems = np.arange(n_chem + 1)
            ax_ixyspec = []
            ax_ixyspec.append([1, d_histbinsum, 'n_prthst'     , True ])
            ax_ixyspec.append([2, d_histbinsum, 'm_prthst'     , True ])
            ax_ixyspec.append([3, i_chems,      'm_chmprt'     , True ])
            # ax_ixyspec.append([4, eps_histbins, 'ccn_cdf'     , False])
            # ax_ixyspec.append([5, temprtr_bins, 'frznfrac_tmp', False])
            # ax_ixyspec.append([6, len_wvum,     'qs_pop'      , False])
            # ax_ixyspec.append([7, len_wvum,     'qa_pop'      , False])
            
            for i_col in range(n_figcols):
                for i_row, xvals, ycol, plot_orig in ax_ixyspec:
                    ax_idx = i_row * n_figcols + i_col
                    vrnt, i_rcns, ttl = vrnt_specs[i_col]
                    n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                    v_ynparr = v_ydata[f'{ycol}/{vrnt}']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:rcnst:x'] = xvals
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:rcnst:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])

                    if not plot_orig: continue
                    n_rcns, i_rcns = 1, 0
                    v_ynparr = v_ydata[f'{ycol}/orig']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:orig:x'] = xvals
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:orig:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        tag_text = f'({"fghi"[i_figrow]}$_{{{i_figcol + 1}}}$)'
                        tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
                for i_figcol, ax in enumerate(axes[0]):
                    txthdr = f'Sample {i_figcol + 1}'
                    print_axheader(ax, txthdr, 'top', fontsize=14, fontweight='bold')

            figs_dict[f'{i_page}/diag'] = fig

        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_anec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Anecdotal Aerosol Diagnostic Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
show_mchmprt = False

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbinsum = viz_data['d_histbins']
    n_binsum = viz_data['n_bins']
    len_wvum = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    eps_histbins = viz_data['eps_histbins']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']

    n_binssmps = aero_cstscnv['n_binssmps']
    d_binssmpsum = aero_cstscnv['d_binssmps']
    acsm_species = aero_cstscnv['acsm_species']
    n_chemacsm = aero_cstscnv['n_chemacsm']

    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.trilbl.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/vardiag']
    
    n_figrows, n_figcols = 1 + show_mchmprt, 4
    for ycol in ['m_chmprthst', 'n_prthst', 'ccn_cdf', 'qs_pop', 'qa_pop', 'frznfrac_tmp']:
        v_mplcfgs[f'{ycol}.cndvar']['plt.subplots/nrows'] = n_figrows
        v_mplcfgs[f'{ycol}.cndvar']['plt.subplots/ncols'] = n_figcols

    v_mplcfgs['ccn_cdf.cndvar']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['frznfrac_tmp.cndvar']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):

            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]

            if show_mchmprt:
                # The speciated mass data
                i_row, vrnt = 0, 'znrm'
                n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                for i_col in range(n_figcols):
                    v_ynparr = v_ydata[f'm_chmprthst/{vrnt}']
                    assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                    assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                    ax_idx, i_rcns = i_row * n_figcols + i_col, i_col
                    for i_chem, chem in enumerate(chem_species):
                        v_mpldatas[f'm_chmprthst.cndvar/{ax_idx}:{chem}:x'] = d_histbinsum
                        v_mpldatas[f'm_chmprthst.cndvar/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            i_acsmchems = np.arange(n_chemacsm + 1)
            ax_ixyspec = []
            # ax_ixyspec.append([0, d_histbinsum, 'n_prthst'    , False])
            ax_ixyspec.append([0, eps_histbins, 'ccn_cdf'     , False])
            ax_ixyspec.append([1, len_wvum,     'qs_pop'      , False])
            ax_ixyspec.append([2, len_wvum,     'qa_pop'      , False])
            ax_ixyspec.append([3, temprtr_bins, 'frznfrac_tmp', False])
            
            i_row = 1 if show_mchmprt else 0
            for i_col, xvals, ycol, plot_orig in ax_ixyspec:
                ax_idx = i_row * n_figcols + i_col
                for vrnt, plot_ci in [('znrm', True), ('orig', False)]:
                    v_ynparr = v_ydata[f'{ycol}/{vrnt}']
                    n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    d_ydata = math.prod(v_ynparr2.shape[2:])
                    v_ynparr3 = v_ynparr2.reshape(n_seeds * n_snrt, n_rcns, d_ydata)
                    assert v_ynparr3.shape == (n_seeds * n_snrt, n_rcns, d_ydata)
                    v_ynparr4 = v_ynparr3[i_samp]
                    assert v_ynparr4.shape == (n_rcns, d_ydata)
                    v_ymean = np.median(v_ynparr4, axis=0)
                    assert v_ymean.shape == (d_ydata,)
                    v_ylow = np.quantile(v_ynparr4, q=0.0, axis=0)
                    assert v_ylow.shape == (d_ydata,)
                    v_yhigh = np.quantile(v_ynparr4, q=1.0, axis=0)
                    assert v_yhigh.shape == (d_ydata,)
                    v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:x'] = xvals
                    if plot_ci:
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y/mean'] = v_ymean
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y/low'] = v_ylow
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y/high'] = v_yhigh
                    else:
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y'] = v_ymean
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        if show_mchmprt and (i_figrow == 0):
                            tag_text = f'(a$_{{{i_figcol + 1}}}$)'
                        elif show_mchmprt and (i_figrow > 0):
                            tag_text = f'({"bcdef"[i_figcol]})'
                        else:
                            tag_text = f'({"jklmn"[i_figcol]})'
                        tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
                if show_mchmprt:
                    for i_figrow, ax in enumerate(axes[:, 0]):
                        txthdr = ('Speciated Mass', 'Diagnostics')[i_figrow] 
                        print_axheader(ax, txthdr, 'left', fontsize=14, pad=10, fontweight='bold')

            figs_dict[f'{i_page}/vardiag'] = fig
    
        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_anec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Conditional Collective Summary Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
colorspec = v_mplcfgs['colorspec']

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    if (not exprmnt.startswith('cond.cont.trilbl.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/diagsmry']

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata1 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=False, vrnt_trg='znrm')

    eps1, eps2, eps3 = v_mtrcdata1['eps1'], v_mtrcdata1['eps2'], v_mtrcdata1['eps3']
    i_eps1, i_eps2, i_eps3 = v_mtrcdata1['i_eps1'], v_mtrcdata1['i_eps2'], v_mtrcdata1['i_eps3']
    tmp1, tmp2, tmp3 = v_mtrcdata1['tmp1'], v_mtrcdata1['tmp2'], v_mtrcdata1['tmp3']
    i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata1['i_tmp1'], v_mtrcdata1['i_tmp2'], v_mtrcdata1['i_tmp3']
    wvl1, i_wvl1 = v_mtrcdata1['wvl1'], v_mtrcdata1['i_wvl1']

    # Assembling the plotting data
    v_mpldatas1 = dict()
    v_mpldatas1['nprthst_errhist/0:n_smpserr:y'] = v_mtrcdata1['n_prthst/err'].ravel()
    v_mpldatas1['mprthst_errhist/1:m_acsmerr:y'] = v_mtrcdata1['m_prthst/err'].ravel()
    v_mpldatas1['mchmprt_errhist/2:m_acsmerr:y'] = v_mtrcdata1['m_chmprt/err'].ravel()
    for ax_idx, ax_id, i_eps in [(3, 'ccn_sctcnd1', i_eps1), (6, 'ccn_sctcnd3', i_eps3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:x'] = v_mtrcdata1['ccn_cdf/orig'][..., i_eps].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:y'] = v_mtrcdata1['ccn_cdf/znrm'][..., i_eps].ravel()
    v_mpldatas1['qs_popsctcnd/4:qs_pop:x'] = v_mtrcdata1['qs_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qs_popsctcnd/4:qs_pop:y'] = v_mtrcdata1['qs_pop/znrm'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcnd/7:qa_pop:x'] = v_mtrcdata1['qa_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcnd/7:qa_pop:y'] = v_mtrcdata1['qa_pop/znrm'][:, :, i_wvl1, :].ravel()
    for ax_idx, ax_id, i_tmp in [(5, 'frznfrac_sctcnd1', i_tmp1), (8, 'frznfrac_sctcnd3', i_tmp3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:x'] = v_mtrcdata1['frznfrac_tmp/orig'][..., i_tmp].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:y'] = v_mtrcdata1['frznfrac_tmp/znrm'][..., i_tmp].ravel()

    # Making the matplotlib calls
    v_mpldatas2 = hie2deep(v_mpldatas1, maxdepth=1)
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatas2):
        fig, axes = plot_mpl(data=v_mpldatas2[ax_id], fig=fig, axes=axes, 
            mplopts=v_mplcfgs[ax_id])
    axes1d = axes.ravel()

    # Adding the text boxes
    for ax_idx, textstr in [
        (3,  f'$s={100*eps1:0.2g}\%$'), 
        (6,  f'$s={100*eps3:0.2g}\%$'), 
        (4,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (7,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (5,  f'$T={{\\rm {tmp1:.0f}^{{\circ}}\\ C}}$'), 
        (8,  f'$T={{\\rm {tmp3:.0f}^{{\circ}}\\ C}}$')]:
        ax = axes1d[ax_idx]
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', usetex=True, bbox={'facecolor': 'none', 
            'edgecolor': 'none'})

    # Adding the x=y line to scatter plots
    for ax_idx in range(3, 9):
        ax = axes1d[ax_idx]
        x_lo = min(ax.get_xlim()[0], ax.get_ylim()[0])
        x_hi = max(ax.get_xlim()[1], ax.get_ylim()[1])
        x_id = np.linspace(x_lo, x_hi, 100)
        ax.plot(x_id, x_id, lw=1, ls='--', color=colorspec['blue'])
    
    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(axes.shape[0]):
            for i_figcol, ax in enumerate(axes[i_figrow]):
                ax_idx = i_figrow * axes.shape[1] + i_figcol
                tag_text = f'({"abcdefghi"[ax_idx]})'
                tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
    
    _nicknm = '' if nicknm is None else f'_{nicknm}'
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_smry.pdf'
    pngpath = pdfpath[:-4] + '.png'
    fig.savefig(pngpath, dpi=200, bbox_inches='tight')
    print(f'Finished writing {pngpath}')

fig

### Conditional Diagnostic Calibration Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
colorspec = v_mplcfgs['colorspec']

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    if (not exprmnt.startswith('cond.cont.trilbl.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/diagsmryqnt']

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata1 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=False, vrnt_trg='znrm')

    eps1, eps2, eps3 = v_mtrcdata1['eps1'], v_mtrcdata1['eps2'], v_mtrcdata1['eps3']
    i_eps1, i_eps2, i_eps3 = v_mtrcdata1['i_eps1'], v_mtrcdata1['i_eps2'], v_mtrcdata1['i_eps3']
    tmp1, tmp2, tmp3 = v_mtrcdata1['tmp1'], v_mtrcdata1['tmp2'], v_mtrcdata1['tmp3']
    i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata1['i_tmp1'], v_mtrcdata1['i_tmp2'], v_mtrcdata1['i_tmp3']
    wvl1, i_wvl1 = v_mtrcdata1['wvl1'], v_mtrcdata1['i_wvl1']

    # Assembling the plotting data
    v_mpldatas1 = dict()
    # v_mpldatas1['nsmps_errhist/0:n_smpserr:y'] = v_mtrcdata1['n_smps/err'].ravel()
    # v_mpldatas1['macsm_errhist/1:m_acsmerr:y'] = v_mtrcdata1['m_acsm/err'].ravel()
    for ax_idx, ax_id, i_eps in [(0, 'ccn_sctcndqnt1', i_eps1), (3, 'ccn_sctcndqnt3', i_eps3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:x'] = v_mtrcdata1['ccn_cdf/origraw'][..., i_eps].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:y'] = v_mtrcdata1['ccn_cdf/origcdf'][..., i_eps].ravel()
    v_mpldatas1['qs_popsctcndqnt/1:qs_pop:x'] = v_mtrcdata1['qs_pop/origraw'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qs_popsctcndqnt/1:qs_pop:y'] = v_mtrcdata1['qs_pop/origcdf'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcndqnt/4:qa_pop:x'] = v_mtrcdata1['qa_pop/origraw'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcndqnt/4:qa_pop:y'] = v_mtrcdata1['qa_pop/origcdf'][:, :, i_wvl1, :].ravel()
    for ax_idx, ax_id, i_tmp in [(2, 'frznfrac_sctcndqnt1', i_tmp1), (5, 'frznfrac_sctcndqnt3', i_tmp3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:x'] = v_mtrcdata1['frznfrac_tmp/origraw'][..., i_tmp].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:y'] = v_mtrcdata1['frznfrac_tmp/origcdf'][..., i_tmp].ravel()

    # Making the matplotlib calls
    v_mpldatas2 = hie2deep(v_mpldatas1, maxdepth=1)
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatas2):
        fig, axes = plot_mpl(data=v_mpldatas2[ax_id], fig=fig, axes=axes, 
            mplopts=v_mplcfgs[ax_id])
    axes1d = axes.ravel()

    # Adding the text boxes
    for ax_idx, textstr in [
        (0,  f'$s={100*eps1:0.2g}\%$'), 
        (3,  f'$s={100*eps3:0.2g}\%$'), 
        (1,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (4,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (2,  f'$T={{\\rm {tmp1:.0f}^{{\circ}}\\ C}}$'), 
        (5,  f'$T={{\\rm {tmp3:.0f}^{{\circ}}\\ C}}$')]:
        ax = axes1d[ax_idx]
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', usetex=True, bbox={'facecolor': 'none', 
            'edgecolor': 'none'})
    
    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(axes.shape[0]):
            for i_figcol, ax in enumerate(axes[i_figrow]):
                ax_idx = i_figrow * axes.shape[1] + i_figcol
                tag_text = f'({"abc"[i_figcol]}$_{{{i_figrow+1}}}$)'
                tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
        
        for i_figcol, col_ttl in enumerate(['Cloud Condensation', 'Optical Properties', 'Ice Nucleation']):
            print_axheader(axes[0, i_figcol], col_ttl, 'top', pad=8, fontsize=14, fontweight='bold')
    
    _nicknm = '' if nicknm is None else f'_{nicknm}'
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_smry.pdf'
    pngpath = pdfpath[:-4] + '.png'
    fig.savefig(pngpath, dpi=200, bbox_inches='tight')
    print(f'Finished writing {pngpath}')

fig

### Non-Conditional and Conditional Average Error Tables

In [None]:
viz_datasdeep = hie2deep(viz_datas, maxdepth=1)

df_errslst = []
for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    if not((fnmatch.fnmatch(exprmnt, 'trad.*') or 
            fnmatch.fnmatch(exprmnt, 'cond.cont.*')) and (split in ('test',))):
        continue

    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    if exprmnt.startswith('trad.'):
        vrnt_trg = 'rcnst'
    elif exprmnt.startswith('cond.cont.'):
        vrnt_trg = 'znrm'
    else:
        vrnt_trg = None
    n_rcns = exprm_info[f'n_rcns/{split}/{vrnt_trg}']

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata1 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=True, vrnt_trg=vrnt_trg)

    eps1, eps2, eps3 = v_mtrcdata1['eps1'], v_mtrcdata1['eps2'], v_mtrcdata1['eps3']
    i_eps1, i_eps2, i_eps3 = v_mtrcdata1['i_eps1'], v_mtrcdata1['i_eps2'], v_mtrcdata1['i_eps3']
    tmp1, tmp2, tmp3 = v_mtrcdata1['tmp1'], v_mtrcdata1['tmp2'], v_mtrcdata1['tmp3']
    i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata1['i_tmp1'], v_mtrcdata1['i_tmp2'], v_mtrcdata1['i_tmp3']
    wvl1, i_wvl1 = v_mtrcdata1['wvl1'], v_mtrcdata1['i_wvl1']

    ######## Framing the Error Metrics as a DataFrame #########
    dfdict = dict(viz_id=viz_id, exprmnt=exprmnt, arch=arch, 
        split=split, epoch=-1, rng_seed=np.arange(n_seeds))
    dfdict.update({ycol: y_samps for ycol, y_samps in v_mtrcdata1.items() 
        if ycol.endswith('/err')})
    df_vizid = pd.DataFrame(dfdict)
    df_errslst.append(df_vizid)

df_errs = pd.concat(df_errslst, axis=0, ignore_index=True)
df_errs = df_errs.dropna(axis=1, how='all')
ycols = [col for col in df_errs.columns if col.endswith('/err')]

######## Bootstrap Aggregating the Results #########
df_agglst = []
for viz_id, df_errsvid in df_errs.groupby('viz_id', sort=False, observed=True):
    aggcfg = dict(type='bootstrap', n_boot=40, q=[2.5, 97.5], stat='mean', device='cpu')
    hpcols = ['viz_id', 'exprmnt', 'arch', 'split']
    stcols = [col for col in df_errsvid.columns if col not in hpcols]
    agg_data = get_aggdf(df_errsvid[hpcols], df_errsvid[stcols], xcol='epoch', 
        huecol='viz_id', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_agg, stdf_agg = agg_data['hpdf'], agg_data['stdf']
    df_aggvid = pd.concat([hpdf_agg, stdf_agg], axis=1)
    df_aggvid = df_aggvid.drop(columns=['exprmnt', 'arch', 'split', 'epoch'])
    df_agglst.append(df_aggvid)

df_agg = pd.concat(df_agglst, axis=0, ignore_index=True)

In [None]:
######## Creating a Formatted Table for Latex #########
df_agglst2 = []
for i_row, row_dict in df_agg.iterrows():
    row_dict2 = dict(viz_id=row_dict['viz_id'])
    for ycol in ycols:
        y_mean = row_dict[f'{ycol}/mean'] * 100
        y_low = row_dict[f'{ycol}/low'] * 100
        y_high = row_dict[f'{ycol}/high'] * 100
        y_meanstr = f'{y_mean:.2g}' if y_mean < 1 else f'{y_mean:.3g}'
        y_lowstr = f'{y_low:.2g}' if y_low < 1 else f'{y_low:.3g}'
        y_highstr = f'{y_high:.2g}' if y_high < 1 else f'{y_high:.3g}'
        row_dict2[ycol] = f'${y_meanstr}\% [{y_lowstr}\%,{y_highstr}\%]$'
    df_agglst2.append(row_dict2)
df_agg2 = pd.DataFrame(df_agglst2)

# Transposing the table
df_agg3 = df_agg2.set_index('viz_id').T
df_agg3.columns.name = None

err_rnmngs = {
    'ccn_cdf/err': 'CCN Spectrum',
    'logqs_pop/err': 'Vol Scat Coef',
    'logqa_pop/err': 'Vol Abs Coef',
    'logfrznfrac_tmp/err': 'Frozen Fraction',
    'm_acsm/err': 'ACSM Readings',
    'n_smps/err': 'SMPS Readings',
    'n_prthst/err': 'Number Dist',
    'm_prthst/err': 'Total Mass',
    'm_chmprt/err': 'Species Mass',
    'm_chmprthst/err': 'Speciated Mass'}

col_rnmngs1 = {
    'index': 'Reconst Error',
    'trad.gen:mlp:test': 'MLP',
    'trad.gen:cnn:test': 'CNN'}

col_rnmngs2 = {
    'index': 'Gen Ambiguity',
    'cond.cont.trilbl.depzy:mlp:test': 'HDM (Trad)',
    'cond.cont.trilbl.indzy:mlp:test': 'HDM (Ours)',
    'cond.cont.acsmsmps.duo.depzy:mlp:test': 'LDM (Trad)',
    'cond.cont.acsmsmps.duo.indzy:mlp:test': 'LDM (Ours)'}

col_rnmngs3 = {
    'index': r'\makecell[c]{Generative\\Ambiguity}',
    'cond.cont.trilbl.indzy:mlp:test': r'\makecell[c]{High-Dim Measurements\\Mean [95\% CI]}',
    'cond.cont.acsmsmps.duo.indzy:mlp:test': r'\makecell[c]{Low-Dim Measurements\\Mean [95\% CI]}'}

col_rnmngs4 = {
    'index': '\makecell[c]{Generative\\Ambiguity}',
    'cond.cont.trilbl.depzy:mlp:test': r'\makecell[c]{Traditional CVAE\\Mean [95\% CI]}',
    'cond.cont.trilbl.indzy:mlp:test': r'\makecell[c]{Wasserstein-Regularized CVAE (Ours)\\Mean [95\% CI]}'}

col_rnmngs5 = {
    'index': '\makecell[c]{Generative\\Ambiguity}',
    'cond.cont.acsmsmps.duo.depzy:mlp:test': r'\makecell[c]{Traditional CVAE\\Mean [95\% CI]}',
    'cond.cont.acsmsmps.duo.indzy:mlp:test': r'\makecell[c]{Wasserstein-Regularized CVAE (Ours)\\Mean [95\% CI]}'}


df_agg4 = df_agg3.copy(deep=True)
# Selecting a subset of the errors and Reordering them
df_agg5 = df_agg4.loc[list(err_rnmngs)].reset_index()
# Renaming the errors
df_agg6 = df_agg5.replace(err_rnmngs)
# Selecting a subset of the columns and Reordering them and renaming the columns
df_agg7 = df_agg6.loc[:, list(col_rnmngs1)].rename(columns=col_rnmngs1)
df_agg8 = df_agg6.loc[:, list(col_rnmngs2)].rename(columns=col_rnmngs2)
df_agg9 = df_agg6.loc[:, list(col_rnmngs3)].rename(columns=col_rnmngs3)
df_agg10 = df_agg6.loc[:, list(col_rnmngs4)].rename(columns=col_rnmngs4)
df_agg11 = df_agg6.loc[:, list(col_rnmngs5)].rename(columns=col_rnmngs5)

df_agg7

In [None]:
tbl_tex = df_agg7.to_latex(index=False)
tbl_tex = tbl_tex.replace(r'\begin{tabular}{lllll}', 
    r'\begin{tabular}{' + r'|p{0.17\textwidth}' * 5 + r'|}')
tbl_tex = tbl_tex.replace(r'\toprule', r'\hline')
tbl_tex = tbl_tex.replace(r'\midrule', r'\hline')
tbl_tex = tbl_tex.replace(r'\bottomrule', r'\hline')
tbl_tex = tbl_tex.replace(r'$ \\', r'$ \\\hline')
print(tbl_tex)

## Conditional ACSM/SMPS Generation

### Anecdotal Mass and Number Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_binsum = viz_data['n_bins']
    d_histbinsum = viz_data['d_histbins']
    n_binssmps = aero_cstscnv['n_binssmps']
    d_binssmpsum = aero_cstscnv['d_binssmps']
    acsm_species = aero_cstscnv['acsm_species']
    n_chemacsm = aero_cstscnv['n_chemacsm']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.acsmsmps.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/xorig']

    n_figrows, n_figcols = 1, 2
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):
            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]

            # The speciated mass data
            i_row, i_col = 0, 0
            n_rcns, i_rcns = 1, 0
            v_ynparr = v_ydata['m_chmprthst/orig']
            assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
            v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
            assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
            ax_idx = i_row * n_figcols + i_col
            for i_chem, chem in enumerate(chem_species):
                v_mpldatas[f'm_chmprthst.cndinp/{ax_idx}:{chem}:x'] = d_histbinsum
                v_mpldatas[f'm_chmprthst.cndinp/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            ax_ixyspec = []
            ax_ixyspec.append([1, d_histbinsum, 'n_prthst'])
            
            for i_row in range(n_figrows):
                for i_col, xvals, ycol in ax_ixyspec:
                    ax_idx = i_row * n_figcols + i_col
                    n_rcns, i_rcns = 1, 0
                    v_ynparr = v_ydata[f'{ycol}/orig']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndinp/{ax_idx}:orig:x'] = xvals
                    v_mpldatas[f'{ycol}.cndinp/{ax_idx}:orig:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        tag_text = f'({"ab"[i_figcol]})'
                        tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.9))
                # for i_figcol, ax in enumerate(axes[0]):
                #     print_axheader(ax, f'Input Label {i_figcol+1}', 'top', fontsize=14, fontweight='bold')

            figs_dict[f'{i_page}/label'] = fig

        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_xanec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Anecdotal Input Label Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_binssmps = aero_cstscnv['n_binssmps']
    d_binssmpsum = aero_cstscnv['d_binssmps']
    acsm_species = aero_cstscnv['acsm_species']
    n_chemacsm = aero_cstscnv['n_chemacsm']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.acsmsmps.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/label']

    n_figrows, n_figcols = 1, 2
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):
            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            i_acsmchems = np.arange(n_chemacsm + 1)
            ax_ixyspec = []
            ax_ixyspec.append([0, d_binssmpsum, 'n_smps'])
            ax_ixyspec.append([1, i_acsmchems,  'm_acsm'])
            
            for i_row in range(n_figrows):
                for i_col, xvals, ycol in ax_ixyspec:
                    ax_idx = i_row * n_figcols + i_col
                    n_rcns, i_rcns = 1, 0
                    v_ynparr = v_ydata[f'{ycol}/orig']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndlbl/{ax_idx}:orig:x'] = xvals
                    v_mpldatas[f'{ycol}.cndlbl/{ax_idx}:orig:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        tag_text = f'({"cd"[i_figcol]})'
                        tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.9))
                for i_figcol, ax in enumerate(axes[0]):
                    print_axheader(ax, f'Measurement {i_figcol+1}', 'top', fontsize=14, fontweight='bold')

            figs_dict[f'{i_page}/label'] = fig

        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_yanec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Anecdotal Generated Sample and Label Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbinsum = viz_data['d_histbins']
    n_binsum = viz_data['n_bins']
    len_wvum = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    eps_histbins = viz_data['eps_histbins']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']

    n_binssmps = aero_cstscnv['n_binssmps']
    d_binssmpsum = aero_cstscnv['d_binssmps']
    acsm_species = aero_cstscnv['acsm_species']
    n_chemacsm = aero_cstscnv['n_chemacsm']

    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.acsmsmps.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/gendiag']

    # vrnt_specs = exprm_info.get('vrnt_specs', None)
    vrnt_specs = [
        ('orig', 0, 'Original'), ('znrm', 0, 'Norm Lat'), 
        ('znrm', 1, 'Norm Lat'), ('znrm', 2, 'Norm Lat'), 
        ('znrm', 3, 'Norm Lat'), ('znrm', 4, 'Norm Lat')][1:]
    n_figrows, n_figcols = 4, len(vrnt_specs)
    n_figrows = v_mplcfgs['m_chmprthst.cndgen']['plt.subplots/nrows']
    v_mplcfgs['m_chmprthst.cndgen']['plt.subplots/ncols'] = n_figcols
    assert len(vrnt_specs) == n_figcols

    v_mplcfgs['ccn_cdf.cndgen']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['frznfrac_tmp.cndgen']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):
            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]
            
            # The speciated mass data
            i_row = 0
            for i_col in range(n_figcols):
                vrnt, i_rcns, ttl = vrnt_specs[i_col]
                n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                v_ynparr = v_ydata[f'm_chmprthst/{vrnt}']
                assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
                v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                ax_idx = i_row * n_figcols + i_col
                for i_chem, chem in enumerate(chem_species):
                    v_mpldatas[f'm_chmprthst.cndgen/{ax_idx}:{chem}:x'] = d_histbinsum
                    v_mpldatas[f'm_chmprthst.cndgen/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            i_acsmchems = np.arange(n_chemacsm + 1)
            ax_ixyspec = []
            ax_ixyspec.append([1, d_histbinsum, 'n_prthst'    , False])
            ax_ixyspec.append([2, d_binssmpsum, 'n_smps'      , True ])
            ax_ixyspec.append([3, i_acsmchems,  'm_acsm'      , True ])
            # ax_ixyspec.append([4, eps_histbins, 'ccn_cdf'     , False])
            # ax_ixyspec.append([5, temprtr_bins, 'frznfrac_tmp', False])
            # ax_ixyspec.append([6, len_wvum,     'qs_pop'      , False])
            # ax_ixyspec.append([7, len_wvum,     'qa_pop'      , False])
            
            for i_col in range(n_figcols):
                for i_row, xvals, ycol, plot_orig in ax_ixyspec:
                    ax_idx = i_row * n_figcols + i_col
                    vrnt, i_rcns, ttl = vrnt_specs[i_col]
                    n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                    v_ynparr = v_ydata[f'{ycol}/{vrnt}']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:rcnst:x'] = xvals
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:rcnst:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])

                    if not plot_orig: continue
                    n_rcns, i_rcns = 1, 0
                    v_ynparr = v_ydata[f'{ycol}/orig']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:orig:x'] = xvals
                    v_mpldatas[f'{ycol}.cndgen/{ax_idx}:orig:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        tag_text = f'({"fghi"[i_figrow]}$_{{{i_figcol + 1}}}$)'
                        # tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
                for i_figcol, ax in enumerate(axes[0]):
                    txthdr = f'Sample {i_figcol + 1}'
                    print_axheader(ax, txthdr, 'top', fontsize=14, fontweight='bold')

            figs_dict[f'{i_page}/diag'] = fig
    
        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_anec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Anecdotal Aerosol Diagnostic Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
show_mchmprt = False

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbinsum = viz_data['d_histbins']
    n_binsum = viz_data['n_bins']
    len_wvum = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    eps_histbins = viz_data['eps_histbins']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']

    n_binssmps = aero_cstscnv['n_binssmps']
    d_binssmpsum = aero_cstscnv['d_binssmps']
    acsm_species = aero_cstscnv['acsm_species']
    n_chemacsm = aero_cstscnv['n_chemacsm']

    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_files, n_page = i_sel.shape

    if (not exprmnt.startswith('cond.cont.acsmsmps.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/vardiag']
    
    n_figrows, n_figcols = 1 + show_mchmprt, 4
    for ycol in ['m_chmprthst', 'n_prthst', 'ccn_cdf', 'qs_pop', 'qa_pop', 'frznfrac_tmp']:
        v_mplcfgs[f'{ycol}.cndvar']['plt.subplots/nrows'] = n_figrows
        v_mplcfgs[f'{ycol}.cndvar']['plt.subplots/ncols'] = n_figcols

    v_mplcfgs['ccn_cdf.cndvar']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['frznfrac_tmp.cndvar']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):

            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]

            if show_mchmprt:
                # The speciated mass data
                i_row, vrnt = 0, 'znrm'
                n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                for i_col in range(n_figcols):
                    v_ynparr = v_ydata[f'm_chmprthst/{vrnt}']
                    assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                    assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                    ax_idx, i_rcns = i_row * n_figcols + i_col, i_col
                    for i_chem, chem in enumerate(chem_species):
                        v_mpldatas[f'm_chmprthst.cndvar/{ax_idx}:{chem}:x'] = d_histbinsum
                        v_mpldatas[f'm_chmprthst.cndvar/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]

            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            i_acsmchems = np.arange(n_chemacsm + 1)
            ax_ixyspec = []
            # ax_ixyspec.append([0, d_histbinsum, 'n_prthst'    , False])
            ax_ixyspec.append([0, eps_histbins, 'ccn_cdf'     , False])
            ax_ixyspec.append([1, len_wvum,     'qs_pop'      , False])
            ax_ixyspec.append([2, len_wvum,     'qa_pop'      , False])
            ax_ixyspec.append([3, temprtr_bins, 'frznfrac_tmp', False])
            
            i_row = 1 if show_mchmprt else 0
            for i_col, xvals, ycol, plot_orig in ax_ixyspec:
                ax_idx = i_row * n_figcols + i_col
                for vrnt, plot_ci in [('znrm', True), ('orig', False)]:
                    v_ynparr = v_ydata[f'{ycol}/{vrnt}']
                    n_rcns = exprm_info[f'n_rcns/test/{vrnt}']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    d_ydata = math.prod(v_ynparr2.shape[2:])
                    v_ynparr3 = v_ynparr2.reshape(n_seeds * n_snrt, n_rcns, d_ydata)
                    assert v_ynparr3.shape == (n_seeds * n_snrt, n_rcns, d_ydata)
                    v_ynparr4 = v_ynparr3[i_samp]
                    assert v_ynparr4.shape == (n_rcns, d_ydata)
                    v_ymean = np.median(v_ynparr4, axis=0)
                    assert v_ymean.shape == (d_ydata,)
                    v_ylow = np.quantile(v_ynparr4, q=0.0, axis=0)
                    assert v_ylow.shape == (d_ydata,)
                    v_yhigh = np.quantile(v_ynparr4, q=1.0, axis=0)
                    assert v_yhigh.shape == (d_ydata,)
                    v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:x'] = xvals
                    if plot_ci:
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y/mean'] = v_ymean
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y/low'] = v_ylow
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y/high'] = v_yhigh
                    else:
                        v_mpldatas[f'{ycol}.cndvar/{ax_idx}:{vrnt}:y'] = v_ymean
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for i_figrow in range(axes.shape[0]):
                    for i_figcol in range(axes.shape[1]):
                        ax = axes[i_figrow, i_figcol]
                        if show_mchmprt and (i_figrow == 0):
                            tag_text = f'(a$_{{{i_figcol + 1}}}$)'
                        elif show_mchmprt and (i_figrow > 0):
                            tag_text = f'({"bcdef"[i_figcol]})'
                        else:
                            tag_text = f'({"jklm"[i_figcol]})'
                        tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
                if show_mchmprt:
                    for i_figrow, ax in enumerate(axes[:, 0]):
                        txthdr = ('Speciated Mass', 'Diagnostics')[i_figrow] 
                        print_axheader(ax, txthdr, 'left', fontsize=14, pad=10, fontweight='bold')

            figs_dict[f'{i_page}/vardiag'] = fig

        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_anec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight', pad_inches=1/72)
        print(f'Finished writing {pdfpath}')

fig

### Conditional Collective Summary Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
colorspec = v_mplcfgs['colorspec']

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    if (not exprmnt.startswith('cond.cont.acsmsmps.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/diagsmry']

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata1 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=False, vrnt_trg='znrm')

    eps1, eps2, eps3 = v_mtrcdata1['eps1'], v_mtrcdata1['eps2'], v_mtrcdata1['eps3']
    i_eps1, i_eps2, i_eps3 = v_mtrcdata1['i_eps1'], v_mtrcdata1['i_eps2'], v_mtrcdata1['i_eps3']
    tmp1, tmp2, tmp3 = v_mtrcdata1['tmp1'], v_mtrcdata1['tmp2'], v_mtrcdata1['tmp3']
    i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata1['i_tmp1'], v_mtrcdata1['i_tmp2'], v_mtrcdata1['i_tmp3']
    wvl1, i_wvl1 = v_mtrcdata1['wvl1'], v_mtrcdata1['i_wvl1']

    # Assembling the plotting data
    v_mpldatas1 = dict()
    v_mpldatas1['nsmps_errhist/0:n_smpserr:y'] = v_mtrcdata1['n_smps/err'].ravel()
    v_mpldatas1['macsm_errhist/1:m_acsmerr:y'] = v_mtrcdata1['m_acsm/err'].ravel()
    for ax_idx, ax_id, i_eps in [(3, 'ccn_sctcnd1', i_eps1), (6, 'ccn_sctcnd3', i_eps3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:x'] = v_mtrcdata1['ccn_cdf/orig'][..., i_eps].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:y'] = v_mtrcdata1['ccn_cdf/znrm'][..., i_eps].ravel()
    v_mpldatas1['qs_popsctcnd/4:qs_pop:x'] = v_mtrcdata1['qs_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qs_popsctcnd/4:qs_pop:y'] = v_mtrcdata1['qs_pop/znrm'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcnd/7:qa_pop:x'] = v_mtrcdata1['qa_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcnd/7:qa_pop:y'] = v_mtrcdata1['qa_pop/znrm'][:, :, i_wvl1, :].ravel()
    for ax_idx, ax_id, i_tmp in [(5, 'frznfrac_sctcnd1', i_tmp1), (8, 'frznfrac_sctcnd3', i_tmp3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:x'] = v_mtrcdata1['frznfrac_tmp/orig'][..., i_tmp].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:y'] = v_mtrcdata1['frznfrac_tmp/znrm'][..., i_tmp].ravel()

    # Making the matplotlib calls
    v_mpldatas2 = hie2deep(v_mpldatas1, maxdepth=1)
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatas2):
        fig, axes = plot_mpl(data=v_mpldatas2[ax_id], fig=fig, axes=axes, 
            mplopts=v_mplcfgs[ax_id])
    axes1d = axes.ravel()

    # Adding the text boxes
    for ax_idx, textstr in [
        (3,  f'$s={100*eps1:0.2g}\%$'), 
        (6,  f'$s={100*eps3:0.2g}\%$'), 
        (4,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (7,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (5,  f'$T={{\\rm {tmp1:.0f}^{{\circ}}\\ C}}$'), 
        (8,  f'$T={{\\rm {tmp3:.0f}^{{\circ}}\\ C}}$')]:
        ax = axes1d[ax_idx]
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', usetex=True, bbox={'facecolor': 'none', 
            'edgecolor': 'none'})

    # Adding the x=y line to scatter plots
    for ax_idx in range(3, 9):
        ax = axes1d[ax_idx]
        x_lo = min(ax.get_xlim()[0], ax.get_ylim()[0])
        x_hi = max(ax.get_xlim()[1], ax.get_ylim()[1])
        x_id = np.linspace(x_lo, x_hi, 100)
        ax.plot(x_id, x_id, lw=1, ls='--', color=colorspec['blue'])
    
    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(axes.shape[0]):
            for i_figcol, ax in enumerate(axes[i_figrow]):
                ax_idx = i_figrow * axes.shape[1] + i_figcol
                tag_text = f'({"abcdefghi"[ax_idx]})'
                tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
    
    # Removing the empty axis at the end of the first row
    axes[0, 2].remove()
    
    _nicknm = '' if nicknm is None else f'_{nicknm}'
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_smry.pdf'
    pngpath = pdfpath[:-4] + '.png'
    fig.savefig(pngpath, dpi=200, bbox_inches='tight')
    print(f'Finished writing {pngpath}')

fig

### Conditional Diagnostic Calibration Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
colorspec = v_mplcfgs['colorspec']

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    if (not exprmnt.startswith('cond.cont.acsmsmps.')) or (split not in ('test',)):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/diagsmryqnt']

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata1 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=False, vrnt_trg='znrm')

    eps1, eps2, eps3 = v_mtrcdata1['eps1'], v_mtrcdata1['eps2'], v_mtrcdata1['eps3']
    i_eps1, i_eps2, i_eps3 = v_mtrcdata1['i_eps1'], v_mtrcdata1['i_eps2'], v_mtrcdata1['i_eps3']
    tmp1, tmp2, tmp3 = v_mtrcdata1['tmp1'], v_mtrcdata1['tmp2'], v_mtrcdata1['tmp3']
    i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata1['i_tmp1'], v_mtrcdata1['i_tmp2'], v_mtrcdata1['i_tmp3']
    wvl1, i_wvl1 = v_mtrcdata1['wvl1'], v_mtrcdata1['i_wvl1']

    # Assembling the plotting data
    v_mpldatas1 = dict()
    # v_mpldatas1['nsmps_errhist/0:n_smpserr:y'] = v_mtrcdata1['n_smps/err'].ravel()
    # v_mpldatas1['macsm_errhist/1:m_acsmerr:y'] = v_mtrcdata1['m_acsm/err'].ravel()
    for ax_idx, ax_id, i_eps in [(0, 'ccn_sctcndqnt1', i_eps1), (3, 'ccn_sctcndqnt3', i_eps3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:x'] = v_mtrcdata1['ccn_cdf/origraw'][..., i_eps].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:ccn_cdf:y'] = v_mtrcdata1['ccn_cdf/origcdf'][..., i_eps].ravel()
    v_mpldatas1['qs_popsctcndqnt/1:qs_pop:x'] = v_mtrcdata1['qs_pop/origraw'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qs_popsctcndqnt/1:qs_pop:y'] = v_mtrcdata1['qs_pop/origcdf'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcndqnt/4:qa_pop:x'] = v_mtrcdata1['qa_pop/origraw'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsctcndqnt/4:qa_pop:y'] = v_mtrcdata1['qa_pop/origcdf'][:, :, i_wvl1, :].ravel()
    for ax_idx, ax_id, i_tmp in [(2, 'frznfrac_sctcndqnt1', i_tmp1), (5, 'frznfrac_sctcndqnt3', i_tmp3)]:
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:x'] = v_mtrcdata1['frznfrac_tmp/origraw'][..., i_tmp].ravel()
        v_mpldatas1[f'{ax_id}/{ax_idx}:frznfrac_tmp:y'] = v_mtrcdata1['frznfrac_tmp/origcdf'][..., i_tmp].ravel()

    # Making the matplotlib calls
    v_mpldatas2 = hie2deep(v_mpldatas1, maxdepth=1)
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatas2):
        fig, axes = plot_mpl(data=v_mpldatas2[ax_id], fig=fig, axes=axes, 
            mplopts=v_mplcfgs[ax_id])
    axes1d = axes.ravel()

    # Adding the text boxes
    for ax_idx, textstr in [
        (0,  f'$s={100*eps1:0.2g}\%$'), 
        (3,  f'$s={100*eps3:0.2g}\%$'), 
        (1,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (4,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (2,  f'$T={{\\rm {tmp1:.0f}^{{\circ}}\\ C}}$'), 
        (5,  f'$T={{\\rm {tmp3:.0f}^{{\circ}}\\ C}}$')]:
        ax = axes1d[ax_idx]
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', usetex=True, bbox={'facecolor': 'none', 
            'edgecolor': 'none'})
    
    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(axes.shape[0]):
            for i_figcol, ax in enumerate(axes[i_figrow]):
                ax_idx = i_figrow * axes.shape[1] + i_figcol
                tag_text = f'({"abc"[i_figcol]}$_{{{i_figrow+1}}}$)'
                tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
        
        for i_figcol, col_ttl in enumerate(['Cloud Condensation', 'Optical Properties', 'Ice Nucleation']):
            print_axheader(axes[0, i_figcol], col_ttl, 'top', pad=8, fontsize=14, fontweight='bold')
    
    _nicknm = '' if nicknm is None else f'_{nicknm}'
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_smry.pdf'
    pngpath = pdfpath[:-4] + '.png'
    fig.savefig(pngpath, dpi=200, bbox_inches='tight')
    print(f'Finished writing {pngpath}')

fig

## Non-Conditional Trainings

### Paper Plot: Data Introduction

In [None]:
i_fig = 84
plt.ioff()

# Loading the original data
data_path = f'{data_dir}/02_masshist/03_bwchisamp.nc'
data_hist = load_histdata(data_path)
m_chmprthst2 = data_hist['m_chmprthst']
v_ydata2, aero_cstscnv2 = cnvrt_physunits({'m_chmprthst/orig': m_chmprthst2}, aero_csts)
n_snr, n_t = data_hist['n_snr'], data_hist['n_t']
n_chem, n_bins = aero_cstscnv2['n_chem'], aero_cstscnv2['n_bins']
m_chmprthst3 = v_ydata2['m_chmprthst/orig']
assert m_chmprthst3.shape == (n_snr, n_t, n_chem, n_bins)
m_chmprthst4 = m_chmprthst3.mean(axis=(0, 1))
assert m_chmprthst4.shape == (n_chem, n_bins)

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
for ax_id in ['paper.intro.data.hmap.specific', 'paper.intro.data.hmap.global']:
    v_mplcfgs[ax_id]['yticks/ticks'] = (np.arange(n_chem) + 0.5).tolist()
    v_mplcfgs[ax_id]['yticks/labels'] = chem_species
    v_mplcfgs[ax_id]['ylim'] = [0, n_chem]

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    i_sel = viz_data['i_sel']
    v_ydata = viz_data['v_ydata']
    n_chem = viz_data['n_chem']
    n_bins = viz_data['n_bins']
    d_histbins2 = viz_data['d_histbins']

    if (not exprmnt.startswith('trad.')) or (split not in ('train',)) or (arch not in ('mlp',)):
        continue

    i_samp = i_sel[0, 2] if i_sampintro is None else i_sampintro
    v_ydata2 = {key: np_flatten(val, 0, 2)[i_samp] for key, val in v_ydata.items()}

    m_chmprtsamp = v_ydata2[f'm_chmprthst/orig']

    assert d_histbins2.shape == (n_bins + 1,)
    assert m_chmprtsamp.shape == (n_chem, n_bins)
    assert m_chmprthst4.shape == (n_chem, n_bins)

    v_mpldatas = dict()
    # The speciated mass stack bar plot data
    ax_idx = 0
    for i_chem, chem in enumerate(chem_species):
        v_mpldatas[f'paper.intro.data.bar/{ax_idx}:{chem}:x'] = d_histbins2
        v_mpldatas[f'paper.intro.data.bar/{ax_idx}:{chem}:y'] = m_chmprtsamp[i_chem]

    # The speciated mass heatmap data
    ax_idx, ax_id = 1, 'paper.intro.data.hmap.specific'
    v_mpldatas[f'{ax_id}/{ax_idx}:mass:x'] = d_histbins2
    v_mpldatas[f'{ax_id}/{ax_idx}:mass:y'] = np.arange(n_chem + 1)
    v_mpldatas[f'{ax_id}/{ax_idx}:mass:c'] = m_chmprtsamp

    # The speciated mass heatmap data
    ax_idx, ax_id = 2, 'paper.intro.data.hmap.global'
    v_mpldatas[f'{ax_id}/{ax_idx}:mass:x'] = d_histbins2
    v_mpldatas[f'{ax_id}/{ax_idx}:mass:y'] = np.arange(n_chem + 1)
    v_mpldatas[f'{ax_id}/{ax_idx}:mass:c'] = m_chmprthst4
    
    v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)

    ########## Calling Matplotlib to Plot the Diagnostics ########## 
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatashie):
        fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
            axes=axes, mplopts=v_mplcfgs[ax_id])
    
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for ax, tag_txt in [(axes[0, 0], '(a)'), 
            (axes[0, 1], '(b)'), (axes[0, 2], '(c)')]:
            tag_axis(ax, tag_txt, fontsize=11, pad=[0.3, 0.8])
    
    pdfpath = f'{workdir}/{i_fig:02d}_intro_data.pdf'
    fig.savefig(pdfpath, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')
    
    i_fig += 1

fig

### Paper Plot: Example Diagnostic

In [None]:
i_fig = 85
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbins = viz_data['d_histbins']
    eps_histbins = viz_data['eps_histbins']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']
    len_wv = viz_data['len_wv']
    i_sel = viz_data['i_sel']

    if (not exprmnt.startswith('trad.')) or (split not in ('train',)) or (arch not in ('mlp',)):
        continue

    v_mplcfgs['paper.intro.diag.ccn_cdf']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['paper.intro.diag.frznfrac_tmp']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]
    
    assert i_sel.shape[0] == 1
    i_samplst = i_sel[0, :] if i_sampintro is None else [i_sampintro]

    figs_dict = dict()
    for ii_samp, i_samp in enumerate(i_samplst):
        v_ydata2 = {key: np_flatten(val, 0, 2)[i_samp] for key, val in v_ydata.items()}

        v_mpldatas = dict()
        # The speciated mass data
        ax_idx = 0
        for i_chem, chem in enumerate(chem_species):
            v_mpldatas[f'paper.intro.diag.m_chmprthst/{ax_idx}:{chem}:x'] = d_histbins
            v_mpldatas[f'paper.intro.diag.m_chmprthst/{ax_idx}:{chem}:y'] = v_ydata2[f'm_chmprthst/orig'][i_chem]
        
        # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
        ax_ixyspec = []
        ax_ixyspec.append([2, d_histbins,   'm_prthst'    ])
        ax_ixyspec.append([3, d_histbins,   'n_prthst'    ])
        ax_ixyspec.append([4, eps_histbins, 'ccn_cdf'     ])
        ax_ixyspec.append([5, len_wv,       'qs_pop'      ])
        ax_ixyspec.append([6, len_wv,       'qa_pop'      ])
        ax_ixyspec.append([7, temprtr_bins, 'frznfrac_tmp'])
        
        for ax_idx, xvals, ycol in ax_ixyspec:
            v_mpldatas[f'paper.intro.diag.{ycol}/{ax_idx}:orig:x'] = xvals
            v_mpldatas[f'paper.intro.diag.{ycol}/{ax_idx}:orig:y'] = np.squeeze(v_ydata2[f'{ycol}/orig'])
        
        v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)

        ########## Calling Matplotlib to Plot the Diagnostics ########## 
        fig, axes = None, None
        for ax_idx, ax_id in enumerate(v_mpldatashie):
            fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                axes=axes, mplopts=v_mplcfgs[ax_id])
        
        # Sharing the y-axis limits between the three axes
        axes_yshrd = axes[0, :3]
        ax_ylimlo = min(ax.get_ylim()[0] for ax in axes_yshrd)
        ax_ylimhi = max(ax.get_ylim()[1] for ax in axes_yshrd)
        for ax in axes_yshrd:
            ax.set_ylim(ax_ylimlo, ax_ylimhi)

        # Adding the top left april tag for axis identification
        with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
            for ax_idx, tag_char in enumerate('abcdefgh'):
                ax = axes.ravel()[ax_idx]
                tag_axis(ax, f'({tag_char})', fontsize=11, pad=(0.3, 0.6))

        # Removing the reconstructed mass axis
        axes[0, 1].remove()
    
        figs_dict[f'{ii_samp}/diag'] = fig

        ########## The Per Wave-Length Optical Plots ########## 
        for ycol in ('qscs_prt', 'qacs_prt')[:0]:
            v_mpldatas2 = dict()
            for i_wvl, wvl in enumerate(len_wv):
                v_mpldatas2[f'{i_wvl}:orig:x'] = d_histbins
                v_mpldatas2[f'{i_wvl}:orig:y'] = v_ydata2[f'{ycol}/orig'][i_wvl]
            mploptid = f'paper.intro.diag.{ycol}'
            fig2, axes2 = plot_mpl(data=v_mpldatas2, fig=None, axes=None, 
                mplopts=v_mplcfgs[mploptid])
            for i_wvl, (wvl, ax) in enumerate(zip(len_wv, axes2.ravel())):
                props = dict(facecolor='none', edgecolor='none')
                textstr = f'$\\lambda={len_wv[i_wvl]:.1f}\\ {{\\rm \\mu m}}$'
                ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                    usetex=True, verticalalignment='top', bbox=props)
            figs_dict[f'{ii_samp}/{mploptid}'] = fig2

        ########## The Per-Chemical Species Plots ########## 
        v_mpldatas3 = dict()
        for i_chem, chem in enumerate(chem_species):
            v_mpldatas3[f'{i_chem}:orig:x'] = d_histbins
            v_mpldatas3[f'{i_chem}:orig:y'] = v_ydata2[f'm_chmprthst/orig'][i_chem]
        for mploptid in ('paper.intro.diag.m_chmhst1', 'paper.intro.diag.m_chmhst2'):
            fig3, axes3 = plot_mpl(data=v_mpldatas3, fig=None, axes=None, 
                mplopts=v_mplcfgs[mploptid])
            for i_chem, (chem, ax) in enumerate(zip(chem_species, axes3.ravel())):
                props = dict(facecolor='none', edgecolor='none')
                ax.text(0.05, 0.95, chem, transform=ax.transAxes, fontsize=10,
                    usetex=True, verticalalignment='top', bbox=props)
            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for ax_idx, ax in enumerate(axes3.ravel()):
                    tag_txt = {'m_chmhst1': f'(i$_{{{ax_idx+1}}}$)', 
                        'm_chmhst2': f'(j$_{{{ax_idx+1}}}$)'}[mploptid.split('.')[-1]]
                    tag_axis(ax, tag_txt, fontsize=11, pad=(0.3, 0.6))

            figs_dict[f'{ii_samp}/{mploptid}'] = fig3
    
    pdfpath = f'{workdir}/{i_fig:02d}_intro_diag.pdf'
    with PdfPages(pdfpath) as pdf:
        for ii_samp, fig3 in figs_dict.items():
            pdf.savefig(figure=fig3, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')
    
    i_fig += 1

fig

### The Single Aerosol Population Original vs. Reconstruction Diagnostic Plots

In [None]:
i_fig = 38
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
v_mplcfgs['ccn_cdf']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
v_mplcfgs['frznfrac_tmp']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbins = viz_data['d_histbins']
    eps_histbins = viz_data['eps_histbins']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']
    len_wv = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']

    if (not exprmnt.startswith('trad.')) or (split not in ('train', 'test')):
        continue
    
    assert i_sel.shape[0] == 1

    figs_dict = dict()
    for ii_samp, i_samp in enumerate(i_sel[0]):
        v_ydata2 = {key: np_flatten(val, 0, 2)[i_samp] for key, val in v_ydata.items()}
        
        v_mpldatas = dict()
        # The speciated mass data
        for ax_idx, vrnt in enumerate(['orig', 'rcnst']):
            for i_chem, chem in enumerate(chem_species):
                v_mpldatas[f'm_chmprthst.{vrnt}/{ax_idx}:{chem}:x'] = d_histbins
                v_mpldatas[f'm_chmprthst.{vrnt}/{ax_idx}:{chem}:y'] = v_ydata2[f'm_chmprthst/{vrnt}'][i_chem]
        
        # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
        ax_ixyspec = []
        ax_ixyspec.append([2, d_histbins,   'm_prthst'    ])
        ax_ixyspec.append([3, d_histbins,   'n_prthst'    ])
        ax_ixyspec.append([4, eps_histbins, 'ccn_cdf'     ])
        ax_ixyspec.append([5, len_wv,       'qs_pop'      ])
        ax_ixyspec.append([6, len_wv,       'qa_pop'      ])
        ax_ixyspec.append([7, temprtr_bins, 'frznfrac_tmp'])
        
        for ax_idx, xvals, ycol in ax_ixyspec:
            for vrnt in ['orig', 'rcnst']:
                v_mpldatas[f'{ycol}/{ax_idx}:{vrnt}:x'] = xvals
                v_mpldatas[f'{ycol}/{ax_idx}:{vrnt}:y'] = np.squeeze(v_ydata2[f'{ycol}/{vrnt}'])
        
        v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)

        ########## Calling Matplotlib to Plot the Diagnostics ########## 
        fig, axes = None, None
        for ax_idx, ax_id in enumerate(v_mpldatashie):
            fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                axes=axes, mplopts=v_mplcfgs[ax_id])
        
        # Sharing the y-axis limits between the three axes
        axes_yshrd = axes[0, :3]
        ax_ylimlo = min(ax.get_ylim()[0] for ax in axes_yshrd)
        ax_ylimhi = max(ax.get_ylim()[1] for ax in axes_yshrd)
        for ax in axes_yshrd:
            ax.set_ylim(ax_ylimlo, ax_ylimhi)

        # Adding the small identification text box
        axes1d = axes.ravel()
        axtexts = [(1, f'SMRE={e_samps[i_samp].item():.2f}')]        
        for ax_idx, textstr in axtexts:
            ax = axes1d[ax_idx]
            props = dict(facecolor='none', edgecolor='none')
            ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                usetex=True, verticalalignment='top', bbox=props)

        # Adding the orignal and reconstruction labels
        with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
            for ax, text_str in [(axes[0, 0], 'Original'), (axes[0, 1], 'Reconstruction')]:
                print_axheader(ax, text_str, 'top', fontsize=10)

        # Adding the top left april tag for axis identification
        with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
            for ax_idx, tag_char in enumerate('abcdefgh'):
                ax = axes.ravel()[ax_idx]
                tag_axis(ax, f'({tag_char})', fontsize=11, pad=(0.3, 0.6))
    
        figs_dict[f'{ii_samp}/diag'] = fig

        ########## The Per Wave-Length Optical Plots ########## 
        for ycol in ('qscs_prt', 'qacs_prt')[:0]:
            v_mpldatas2 = dict()
            for i_wvl, wvl in enumerate(len_wv):
                for vrnt in ['orig', 'rcnst']:
                    v_mpldatas2[f'{i_wvl}:{vrnt}:x'] = d_histbins
                    v_mpldatas2[f'{i_wvl}:{vrnt}:y'] = v_ydata2[f'{ycol}/{vrnt}'][i_wvl]
            mploptid = ycol
            fig2, axes2 = plot_mpl(data=v_mpldatas2, fig=None, axes=None, 
                mplopts=v_mplcfgs[mploptid])
            for i_wvl, (wvl, ax) in enumerate(zip(len_wv, axes2.ravel())):
                props = dict(facecolor='none', edgecolor='none')
                textstr = f'$\\lambda={len_wv[i_wvl]:.1f}\\ {{\\rm \\mu m}}$'
                ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                    usetex=True, verticalalignment='top', bbox=props)
            figs_dict[f'{ii_samp}/{mploptid}'] = fig2

        ########## The Per-Chemical Species Plots ########## 
        v_mpldatas3 = dict()
        for i_chem, chem in enumerate(chem_species):
            for vrnt in ['orig', 'rcnst']:
                v_mpldatas3[f'{i_chem}:{vrnt}:x'] = d_histbins
                v_mpldatas3[f'{i_chem}:{vrnt}:y'] = v_ydata2[f'm_chmprthst/{vrnt}'][i_chem]
        for mploptid in ('m_chmhst1', 'm_chmhst2'):
            fig3, axes3 = plot_mpl(data=v_mpldatas3, fig=None, axes=None, 
                mplopts=v_mplcfgs[mploptid])
            for i_chem, (chem, ax) in enumerate(zip(chem_species, axes3.ravel())):
                props = dict(facecolor='none', edgecolor='none')
                ax.text(0.05, 0.95, chem, transform=ax.transAxes, fontsize=10,
                    usetex=True, verticalalignment='top', bbox=props)
            
            # Adding the top left april tag for axis identification
            with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
                for ax_idx, ax in enumerate(axes3.ravel()):
                    tag_txt = {'m_chmhst1': f'(i$_{{{ax_idx+1}}}$)', 
                        'm_chmhst2': f'(j$_{{{ax_idx+1}}}$)'}[mploptid]
                    tag_axis(ax, tag_txt, fontsize=11, pad=(0.3, 0.6))
            figs_dict[f'{ii_samp}/{mploptid}'] = fig3

    pdfpath = f'{workdir}/{i_fig:02d}_{arch}_{split}_anec.pdf'
    with PdfPages(pdfpath) as pdf:
        for ii_samp, fig3 in figs_dict.items():
            pdf.savefig(figure=fig3, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')
    
    i_fig += 1

fig

### The Generated Aerosol Population Diagnostic Plots

In [None]:
i_fig = 42
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbins = viz_data['d_histbins']
    eps_histbins = viz_data['eps_histbins']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']
    len_wv = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    n_clctn, n_prtr = i_sel.shape
    e_samps = viz_data['e_samps']

    if (not exprmnt.startswith('trad.')) or (split not in ('normal',)):
        continue

    v_mplcfgs['ccn_cdf.genr']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['frznfrac_tmp.genr']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]

    figs_dict, samp_cntr = dict(), 1
    for i_clctn in range(n_clctn):
        v_mpldatas = dict()
        i_selclctn = i_sel[i_clctn]
        # The speciated mass data
        for ii_samp, i_samp in enumerate(i_selclctn):
            ax_idx = ii_samp
            for i_chem, chem in enumerate(chem_species):
                v_mpldatas[f'm_chmprthst.genr/{ax_idx}:{chem}:x'] = d_histbins
                y_vals = np_flatten(v_ydata['m_chmprthst/genr'], 0, 2)[i_samp, i_chem]
                v_mpldatas[f'm_chmprthst.genr/{ax_idx}:{chem}:y'] = y_vals
        
        # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
        ax_ixyspec = []
        ax_ixyspec.append([1, d_histbins,   'n_prthst'    ])
        ax_ixyspec.append([2, eps_histbins, 'ccn_cdf'     ])
        ax_ixyspec.append([3, len_wv,       'qs_pop'      ])
        ax_ixyspec.append([4, len_wv,       'qa_pop'      ])
        ax_ixyspec.append([5, temprtr_bins, 'frznfrac_tmp'])
        for ii_samp, i_samp in enumerate(i_selclctn):
            for i_row, xvals, ycol in ax_ixyspec:
                ax_idx = i_row * n_prtr + ii_samp
                v_mpldatas[f'{ycol}.genr/{ax_idx}:genr:x'] = xvals
                y_vals = np_flatten(v_ydata[f'{ycol}/genr'], 0, 2)[i_samp]
                v_mpldatas[f'{ycol}.genr/{ax_idx}:genr:y'] = np.squeeze(y_vals)
        
        v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
        ########## Calling Matplotlib to Plot the Diagnostics ########## 
        fig, axes = None, None
        for ax_id in v_mpldatashie:
            fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                axes=axes, mplopts=v_mplcfgs[ax_id])
    
        figs_dict[f'{i_clctn}/diag'] = fig

        # Adding the small identification text box
        axes1d = axes.ravel()
        for ii_samp, i_samp in enumerate(i_selclctn):
            ax = axes[0, ii_samp]
            textstr = f'{100 * e_samps[i_samp].item():.1f}' + '\% Dust'
            props = dict(facecolor='none', edgecolor='none')
            ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                usetex=True, verticalalignment='top', bbox=props)
        
        # Adding the top left april tag for axis identification
        with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
            for i_figrow in range(axes.shape[0]):
                for i_figcol in range(axes.shape[1]):
                    ax = axes[i_figrow, i_figcol]
                    tag_text = f'({"abcdefgh"[i_figrow]}$_{{{i_figcol + 1}}}$)'
                    tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
            for i_figcol in range(axes.shape[1]):
                print_axheader(axes[0, i_figcol], f'Sample {samp_cntr}', 'top', 
                    fontsize=14, fontweight='bold')
                samp_cntr += 1

    pdfpath = f'{workdir}/{i_fig:02d}_{arch}_{split}_anec.pdf'
    with PdfPages(pdfpath) as pdf:
        for i_clctn, fig_clctn in figs_dict.items():
            pdf.savefig(figure=fig_clctn, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')
    
    i_fig += 1

fig

### The Sample Selection Criterion Histogram

In [None]:
# Reading the plot configs
plt.ioff()
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for viz_id, viz_data in hie2deep(viz_datas, sep=':', mindepth=2, maxdepth=2).items():
    exprmnt, arch = viz_id.split(':')
    if not exprmnt.startswith('trad.'):
        continue
    v_mpldata = {
        '0:sm_relerr:y': viz_data['train/e_samps'],
        '1:sm_relerr:y': viz_data['test/e_samps']}
    fig, axes = plot_mpl(data=v_mpldata, fig=None, axes=None, 
        mplopts=v_mplcfgs['sm_relerrhist'])

    for ax, split in [(axes[0,0], 'train'), (axes[0, 1], 'test')]:
        i_sel = viz_data[f'{split}/i_sel']
        e_samps = viz_data[f'{split}/e_samps']

        for i_samp in i_sel[0]:
            ax.axvline(e_samps[i_samp], lw=1, color='black', ls='--')
        
        props = dict(facecolor='none', edgecolor='none')
        for i_samp in i_sel[0, -2:]:
            ii_prcntl = (e_samps <= e_samps[i_samp]).mean()
            textstr = f'$q_{{{ii_prcntl*100:.2g}}}$'
            ax.text(e_samps[i_samp] + 0.01, 3.95, textstr, transform=ax.transData, 
                fontsize=12, usetex=True, verticalalignment='top', bbox=props)
        
        ax.text(0.71, 0.95, f'{arch.upper()} {split.capitalize()}', transform=ax.transAxes, 
            fontsize=10, usetex=True, verticalalignment='top', bbox=props)

fig

### The Collective Aerosol Diagnostics (Summary Plots)

In [None]:
i_fig = 44
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
colorspec = v_mplcfgs['colorspec']

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    aero_cstscnv = viz_data['aero_csts']

    if (not exprmnt.startswith('trad.')) or (split not in ('train', 'test')):
        continue

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata1 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=False)

    eps1, eps2, eps3 = v_mtrcdata1['eps1'], v_mtrcdata1['eps2'], v_mtrcdata1['eps3']
    i_eps1, i_eps2, i_eps3 = v_mtrcdata1['i_eps1'], v_mtrcdata1['i_eps2'], v_mtrcdata1['i_eps3']
    tmp1, tmp2, tmp3 = v_mtrcdata1['tmp1'], v_mtrcdata1['tmp2'], v_mtrcdata1['tmp3']
    i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata1['i_tmp1'], v_mtrcdata1['i_tmp2'], v_mtrcdata1['i_tmp3']
    wvl1, i_wvl1 = v_mtrcdata1['wvl1'], v_mtrcdata1['i_wvl1']

    # Assembling the plotting data
    v_mpldatas1 = dict()
    v_mpldatas1['ccn_errhist/0:ccn_cdferr:y'] = v_mtrcdata1['ccn_cdf/err'].ravel()
    for i, i_eps in [(3, i_eps1), (6, i_eps2), (9, i_eps3)]:
        v_mpldatas1[f'ccn_sct{i//3}/{i}:ccn_cdf:x'] = v_mtrcdata1['ccn_cdf/orig'][..., i_eps].ravel()
        v_mpldatas1[f'ccn_sct{i//3}/{i}:ccn_cdf:y'] = v_mtrcdata1['ccn_cdf/rcnst'][..., i_eps].ravel()
    v_mpldatas1['qs_errhist/1:qs_pop:y'] = v_mtrcdata1['logqs_pop/err'].ravel()
    v_mpldatas1['qa_errhist/4:qa_pop:y'] = v_mtrcdata1['logqa_pop/err'].ravel()
    v_mpldatas1['qs_popsct/7:qs_pop:x'] = v_mtrcdata1['qs_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qs_popsct/7:qs_pop:y'] = v_mtrcdata1['qs_pop/rcnst'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsct/10:qa_pop:x'] = v_mtrcdata1['qa_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['qa_popsct/10:qa_pop:y'] = v_mtrcdata1['qa_pop/rcnst'][:, :, i_wvl1, :].ravel()
    v_mpldatas1['logfrznfrac_errhist/2:logfrznfrac_tmp:y'] = v_mtrcdata1['logfrznfrac_tmp/err'].ravel()
    for i, i_tmp in [(5, i_tmp1), (8, i_tmp2), (11, i_tmp3)]:
        v_mpldatas1[f'frznfrac_sct{i//3}/{i}:frznfrac_tmp:x'] = v_mtrcdata1['frznfrac_tmp/orig'][..., i_tmp].ravel()
        v_mpldatas1[f'frznfrac_sct{i//3}/{i}:frznfrac_tmp:y'] = v_mtrcdata1['frznfrac_tmp/rcnst'][..., i_tmp].ravel()

    # Making the matplotlib calls
    v_mpldatas2 = hie2deep(v_mpldatas1, maxdepth=1)
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatas2):
        fig, axes = plot_mpl(data=v_mpldatas2[ax_id], fig=fig, axes=axes, 
            mplopts=v_mplcfgs[ax_id])
    axes1d = axes.ravel()

    # Adding the text boxes
    for ax_idx, textstr in [
        (3,  f'$s={100*eps1:0.2g}\%$'), 
        (6,  f'$s={100*eps2:0.2g}\%$'), 
        (9,  f'$s={100*eps3:0.2g}\%$'),
        (7,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (10, f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (5,  f'$T={{\\rm {tmp1:.0f}^{{\circ}}\\ C}}$'), 
        (8,  f'$T={{\\rm {tmp2:.0f}^{{\circ}}\\ C}}$'), 
        (11, f'$T={{\\rm {tmp3:.0f}^{{\circ}}\\ C}}$')]:
        ax = axes1d[ax_idx]
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', usetex=True, bbox={'facecolor': 'none', 
            'edgecolor': 'none'})

    # Adding the x=y line to scatter plots
    for ax_idx in [3, 6, 9, 7, 10, 5, 8, 11]:
        ax = axes1d[ax_idx]
        x_lo = min(ax.get_xlim()[0], ax.get_ylim()[0])
        x_hi = max(ax.get_xlim()[1], ax.get_ylim()[1])
        x_id = np.linspace(x_lo, x_hi, 100)
        ax.plot(x_id, x_id, lw=1, ls='--', color=colorspec['blue'])

    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(axes.shape[0]):
            for i_figcol, ax in enumerate(axes[i_figrow]):
                tag_text = f'({"abc"[i_figcol]}$_{{{i_figrow + 1}}}$)'
                tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))

        for i_figcol, col_ttl in enumerate(['Cloud Condensation', 'Optical Properties', 'Ice Nucleation']):
            print_axheader(axes[0, i_figcol], col_ttl, 'top', pad=8, fontsize=14, fontweight='bold')

    pdfpath = f'{workdir}/{i_fig:02d}_{arch}_{split}_smry.pdf'
    pngpath = pdfpath[:-4] + '.png'
    fig.savefig(pngpath, dpi=200, bbox_inches='tight')
    print(f'Finished writing {pngpath}')
    i_fig += 1
 
fig

### The Generative Diagnostic Calibration Histograms

In [None]:
i_fig = 48
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
colorspec = v_mplcfgs['colorspec']

for viz_id1, viz_data1 in hie2deep(viz_datas, sep=':', mindepth=2).items():
    exprmnt, arch = viz_id1.split(':')

    if (not exprmnt.startswith('trad.')) or (split not in ('test', 'normal')):
        continue
    
    v_mtrcdata1deep = dict()
    for split, viz_data2 in hie2deep(viz_data1, sep='/', maxdepth=1).items():
        if (split not in ('test', 'normal')):
            continue
        vrnt_trg = {'test': 'orig', 'normal': 'genr'}[split]

        v_ydata = viz_data2['v_ydata']
        aero_cstscnv = viz_data2['aero_csts']

        ######## Computing the Aerosol Diagnostic Metrics #########
        v_mtrcdata2 = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=False, vrnt_trg=vrnt_trg)

        eps1, eps2, eps3 = v_mtrcdata2['eps1'], v_mtrcdata2['eps2'], v_mtrcdata2['eps3']
        i_eps1, i_eps2, i_eps3 = v_mtrcdata2['i_eps1'], v_mtrcdata2['i_eps2'], v_mtrcdata2['i_eps3']
        tmp1, tmp2, tmp3 = v_mtrcdata2['tmp1'], v_mtrcdata2['tmp2'], v_mtrcdata2['tmp3']
        i_tmp1, i_tmp2, i_tmp3 = v_mtrcdata2['i_tmp1'], v_mtrcdata2['i_tmp2'], v_mtrcdata2['i_tmp3']
        wvl1, i_wvl1 = v_mtrcdata2['wvl1'], v_mtrcdata2['i_wvl1']

        v_mtrcdata1deep[split] = v_mtrcdata2
    v_mtrcdata1 = deep2hie(v_mtrcdata1deep)

    # Assembling the plotting data
    v_mpldatas1 = dict()

    v_mpldatas1[f'ccn_hist1/0:test:y'] = v_mtrcdata1['test/ccn_cdf/orig'][..., i_eps1].ravel()
    v_mpldatas1[f'ccn_hist1/0:normal:y'] = v_mtrcdata1['normal/ccn_cdf/genr'][..., i_eps1].ravel()
    v_mpldatas1[f'ccn_hist2/3:test:y'] = v_mtrcdata1['test/ccn_cdf/orig'][..., i_eps3].ravel()
    v_mpldatas1[f'ccn_hist2/3:normal:y'] = v_mtrcdata1['normal/ccn_cdf/genr'][..., i_eps3].ravel()

    v_mpldatas1[f'qs_hist/1:test:y'] = v_mtrcdata1['test/qs_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1[f'qs_hist/1:normal:y'] = v_mtrcdata1['normal/qs_pop/genr'][:, :, i_wvl1, :].ravel()
    v_mpldatas1[f'qa_hist/4:test:y'] = v_mtrcdata1['test/qa_pop/orig'][:, :, i_wvl1, :].ravel()
    v_mpldatas1[f'qa_hist/4:normal:y'] = v_mtrcdata1['normal/qa_pop/genr'][:, :, i_wvl1, :].ravel()

    v_mpldatas1[f'frznfrac_hist1/2:test:y'] = v_mtrcdata1['test/frznfrac_tmp/orig'][..., i_tmp1].ravel()
    v_mpldatas1[f'frznfrac_hist1/2:normal:y'] = v_mtrcdata1['normal/frznfrac_tmp/genr'][..., i_tmp1].ravel()
    v_mpldatas1[f'frznfrac_hist2/5:test:y'] = v_mtrcdata1['test/frznfrac_tmp/orig'][..., i_tmp3].ravel()
    v_mpldatas1[f'frznfrac_hist2/5:normal:y'] = v_mtrcdata1['normal/frznfrac_tmp/genr'][..., i_tmp3].ravel()

    # Making the matplotlib calls
    v_mpldatas2 = hie2deep(v_mpldatas1, maxdepth=1)
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatas2):
        fig, axes = plot_mpl(data=v_mpldatas2[ax_id], fig=fig, axes=axes, 
            mplopts=v_mplcfgs[ax_id])
    axes1d = axes.ravel()

    # Adding the text boxes
    for ax_idx, textstr in [
        (0,  f'$s={100*eps1:0.2g}\%$'), 
        (3,  f'$s={100*eps3:0.2g}\%$'),
        (1,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (4,  f'${{\\rm \lambda={wvl1*1e6:.1f}\\ \mu m}}$'), 
        (2,  f'$T={{\\rm {tmp1:.0f}^{{\circ}}\\ C}}$'), 
        (5,  f'$T={{\\rm {tmp3:.0f}^{{\circ}}\\ C}}$')]:
        ax = axes1d[ax_idx]
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', usetex=True, bbox={'facecolor': 'none', 
            'edgecolor': 'none'})

    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(axes.shape[0]):
            for i_figcol, ax in enumerate(axes[i_figrow]):
                tag_text = f'({"abc"[i_figcol]}$_{{{i_figrow + 1}}}$)'
                tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))

        for i_figcol, col_ttl in enumerate(['Cloud Condensation', 'Optical Properties', 'Ice Nucleation']):
            print_axheader(axes[0, i_figcol], col_ttl, 'top', pad=8, fontsize=14, fontweight='bold')

    akws = {'xycoords': 'axes fraction', 'textcoords': 'axes fraction', 
        'fontsize': 9, 'bbox/pad': 90, 'bbox/facecolor': 'none', 'bbox/edgecolor': 'none', 
        'arrowprops/arrowstyle': '->', 'arrowprops/connectionstyle': 'arc3,rad=0.1'}
    
    annot_kws = dict()
    annot_kws['test'] = {'text': 'Test Data', 'xytext': [0.35, 0.72], 
        'xy': [0.16, 0.72], 'arrowprops/relpos': [0.0, 0.5], **akws}
    annot_kws['normal'] = {'text': 'Generated', 'xytext': [0.45, 0.53], 
        'xy': [0.20, 0.50], 'arrowprops/relpos': [0.0, 0.5], **akws}
    for key, annot_kw in annot_kws.items():
        axes[0, 0].annotate(**hie2deep(annot_kw))

    pdfpath = f'{workdir}/{i_fig:02d}_{arch}_gencal.pdf'
    fig.savefig(pdfpath, dpi=200, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')
    i_fig += 1
 
fig

### Collecting the Data For MDS Scatter Plots

In [None]:
vizsct_datas = dict()
for arch in ['mlp', 'cnn']:
    ############## Collecting the plotting data ###############
    fpidx = {'mlp': '11_mlphist.0.0', 'cnn': '12_cnnhist.0.0'}[arch]
    rio = resio(fpidx=fpidx, resdir='./20_vaehist/results', driver='sec2')
    rio_dtypes = rio.dtypes()
    rio_keys = [key.split(':', 1)[-1] for key in rio_dtypes]
    n_seeds = 12

    v_mdsdataraw = dict()
    for key in rio_keys:
        if not any(key.startswith(f'var/eval/mds/{eid}:') for eid in ['identity', 'polyhstbal']):
            continue
        val = rio(key)
        n_pnts, n_rcns, d_repr = val.shape[1:]
        assert n_rcns == 1
        assert val.shape == (n_epoch * n_seeds, n_pnts, n_rcns, d_repr)
        val_lastepoch = val.reshape(n_epoch, n_seeds, n_pnts, d_repr)[-1]
        assert val_lastepoch.shape == (n_seeds, n_pnts, d_repr)
        v_mdsdataraw[key] = val_lastepoch

    # Example:
    #   v_mdsdataraw = {
    #     'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/normal/genr/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/test/orig/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/test/rcnst/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/train/orig/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/train/rcnst/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/normal/genr/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/mu/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/phi/data': 'np.randn(12, 5000, 1)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/sig/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/test/rcnst/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/mu/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/phi/data': 'np.randn(12, 5000, 1)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/sig/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1zpcalr:z:z/train/rcnst/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1ztsne:z:z/normal/genr/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1ztsne:z:z/test/orig/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1ztsne:z:z/test/rcnst/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1ztsne:z:z/train/orig/pnts/data': 'np.randn(12, 5000, 2)',
    #     'var/eval/mds/identity:1x1ztsne:z:z/train/rcnst/pnts/data': 'np.randn(12, 5000, 2)'
    #   }

    vizsct_datas[f'{arch}/v_mdsdataraw'] = v_mdsdataraw


### The Collective MDS Scatter Plots Points

In [None]:
i_fig = 50
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
    
figs_dict = dict()
for arch, viz_data in hie2deep(vizsct_datas, maxdepth=1).items():
    v_mdsdataraw = viz_data['v_mdsdataraw']
    
    ########################### Collecting the plotting data ############################
    v_mpldatas = dict()

    ###### The Z TSNE individual-split data ######
    for ax_idx, split, vrnt in [(0, 'train', 'orig'), (1, 'test', 'orig'), (2, 'normal', 'genr')]:
        pats_rnm2 = {
            f'var/eval/mds/identity:1x1ztsne:z:z/{split}/{vrnt}/pnts/data': 
            f'ztsne/{ax_idx}:{split}/{vrnt}:pnts'}
        ax_data = get_subdictrnmd(v_mdsdataraw, pats_rnm2)
        v_mpldatas.update(ax_data)

    ###### The Z TSNE combined-splits data ######
    ax_idx = 3
    pats_rnm1 = {
        'ztsne/{ax_idx}:{split}/{vrnt}:{vrepr}': 
        f'ztsne/{ax_idx}:{{split}}/{{vrnt}}:{{vrepr}}'}
    ax_data = get_subdictrnmd(v_mpldatas, pats_rnm1)
    v_mpldatas.update(ax_data)

    ###### The Z PCA individual-split data ######
    for ax_idx, split, vrnt in [(4, 'train', 'orig'), (5, 'test', 'orig'), (6, 'normal', 'genr')]:
        pats_rnm4 = {
            f'var/eval/mds/identity:1x1zpcalr:z:z/{split}/{vrnt}/{{vrepr}}/data':
            f'zpca/{ax_idx}:{split}/{vrnt}:{{vrepr}}'}
        ax_data = get_subdictrnmd(v_mdsdataraw, pats_rnm4)
        v_mpldatas.update(ax_data)

    ###### The Z PCA combined-splits data ######
    ax_idx = 7
    pats_rnm3 = {
        'zpca/{ax_idx}:{split}/{vrnt}:{vrepr}': 
        f'zpca/{ax_idx}:{{split}}/{{vrnt}}:{{vrepr}}'}
    ax_data = get_subdictrnmd(v_mpldatas, pats_rnm3)
    v_mpldatas.update(ax_data)

    ##### The X TSNE individual-split data #####
    for ax_idx, split in [(8, 'train'), (9, 'test'), (10, 'normal')]:
        pats_rnm6 = {
            f'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/{split}/{{vrnt}}/pnts/data': 
            f'xtsne/{ax_idx}:{split}/{{vrnt}}:pnts'}
        ax_data = get_subdictrnmd(v_mdsdataraw, pats_rnm6)
        v_mpldatas.update(ax_data)

    ###### The X TSNE combined-splits data ######
    ax_idx = 11
    pats_rnm5 = {
        'xtsne/{ax_idx}:{split}/{vrnt}:{vrepr}': 
        f'xtsne/{ax_idx}:{{split}}/{{vrnt}}:{{vrepr}}'}
    ax_data = get_subdictrnmd(v_mpldatas, pats_rnm5)
    v_mpldatas.update(ax_data)

    # Restricting the data to the first seed
    v_mpldatas2 = {key: val[0] for key, val in v_mpldatas.items()}

    # Splitting the various plotting configs
    v_mpldatas3 = hie2deep(v_mpldatas2, maxdepth=1)

    ################### Calling Matplotlib to Plot the Scatter Points ################### 
    fig, axes = None, None
    for ax_id in v_mpldatas3:
        fig, axes = plot_mpl(data=v_mpldatas3[ax_id], fig=fig, 
            axes=axes, mplopts=v_mplcfgs[ax_id])

    # Adding the top left april tag for axis identification
    n_figrows, n_figcols = axes.shape
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for ax_idx, ax in enumerate(axes.ravel()):
            i_figrow, i_figcol = ax_idx // n_figcols, ax_idx % n_figcols
            tag_text = f'({"abcd"[i_figcol]}$_{{{i_figrow + 1}}}$)'
            tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
    
        for i_figrow, row_ttl in enumerate(['Latent t-SNE', 'Latent PCA', 'Mass t-SNE']):
            print_axheader(axes[i_figrow, 0], row_ttl, 'left', 
                fontsize=14, fontweight='bold')
        
        for i_figcol, col_ttl in enumerate(['Train', 'Test', 'Generated', 'All']):
            print_axheader(axes[0, i_figcol], col_ttl, 'top', 
                fontsize=14, fontweight='bold')

    # Making sure the legend labels have full alpha!
    leg, = fig.legends
    for handle in leg.legend_handles:
        handle.set_alpha(1.0)

    figpath1 = f'{workdir}/{i_fig:02d}_{arch}_mds.pdf'
    fig.savefig(figpath1, bbox_inches='tight')
    print(f'Finished writing {figpath1}')

    figpath2 = f'{workdir}/{i_fig:02d}_{arch}_mds.png'
    fig.savefig(figpath2, dpi=200, bbox_inches='tight')
    print(f'Finished writing {figpath2}')
    i_fig += 1

fig

### OVAT Parameter Sweep Plots

In [None]:
recmpl_data = False

fpidx_infos1, resdir, rioall = dict(), None, None
if recmpl_data:
    resdir = f'{results_dir}/02_adhoc'
    for arch, fpidxall, n_seeds in [('mlp', '11_mlphist.*.*', 16), ('cnn', '12_cnnhist.*.*', 10)]:
        rioarch = resio(fpidx=fpidxall, resdir=resdir, driver='sec2')
        for fpidx in rioarch.rslv_fpidx(fpidxall):
            fpidx_infos1[fpidx] = {'arch': arch, 'n_seeds': n_seeds, 'n_snrt': 1000,
                'n_epoch': 2, 'n_rcns': 1}

In [None]:
split, vrnts = 'test', ['orig', 'rcnst']

stdfs_dict = defaultdict(list)
fpidx_infos2 = dict()
for i_fpidx, (fpidx, fpidx_info) in enumerate(fpidx_infos1.items()):
    print(f'Working on {fpidx}.   ', end='')
    if (i_fpidx + 1) % 5 == 0:
        print('', flush=True)
    rio2 = resio(fpidx=fpidx, resdir=resdir, driver='sec2')

    n_seeds, n_epoch, i_epoch = fpidx_info['n_seeds'], fpidx_info['n_epoch'], -1
    n_snrt, n_rcns = fpidx_info['n_snrt'], fpidx_info['n_rcns']

    ##########################################################################################
    ############################# Loading the Data From the Disk #############################
    ##########################################################################################
    ycols = ['m_chmprthst', 'n_prthst', 'ccn_cdf', 'qs_prt', 'qscs_prt', 'qs_pop',
        'qa_prt', 'qacs_prt', 'qa_pop', 'frznfrac_tmp', 'logfrznfrac_tmp']

    v_ydataraw = dict()
    for ycol, vrnt in product(ycols, vrnts):
        n_ychnls, n_ylen = ycol2dims[ycol]        
        y_nparr = rio2(f'var/eval/raw/yaero:x:{ycol}/{split}/{vrnt}/pnts/data')
        assert y_nparr.shape == (n_epoch * n_seeds, n_snrt, n_rcns, n_ychnls, n_ylen)
        y_nparr2 = y_nparr.reshape(n_epoch, n_seeds, n_snrt, n_rcns, n_ychnls, n_ylen)
        assert y_nparr2.shape == (n_epoch, n_seeds, n_snrt, n_rcns, n_ychnls, n_ylen)
        v_ydataraw[f'{ycol}/{vrnt}'] = y_nparr2[i_epoch]


    ##########################################################################################
    ######################## Computing the Aerosol Diagnostic Metrics ########################
    ##########################################################################################
    
    ########### Data Cleaning and Unit Conversions ############
    v_ydata, aero_cstscnv = cnvrt_physunits(v_ydataraw, aero_csts)

    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=True)

    ##########################################################################################
    ################################### Compiling the Data ###################################
    ##########################################################################################
    df_cols = {col: val for col, val in v_mtrcdata.items() 
        if col.endswith('/err') and (val is not None)}

    stdf1 = rio2('stat')
    epoch_max = stdf1['epoch'].max()
    idx_lastep = stdf1['epoch'] == epoch_max
    stdf1 = stdf1[idx_lastep].reset_index(drop=True)
    assert stdf1.shape[0] == n_seeds

    hpdf1 = pd.DataFrame(rio2.load_key('hp'))
    hpdf1.insert(0, 'fpidx', fpidx)
    assert hpdf1.shape[0] == n_seeds

    stdf2 = pd.DataFrame(df_cols)
    stdf2.insert(0, 'fpidx', fpidx)
    assert stdf2.shape[0] == n_seeds

    stdf3 = pd.concat([stdf1, stdf2], axis=1)
    fpidx_info2 = fpidx_info.copy()
    fpidx_info2['hpdf'] = hpdf1
    fpidx_info2['stdf'] = stdf3
    fpidx_infos2[fpidx] = fpidx_info2


In [None]:
cache_path = f'{workdir}/91_pltdata.h5'

if recmpl_data:
    arch2dfs = dict()
    for fpidx, fpidx_info in fpidx_infos2.items():
        arch = fpidx_info['arch']
        hpdf3 = fpidx_info['hpdf'].copy(deep=True)
        if 'fpidx' not in hpdf3.columns:
            hpdf3.insert(0, 'fpidx', fpidx)
        arch2dfs[f'{arch}:hp/{fpidx}'] = hpdf3
        arch2dfs[f'{arch}:stat/{fpidx}'] = fpidx_info['stdf']

    save_data = dict()
    for archkey, fpidx2dfs in hie2deep(arch2dfs).items():
        arch, key = archkey.split(':')
        key_dfslst = list(fpidx2dfs.values())
        key_df = pd.concat(key_dfslst, axis=0, ignore_index=True)
        save_data[f'{arch}/{key}'] = key_df

    save_h5datav2(save_data, cache_path)

In [None]:
load_data = load_h5data(cache_path)
hpdf_aggs, stdf_aggs, ovat_grpngs = dict(), dict(), dict()
for arch in ['mlp', 'cnn']:
    hpdf_arch1 = load_data[f'{arch}/hp']
    hpdf_arch1.insert(2, 'arch', arch)
    stdf_arch1 = load_data[f'{arch}/stat']

    assert (hpdf_arch1['fpidx'] == stdf_arch1['fpidx']).all()

    # hpdf_arch1.insert(0, 'fpidx', stdf_arch1['fpidx'])
    # hpdf_arch2 = hpdf_arch1.set_index('fpidx').loc[stdf_arch1['fpidx']].reset_index()
    hpdf_arch2 = drop_unqcols(hpdf_arch1)
    hpdf_arch2.insert(1, 'fpidxgrp', hpdf_arch2['fpidx'])
    stdf_arch2 = stdf_arch1.drop(columns='fpidx')

    # Aggregating the data
    aggcfg = dict(type='bootstrap', n_boot=1000, q=[5, 95], stat='mean', device='cpu')
    agg_data = get_aggdf(hpdf_arch2, stdf_arch2, xcol='epoch', 
        huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_archagg, stdf_archagg = agg_data['hpdf'], agg_data['stdf']

    ovat_grps = get_ovatgrps(hpdf_archagg)
    ovat_grpng = dict()
    for ovat_grp in ovat_grps:
        hpdf_grp1 = hpdf_archagg[hpdf_archagg['fpidx'].isin(ovat_grp)]
        hpdf_grp2 = drop_unqcols(hpdf_grp1)
        ovat_cols = hpdf_grp2.drop(columns=['fpidx', 'fpidxgrp']).columns.tolist()
        ovat_col = ovat_cols[0]
        ovat_grpng[ovat_col] = ovat_grp

    # Creating combined performance metrics
    ycol_fmt = 'perf/reconstruction/polyhst/test/orig/polyhst/test/rcnst/same/0/x/{u_node}/sqerr/mean/{stat}'
    u_nodes = ['m_chmprtnrm', 'm_prthstmag', 'n_prthstnrm']
    for stat in ['mean', 'low', 'high']:
        ycols_agg = [ycol_fmt.format(u_node=u_node, stat=stat) for u_node in u_nodes]
        stdf_archagg[f'rcnst_net/err/{stat}'] = stdf_archagg[ycols_agg].values.sum(axis=1)

    for stat in ['mean', 'low', 'high']:
        opt_errs = stdf_archagg[[f'qs_pop/err/{stat}', f'qa_pop/err/{stat}']]
        stdf_archagg[f'qsa_pop/err/{stat}'] = opt_errs.values.mean(axis=1)

    ycol_fmt = 'perf/urealism/urealsim/normal/genr/urealsim/test/orig/same/0/x/{u_node}/uslcwass:2/mean/{stat}'
    for stat in ['mean', 'low', 'high']:
        ycols_agg = [ycol_fmt.format(u_node=u_node, stat=stat) for u_node in u_nodes]
        stdf_archagg[f'realism/err/{stat}'] = stdf_archagg[ycols_agg].values.sum(axis=1)
    
    hpdf_aggs[arch] = hpdf_archagg
    stdf_aggs[arch] = stdf_archagg
    ovat_grpngs[arch] = ovat_grpng

In [None]:
i_fig = 92
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

x_pgcols = [
    ['nn/enc/d_ltnt', 'cri/kl/sig/w', 'cri/kl/mu/w'],
    ['nn/pp/magexp', 'nn/pp/nrmexp', 'nn/pp/cntexp']]
y_cols = ['rcnst_net/err', 'realism/err', 'ccn_cdf/err', 
    'qsa_pop/err', 'logfrznfrac_tmp/err']

hp_dflts = hpdf_aggs['mlp'].iloc[0].to_dict()

figs_dict = dict()
for i_figpage, x_cols in enumerate(x_pgcols):
    fig, axes = None, None
    n_figrows, n_figcols = len(y_cols), len(x_cols)
    for i_figrow, y_col in enumerate(y_cols):
        for i_figcol, x_col in enumerate(x_cols):
            ax_idx = i_figrow * n_figcols + i_figcol

            v_mpldatas = dict()
            for arch in hpdf_aggs:
                hpdf_agg = hpdf_aggs[arch]
                stdf_agg = stdf_aggs[arch]
                ovat_grpng = ovat_grpngs[arch]

                # Collecting the data
                if x_col not in ovat_grpng:
                    continue
                fpidxs_swp = ovat_grpng[x_col]
                idx_swp = hpdf_agg['fpidx'].isin(fpidxs_swp)
                hpdf_swp = hpdf_agg[idx_swp]
                statdf_swp = stdf_agg[idx_swp]
                hpdf_swp = hpdf_swp.sort_values(by=x_col)
                statdf_swp = statdf_swp.loc[hpdf_swp.index]
                hpdf_swp = hpdf_swp.reset_index(drop=True)
                statdf_swp = statdf_swp.reset_index(drop=True)
                hpstdf_swp = pd.concat([hpdf_swp, statdf_swp], axis=1)

                v_mpldatas[f'{ax_idx}:{arch}:x'] = hpstdf_swp[x_col].values
                v_mpldatas[f'{ax_idx}:{arch}:y/mean'] = hpstdf_swp[f'{y_col}/mean'].values
                v_mpldatas[f'{ax_idx}:{arch}:y/low'] = hpstdf_swp[f'{y_col}/low'].values
                v_mpldatas[f'{ax_idx}:{arch}:y/high'] = hpstdf_swp[f'{y_col}/high'].values

            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            v_mplcfg = v_mplcfgs['paper.hpstudy.figure'].copy()
            v_mplcfg.update(v_mplcfgs[f'paper.hpstudy.xaxis.{x_col.replace("/", ".")}'])
            v_mplcfg.update(v_mplcfgs[f'paper.hpstudy.yaxis.{y_col.replace("/", ".")}'])

            fig, axes = plot_mpl(data=v_mpldatas, fig=fig, 
                axes=axes, mplopts=v_mplcfg)

    akws = {'xycoords': 'axes fraction', 'textcoords': 'axes fraction', 
        'fontsize': 9, 'bbox/pad': 90, 'bbox/facecolor': 'none', 'bbox/edgecolor': 'none', 
        'arrowprops/arrowstyle': '->', 'arrowprops/connectionstyle': 'arc3,rad=0.3'}
    annot_kws = dict()
    annot_kws['mlp'] = {'text': 'MLP', 'xytext': [0.55, 0.4], 
        'xy': [0.4, 0.25], 'arrowprops/relpos': [0.0, 0.5], **akws}
    annot_kws['cnn'] = {'text': 'CNN', 'xytext': [0.03, 0.03], 
        'xy': [0.4, 0.16], 'arrowprops/relpos': [1.0, 0.5], **akws}
    annot_kws = hie2deep(deep2hie(annot_kws))

    if i_figpage == 0:
        with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
            for lbl, antkws in annot_kws.items():
                axes[1, 0].annotate(**antkws)

    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(n_figrows):
            for i_figcol, tag_char in enumerate('abc'):
                ax = axes[i_figrow, i_figcol]
                tag_axis(ax, f'({tag_char}$_{{{i_figrow+1}}}$)', 
                    fontsize=10, pad=(0.3, 0.6))
    
    # Adding the base hp identification rectangle boxes
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(axes.shape[0]):
            for i_figcol, ax in enumerate(axes[i_figrow]):
                x_col = x_cols[i_figcol]
                mark_rect(ax, x_boxdata=hp_dflts[x_col], y_boxdata=None, 
                    w_boxfrac=0.1, h_boxfrac=0.15, label='cnn', linewidth=1)

    figs_dict[f'{i_figpage}'] = fig

pdfpath = f'{workdir}/{i_fig:02d}_hpstudy.pdf'
with PdfPages(pdfpath) as pdf:
    for i_fig, fig3 in figs_dict.items():
        pdf.savefig(figure=fig3, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

fig

## OIN Mass Quantile Conditional Training

### The Generated OIN Histograms - Polished for Paper

In [None]:
plt.ioff()
# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

v_mpldatasfrac, v_mpldatasabs = dict(), dict()
vrnt_specs = {'null': (0, 'Traditional CVAE'), 'pow1': (1, 'Wasserstein-Regularized CVAE (Ours)')}
i_savefigs = []

acc_vrnts = dict()
for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    if (split != 'normal') or not exprmnt.startswith('cond.mqnt.'):
        continue
    
    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/histcmprd']
    i_savefigs.append(i_savefig)

    if nicknm not in vrnt_specs:
        continue
    
    chem_species = viz_data['chem_species']
    n_rcns, i_oin = 1, chem_species.index('Dust')

    for vrnt, mfraclo, mfrachi in [
        ('gaus0', 0.00, 0.04),
        ('gaus1', 0.04, 0.55), 
        ('gaus2', 0.55, 0.84), 
        ('gaus3', 0.84, 1.00)]:
        v_ydata = viz_data['v_ydata']
        v_ynparr = v_ydata[f'm_chmprthst/{vrnt}']
        n_bins2 = v_ynparr.shape[-1]
        assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_bins2)

        # The Dust absolute mass of generated samples
        m_oinabs1 = v_ynparr.sum(-1)[:, :, 0, i_oin]
        assert m_oinabs1.shape == (n_seeds, n_snrt)

        m_oinabs2 = m_oinabs1[0]
        assert m_oinabs2.shape == (n_snrt,)

        # The OIN mass fraction of generated samples
        m_chmfracs = v_ynparr.sum(-1, keepdims=True) / v_ynparr.sum((-1, -2), keepdims=True)
        assert m_chmfracs.shape == (n_seeds, n_snrt, n_rcns, n_chem, 1)
        m_oinfrac1 = m_chmfracs[:, :, 0, i_oin, 0]
        assert m_oinfrac1.shape == (n_seeds, n_snrt)
        m_oinfrac2 = m_oinfrac1[0]
        assert m_oinfrac2.shape == (n_snrt,)

        i_ax, ax_ttl = vrnt_specs[nicknm]
        v_mpldatasabs[f'cond.mabshist.cmprd/{i_ax}:{vrnt}:y'] = m_oinabs1.ravel()
        v_mpldatasfrac[f'cond.mfrachist.cmprd/{i_ax}:{vrnt}:y'] = m_oinfrac1.ravel()

        acc_vrnt = np.logical_and(m_oinfrac1 >= mfraclo, m_oinfrac1 < mfrachi).mean(axis=1)
        assert acc_vrnt.shape == (n_seeds,)

        acc_vrnts[f'{viz_id}/{vrnt}'] = acc_vrnt
        
v_mpldatasabshie = hie2deep(v_mpldatasabs, maxdepth=1)
v_mpldatasfrachie = hie2deep(v_mpldatasfrac, maxdepth=1)

########## Calling Matplotlib to Plot the Diagnostics ########## 
fig, axes = None, None
for ax_id in v_mpldatasfrachie:
    fig, axes = plot_mpl(data=v_mpldatasfrachie[ax_id], fig=fig, 
        axes=axes, mplopts=v_mplcfgs[ax_id])

# Adding the top left april tag for axis identification
with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
    for i_figrow in range(axes.shape[0]):
        for i_figcol in range(axes.shape[1]):
            i_figax = i_figrow * axes.shape[1] + i_figcol
            ax = axes[i_figrow, i_figcol]
            tag_text = f'({"ab"[i_figcol]})'
            tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.9))

    for nicknm, (i_ax, ax_ttl) in vrnt_specs.items():
        ax = axes.ravel()[i_ax]
        print_axheader(ax, ax_ttl, 'top', fontsize=13, fontweight='bold')

# Adding the cutoff indicators    
for ax in axes.ravel():
    for x_cutoff in [0.04, 0.55, 0.84]:
        ax.axvline(x=x_cutoff, lw=1, ls='--', c='black')

assert len(i_savefigs) > 0
i_savefig = min(i_savefigs)
pdfpath = f'{figdir}/{i_savefig:02d}_frachistcmprd.pdf'
# fig.savefig(pdfpath, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

fig

In [None]:
for viz_id, acc_vizvrnts in hie2deep(acc_vrnts).items():
    acc_vizid = 0
    for vrnt, w_vrnt in [('gaus0', 0.4), ('gaus1', 0.2), ('gaus2', 0.2), ('gaus3', 0.2)]:
        acc_vizid = acc_vizid + acc_vizvrnts[vrnt] * w_vrnt
    assert acc_vizid.shape == (n_seeds,)

    acc_vizidmean = np.mean(acc_vizid)
    acc_vizidse = 2.5 * np.std(acc_vizid) / np.sqrt(n_seeds)

    print(f'Accuracy of {viz_id}: {acc_vizidmean*100:.2f} +/- {acc_vizidse*100:.2f}')


### The Generated Aerosol Population Diagnostic Plots - Polished for Paper

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
v_mplcfgs['m_chmprthst.genr.mqnt']['plt.subplots/nrows'] = 2
v_mplcfgs['m_chmprthst.genr.mqnt']['plt.subplots/ncols'] = 4
n_figrows = v_mplcfgs['m_chmprthst.genr.mqnt']['plt.subplots/nrows']
n_figcols = v_mplcfgs['m_chmprthst.genr.mqnt']['plt.subplots/ncols']

vrnt2lbl = {'gaus': 'Random', 'gaus0': 'Low Dust', 'gaus1': 'Med-Low Dust', 
    'gaus2': 'Med-High Dust', 'gaus3': 'High Dust'}
vrnts = list(vrnt2lbl)[1:]

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    n_binsum = viz_data['n_bins']
    d_histbinsum = viz_data['d_histbins']
    len_wvum = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    eps_histbins = viz_data['eps_histbins']
    temprtr_bins = viz_data['temprtr_bins']
    chem_species = viz_data['chem_species']

    if (split != 'normal') or not exprmnt.startswith('cond.mqnt.'):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/massnum']

    # v_mplcfgs['ccn_cdf.genr.mqnt']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    # v_mplcfgs['frznfrac_tmp.genr.mqnt']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]
    
    i_sel = i_sel[:, :n_figcols]
    n_page, n_pgsamps = i_sel.shape
    assert n_pgsamps == n_figcols

    figs_dict = dict()
    for i_page in range(n_page):
        v_mpldatas = dict()
        i_samps = i_sel[i_page]
        # The speciated mass data
        for i_figcol, i_samp in enumerate(i_samps):
            ax_idx = i_figcol
            vrnt = vrnts[i_figcol]
            for i_chem, chem in enumerate(chem_species):
                v_mpldatas[f'm_chmprthst.genr.mqnt/{ax_idx}:{chem}:x'] = d_histbinsum
                i_rcns, n_rcns = 0, 1
                v_ynparr1 = v_ydata[f'm_chmprthst/{vrnt}']
                assert v_ynparr1.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
                v_ynparr2 = v_ynparr1.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                v_mpldatas[f'm_chmprthst.genr.mqnt/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]
        
        # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
        ax_ixyspec = []
        ax_ixyspec.append([1, d_histbinsum, 'n_prthst'    ])
        # ax_ixyspec.append([2, eps_histbins, 'ccn_cdf'     ])
        # ax_ixyspec.append([3, temprtr_bins, 'frznfrac_tmp'])
        # ax_ixyspec.append([4, len_wvum,     'qs_pop'      ])
        # ax_ixyspec.append([5, len_wvum,     'qa_pop'      ])
        for i_figcol, i_samp in enumerate(i_samps):
            for i_row, xvals, ycol in ax_ixyspec:
                ax_idx = i_row * n_figcols + i_figcol
                vrnt = vrnts[i_figcol]
                v_mpldatas[f'{ycol}.genr.mqnt/{ax_idx}:genr:x'] = xvals
                i_rcns, n_rcns = 0, 1
                v_ynparr1 = v_ydata[f'{ycol}/{vrnt}']
                assert v_ynparr1.shape[:3] == (n_seeds, n_snrt, n_rcns)
                v_ynparr2 = v_ynparr1.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr1.shape[3:])
                assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                v_mpldatas[f'{ycol}.genr.mqnt/{ax_idx}:genr:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
        
        v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
        ########## Calling Matplotlib to Plot the Diagnostics ########## 
        fig, axes = None, None
        for ax_id in v_mpldatashie:
            fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                axes=axes, mplopts=v_mplcfgs[ax_id])
    
        figs_dict[f'{i_page}/diag'] = fig

        # Adding the top left april tag for axis identification
        with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
            for i_figrow in range(axes.shape[0]):
                for i_figcol in range(axes.shape[1]):
                    ax = axes[i_figrow, i_figcol]
                    tag_text = f'({"ab"[i_figrow]}$_{{{i_figcol + 1}}}$)'
                    tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.9))

            for i_figcol in range(axes.shape[1]):
                vrnt = vrnts[i_figcol]
                ax_ttl = vrnt2lbl[vrnt]
                print_axheader(axes[0, i_figcol], ax_ttl, 'top', 
                    fontsize=14, fontweight='bold', pad=3)

        # Adding the small identification text box
        i_figrow, i_oin = 0, chem_species.index('Dust')
        for i_figcol, ax in enumerate(axes[i_figrow]):
            i_samp = i_samps[i_figcol]
            vrnt = vrnts[i_figcol]
            v_ynparr1 = v_ydata[f'm_chmprthst/{vrnt}']
            assert v_ynparr1.shape == (n_seeds, n_snrt, 1, n_chem, n_binsum)
            v_ynparr2 = v_ynparr1.reshape(n_seeds * n_snrt, n_chem, n_binsum)
            assert v_ynparr2.shape == (n_seeds * n_snrt, n_chem, n_binsum)
            v_ynparr3 = v_ynparr2[i_samp]
            assert v_ynparr3.shape == (n_chem, n_binsum)
            v_ynparr4 = v_ynparr3.sum(axis=-1)
            assert v_ynparr4.shape == (n_chem,)
            v_ynparr5 = v_ynparr4 / v_ynparr4.sum()
            assert v_ynparr5.shape == (n_chem,)
            textstr = f'{100*v_ynparr5[i_oin].item():.1f}' + '\% Dust'
            props = dict(facecolor='none', edgecolor='none')
            ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                usetex=True, verticalalignment='top', bbox=props)

    _nicknm = '' if nicknm is None else f'_{nicknm}'
    os.makedirs(figdir, exist_ok=True)
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_anec.pdf'
    with PdfPages(pdfpath) as pdf:
        for i_page, fig_clctn in figs_dict.items():
            pdf.savefig(figure=fig_clctn, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')

fig

### The Generated OIN Histograms (Legacy)

In [None]:
plt.ioff()
# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    if (split != 'normal') or not exprmnt.startswith('cond.mqnt.'):
        continue
    
    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/hist']

    v_mpldatas = dict()
    n_rcns, i_oin = 1, chem_species.index('Dust')
    for vrnt in ['gaus0', 'gaus1', 'gaus2', 'gaus3']:
        v_ydata = viz_data['v_ydata']
        v_ynparr = v_ydata[f'm_chmprthst/{vrnt}']
        n_bins2 = v_ynparr.shape[-1]
        assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_bins2)

        # The Dust absolute mass of generated samples
        m_oinabs1 = v_ynparr.sum(-1)[:, :, 0, i_oin]
        assert m_oinabs1.shape == (n_seeds, n_snrt)

        m_oinabs2 = m_oinabs1[0]
        assert m_oinabs2.shape == (n_snrt,)

        # The OIN mass fraction of generated samples
        m_chmfracs = v_ynparr.sum(-1, keepdims=True) / v_ynparr.sum((-1, -2), keepdims=True)
        assert m_chmfracs.shape == (n_seeds, n_snrt, n_rcns, n_chem, 1)
        m_oinfrac1 = m_chmfracs[:, :, 0, i_oin, 0]
        assert m_oinfrac1.shape == (n_seeds, n_snrt)
        m_oinfrac2 = m_oinfrac1[0]
        assert m_oinfrac2.shape == (n_snrt,)

        v_mpldatas[f'cond.mfrachist.legacy/0:{vrnt}:y'] = m_oinfrac1.ravel()
        v_mpldatas[f'cond.mabshist/1:{vrnt}:y'] = m_oinabs1.ravel()
        
    v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
    ########## Calling Matplotlib to Plot the Diagnostics ########## 
    fig, axes = None, None
    for ax_id in v_mpldatashie:
        fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
            axes=axes, mplopts=v_mplcfgs[ax_id])

    _nicknm = '' if nicknm is None else f'_{nicknm}'
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_hist.pdf'
    fig.savefig(pdfpath, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')

fig

### The Generated Aerosol Population Diagnostic Plots (Legacy)

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
n_figrows = v_mplcfgs['m_chmprthst.genr']['plt.subplots/nrows']
v_mplcfgs['m_chmprthst.genr']['plt.subplots/ncols'] = 5
n_figcols = v_mplcfgs['m_chmprthst.genr']['plt.subplots/ncols']

vrnt2lbl = {'gaus': 'Random', 'gaus0': 'Low OIN', 'gaus1': 'Med-Low OIN', 
    'gaus2': 'Med-High OIN', 'gaus3': 'High OIN'}
vrnts = list(vrnt2lbl)

for viz_id, viz_data in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    n_binsum = viz_data['n_bins']
    d_histbinsum = viz_data['d_histbins']
    len_wvum = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    eps_histbins = viz_data['eps_histbins']
    temprtr_bins = viz_data['temprtr_bins']

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/diag']

    v_mplcfgs['ccn_cdf.genr']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['frznfrac_tmp.genr']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]

    if (split != 'normal') or not exprmnt.startswith('cond.mqnt.'):
        continue
    
    n_page, n_pgsamps = i_sel.shape
    assert n_pgsamps == n_figcols

    figs_dict = dict()
    for i_page in range(n_page):
        v_mpldatas = dict()
        i_samps = i_sel[i_page]
        # The speciated mass data
        for i_figcol, i_samp in enumerate(i_samps):
            ax_idx = i_figcol
            vrnt = vrnts[i_figcol]
            for i_chem, chem in enumerate(chem_species):
                v_mpldatas[f'm_chmprthst.genr/{ax_idx}:{chem}:x'] = d_histbinsum
                i_rcns, n_rcns = 0, 1
                v_ynparr1 = v_ydata[f'm_chmprthst/{vrnt}']
                assert v_ynparr1.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
                v_ynparr2 = v_ynparr1.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                v_mpldatas[f'm_chmprthst.genr/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]
        
        # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
        ax_ixyspec = []
        ax_ixyspec.append([1, d_histbinsum, 'n_prthst'    ])
        ax_ixyspec.append([2, eps_histbins, 'ccn_cdf'     ])
        ax_ixyspec.append([3, temprtr_bins, 'frznfrac_tmp'])
        ax_ixyspec.append([4, len_wvum,     'qs_pop'      ])
        ax_ixyspec.append([5, len_wvum,     'qa_pop'      ])
        for i_figcol, i_samp in enumerate(i_samps):
            for i_row, xvals, ycol in ax_ixyspec:
                ax_idx = i_row * n_figcols + i_figcol
                vrnt = vrnts[i_figcol]
                v_mpldatas[f'{ycol}.genr/{ax_idx}:genr:x'] = xvals
                i_rcns, n_rcns = 0, 1
                v_ynparr1 = v_ydata[f'{ycol}/{vrnt}']
                assert v_ynparr1.shape[:3] == (n_seeds, n_snrt, n_rcns)
                v_ynparr2 = v_ynparr1.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr1.shape[3:])
                assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                v_mpldatas[f'{ycol}.genr/{ax_idx}:genr:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
        
        v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
        ########## Calling Matplotlib to Plot the Diagnostics ########## 
        fig, axes = None, None
        for ax_id in v_mpldatashie:
            fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                axes=axes, mplopts=v_mplcfgs[ax_id])
    
        figs_dict[f'{i_page}/diag'] = fig

        # Adding the small identification text box
        axes1d = axes.ravel()
        for i_figcol, i_samp in enumerate(i_samps):
            ax = axes[0, i_figcol]
            vrnt = vrnts[i_figcol]
            textstr = vrnt2lbl[vrnt]
            props = dict(facecolor='none', edgecolor='none')
            ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                usetex=True, verticalalignment='top', bbox=props)

    _nicknm = '' if nicknm is None else f'_{nicknm}'
    os.makedirs(figdir, exist_ok=True)
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_anec.pdf'
    with PdfPages(pdfpath) as pdf:
        for i_page, fig_clctn in figs_dict.items():
            pdf.savefig(figure=fig_clctn, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')

fig

### Collecting the Data For MDS Scatter Plots

In [None]:
vizsct_datas = dict()

for exprm_id, exprm_info in hie2deep(exprm_infos).items():
    (exprmnt, arch) = exprm_id.split(':')
    fpidx = exprm_info['fpidx']
    resdir = exprm_info['resdir']
    n_seeds = exprm_info['n_seeds']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    
    if fpidx is None:
        continue
    
    try:
        rio = resio(fpidx=fpidx, resdir=resdir)
        rio_dtypes = rio.dtypes()
    except Exception as exc:
        print(f'There was an error opening {fpidx}. I will move on.')
        continue

    rio_keys = [key.split(':', 1)[-1] for key in rio_dtypes]

    v_mdsdataraw = dict()
    for key in rio_keys:
        if not any(
            key.startswith(f'var/eval/mds/{eid}:') or 
            key.startswith(f'var/eval/raw/{eid}:y:y/') 
            for eid in ['identity', 'polyhstbal']):
            continue
        val = rio(key)
        n_pnts, n_rcns, *d_repr = val.shape[1:]
        assert n_rcns == 1
        assert val.shape == (n_epoch * n_seeds, n_pnts, n_rcns, *d_repr)
        val_lastepoch = val.reshape(n_epoch, n_seeds, n_pnts, *d_repr)[-1]
        assert val_lastepoch.shape == (n_seeds, n_pnts, *d_repr)
        v_mdsdataraw[key] = val_lastepoch

    # Example:
    #   v_mdsdataraw = {
    #       'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/normal/genr/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/test/orig/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/test/rcnst/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/train/orig/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/train/rcnst/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/normal/genr/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/mu/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/phi/data': np.randn(10, 5000, 1),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/test/orig/sig/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/test/rcnst/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/mu/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/phi/data': np.randn(10, 5000, 1),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/train/orig/sig/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1zpcalr:z:z/train/rcnst/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1ztsne:z:z/normal/genr/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1ztsne:z:z/test/orig/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1ztsne:z:z/test/rcnst/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1ztsne:z:z/train/orig/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/mds/identity:1x1ztsne:z:z/train/rcnst/pnts/data': np.randn(10, 5000, 2),
    #       'var/eval/raw/identity:y:y/normal/genr/pnts/data': np.randn(10, 5000, 1, 4),
    #       'var/eval/raw/identity:y:y/test/orig/pnts/data': np.randn(10, 5000, 1, 4),
    #       'var/eval/raw/identity:y:y/test/rcnst/pnts/data': np.randn(10, 5000, 1, 4),
    #       'var/eval/raw/identity:y:y/train/orig/pnts/data': np.randn(10, 5000, 1, 4),
    #       'var/eval/raw/identity:y:y/train/rcnst/pnts/data': np.randn(10, 5000, 1, 4),
    #   }

    vizsct_datas[f'{exprmnt}:{arch}/v_mdsdataraw'] = v_mdsdataraw

    pats_rnm1 = {
        'var/eval/raw/identity:y:y/{split}/{vrnt}/pnts/data': 
        '{split}/{vrnt}'}
    y_data = get_subdictrnmd(v_mdsdataraw, pats_rnm1)

    for splitvrnt in y_data:
        split, vrnt = splitvrnt.split('/')
        # The one-hot class labels
        y_ohpnts1 = y_data[f'{split}/{vrnt}']
        n_pnts, n_class = y_ohpnts1.shape[1], y_ohpnts1.shape[-1]
        assert y_ohpnts1.shape == (n_seeds, n_pnts, 1, n_class)

        # The one-hot class labels for each point
        y_ohpnts = y_ohpnts1.reshape(n_seeds, n_pnts, n_class)
        assert y_ohpnts.shape == (n_seeds, n_pnts, n_class)

        # Making sure the one-hot labels are actually one-hot
        assert np.logical_or(y_ohpnts == 0, y_ohpnts == 1).all()
        assert (y_ohpnts.sum(-1) == 1).all()

        # Converting the one-hot labels to class indices for every point
        y_pntscls = y_ohpnts.argmax(-1)
        assert y_pntscls.shape == (n_seeds, n_pnts)

        # The fraction of points in each class
        y_clsfrac = (y_pntscls[..., None] == np.arange(n_class)[None, None]).mean(-2)
        assert y_clsfrac.shape == (n_seeds, n_class)
        assert np.allclose(y_clsfrac.sum(-1), 1)

        vizsct_datas[f'{exprmnt}:{arch}/{split}/{vrnt}/n_seeds'] = n_seeds
        vizsct_datas[f'{exprmnt}:{arch}/{split}/{vrnt}/n_pnts'] = n_pnts
        vizsct_datas[f'{exprmnt}:{arch}/{split}/{vrnt}/n_class'] = n_class
        vizsct_datas[f'{exprmnt}:{arch}/{split}/{vrnt}/y_pntscls'] = y_pntscls
        vizsct_datas[f'{exprmnt}:{arch}/{split}/{vrnt}/y_clsfrac'] = y_clsfrac


### The Class-Specific MDS Scatter Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
n_figrows = v_mplcfgs['cond.ztsne']['plt.subplots/nrows']
n_figcols = v_mplcfgs['cond.ztsne']['plt.subplots/ncols']

for exprmntarch, viz_data in hie2deep(vizsct_datas, maxdepth=1).items():
    exprmnt, arch = exprmntarch.split(':')
    v_mdsdataraw = viz_data['v_mdsdataraw']
    if not exprmnt.startswith('cond.mqnt.'):
        continue

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/mdscls']

    figs_dict = dict()
    for i_page, (split, vrnt) in enumerate([('train', 'orig'), ('test', 'orig'), ('normal', 'gaus')]):
        n_pnts = viz_data[f'{split}/{vrnt}/n_pnts']
        n_class = viz_data[f'{split}/{vrnt}/n_class']
        y_pntscls = viz_data[f'{split}/{vrnt}/y_pntscls']
        y_clsfrac = viz_data[f'{split}/{vrnt}/y_clsfrac']
        assert n_figcols == (n_class + 1)
        
        ########################### Collecting the plotting data ############################
        ax_idx = 0
        v_mpldatas1 = dict()
        for mplid, vid in [
            ('cond.ztsne', '1x1ztsne:z:z'), 
            ('cond.zpca', '1x1zpcalr:z:z'), 
            ('cond.xtsne', '1x1xtsne:x:m_chmprthst')]:
            ###### The Z TSNE individual-class data ######
            for i_class in range(n_class):
                pats_rnm2 = {
                    f'var/eval/mds/identity:{vid}/{split}/{vrnt}/{{vrepr}}/data': 
                    f'{mplid}/{ax_idx}:{split}/class/{i_class}:{{vrepr}}'}
                i_seed = 0
                ax_data1 = get_subdictrnmd(v_mdsdataraw, pats_rnm2)
                ax_data = {key: val[i_seed, y_pntscls[i_seed] == i_class] 
                    for key, val in ax_data1.items()}
                v_mpldatas1.update(ax_data)
                ax_idx += 1

            ###### The Z TSNE combined-class data ######
            pats_rnm1 = {
                f'{mplid}/{{ax_idx}}:{{split}}/class/{{i_class}}:{{vrepr}}': 
                f'{mplid}/{ax_idx}:{{split}}/class/{{i_class}}:{{vrepr}}'}
            ax_data = get_subdictrnmd(v_mpldatas1, pats_rnm1)
            v_mpldatas1.update(ax_data)
            ax_idx += 1

        # Splitting the various plotting configs
        v_mpldatas2 = hie2deep(v_mpldatas1, maxdepth=1)

        ################### Calling Matplotlib to Plot the Scatter Points ################### 
        fig, axes = None, None
        for ax_id in v_mpldatas2:
            fig, axes = plot_mpl(data=v_mpldatas2[ax_id], fig=fig, 
                axes=axes, mplopts=v_mplcfgs[ax_id])

        # Adding the top left april tag for axis identification
        n_figrows, n_figcols = axes.shape
        with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
            for ax_idx, ax in enumerate(axes.ravel()):
                i_figrow, i_figcol = ax_idx // n_figcols, ax_idx % n_figcols
                tag_text = f'({"abcde"[i_figcol]}$_{{{i_figrow + 1}}}$)'
                tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
        
            for i_figrow, row_ttl in enumerate(['Latent t-SNE', 'Latent PCA', 'Mass t-SNE']):
                print_axheader(axes[i_figrow, 0], row_ttl, 'left', 
                    fontsize=13, fontweight='bold')
            
            lblnames = ['Low Dust', 'Med-Low Dust', 'Med-High Dust', 'High Dust', 'All']
            split_frml = {'train': 'Train', 'test': 'Test', 'normal': 'Generated'}[split]
            for i_figcol, col_ttl in enumerate(lblnames):
                print_axheader(axes[0, i_figcol], f'{col_ttl}\n({split_frml})', 'top', 
                    fontsize=12, fontweight='bold')
        
        figs_dict[f'{i_page}/{split}/{vrnt}'] = fig

    _nicknm = '' if nicknm is None else f'_{nicknm}'
    pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_mds.pdf'
    with PdfPages(pdfpath) as pdf:
        for fig_id, fig in figs_dict.items():
            pdf.savefig(figure=fig, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')

    for fig_id, fig in figs_dict.items():
        i_page, split, vrnt = fig_id.split('/')
        pngpath = pdfpath.replace('.pdf', f'_{split}.png')
        fig.savefig(pngpath, dpi=200, bbox_inches='tight')
    print(f'Finished writing {pngpath}')

fig

### The Split-Specific MDS Scatter Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
    
figs_dict = dict()
for exprmntarch, viz_data in hie2deep(vizsct_datas, maxdepth=1).items():
    exprmnt, arch = exprmntarch.split(':')
    v_mdsdataraw = viz_data['v_mdsdataraw']

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/mdssplt']
    
    ########################### Collecting the plotting data ############################
    v_mpldatas = dict()
    nrmvrnt = 'gaus' if ('var/eval/mds/identity:1x1ztsne:z:z/normal/gaus/pnts/data' in v_mdsdataraw) else 'genr'

    ###### The Z TSNE individual-split data ######
    for ax_idx, split, vrnt in [(0, 'train', 'orig'), (1, 'test', 'orig'), (2, 'normal', nrmvrnt)]:
        pats_rnm2 = {
            f'var/eval/mds/identity:1x1ztsne:z:z/{split}/{vrnt}/pnts/data': 
            f'ztsne/{ax_idx}:{split}/{vrnt}:pnts'}
        ax_data = get_subdictrnmd(v_mdsdataraw, pats_rnm2)
        v_mpldatas.update(ax_data)

    ###### The Z TSNE combined-splits data ######
    ax_idx = 3
    pats_rnm1 = {
        'ztsne/{ax_idx}:{split}/{vrnt}:{vrepr}': 
        f'ztsne/{ax_idx}:{{split}}/{{vrnt}}:{{vrepr}}'}
    ax_data = get_subdictrnmd(v_mpldatas, pats_rnm1)
    v_mpldatas.update(ax_data)

    ###### The Z PCA individual-split data ######
    for ax_idx, split, vrnt in [(4, 'train', 'orig'), (5, 'test', 'orig'), (6, 'normal', nrmvrnt)]:
        pats_rnm4 = {
            f'var/eval/mds/identity:1x1zpcalr:z:z/{split}/{vrnt}/{{vrepr}}/data':
            f'zpca/{ax_idx}:{split}/{vrnt}:{{vrepr}}'}
        ax_data = get_subdictrnmd(v_mdsdataraw, pats_rnm4)
        v_mpldatas.update(ax_data)

    ###### The Z PCA combined-splits data ######
    ax_idx = 7
    pats_rnm3 = {
        'zpca/{ax_idx}:{split}/{vrnt}:{vrepr}': 
        f'zpca/{ax_idx}:{{split}}/{{vrnt}}:{{vrepr}}'}
    ax_data = get_subdictrnmd(v_mpldatas, pats_rnm3)
    v_mpldatas.update(ax_data)

    ##### The X TSNE individual-split data #####
    for ax_idx, split, vrnt in [(8, 'train', 'orig'), (8, 'train', 'rcnst'), 
        (9, 'test', 'orig'), (9, 'test', 'rcnst'), (10, 'normal', nrmvrnt)]:
        pats_rnm6 = {
            f'var/eval/mds/identity:1x1xtsne:x:m_chmprthst/{split}/{vrnt}/pnts/data': 
            f'xtsne/{ax_idx}:{split}/{vrnt}:pnts'}
        ax_data = get_subdictrnmd(v_mdsdataraw, pats_rnm6)
        v_mpldatas.update(ax_data)

    ###### The X TSNE combined-splits data ######
    ax_idx = 11
    pats_rnm5 = {
        'xtsne/{ax_idx}:{split}/{vrnt}:{vrepr}': 
        f'xtsne/{ax_idx}:{{split}}/{{vrnt}}:{{vrepr}}'}
    ax_data = get_subdictrnmd(v_mpldatas, pats_rnm5)
    v_mpldatas.update(ax_data)

    # Restricting the data to the first seed
    v_mpldatas2 = {key: val[0] for key, val in v_mpldatas.items()}

    # Splitting the various plotting configs
    v_mpldatas3 = hie2deep(v_mpldatas2, maxdepth=1)

    ################### Calling Matplotlib to Plot the Scatter Points ################### 
    fig, axes = None, None
    for ax_id in v_mpldatas3:
        fig, axes = plot_mpl(data=v_mpldatas3[ax_id], fig=fig, 
            axes=axes, mplopts=v_mplcfgs[ax_id])

    # Making sure the legend labels have full alpha!
    leg, = fig.legends
    for handle in leg.legend_handles:
        handle.set_alpha(1.0)

    # Adding the top left april tag for axis identification
    n_figrows, n_figcols = axes.shape
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for ax_idx, ax in enumerate(axes.ravel()):
            i_figrow, i_figcol = ax_idx // n_figcols, ax_idx % n_figcols
            tag_text = f'({"abcd"[i_figcol]}$_{{{i_figrow + 1}}}$)'
            tag_axis(ax, tag_text, fontsize=11, pad=(0.3, 0.6))
    
        for i_figrow, row_ttl in enumerate(['Latent t-SNE', 'Latent PCA', 'Mass t-SNE']):
            print_axheader(axes[i_figrow, 0], row_ttl, 'left', 
                fontsize=14, fontweight='bold')
        
        for i_figcol, col_ttl in enumerate(['Train', 'Test', 'Generated', 'All']):
            print_axheader(axes[0, i_figcol], col_ttl, 'top', 
                fontsize=14, fontweight='bold')

    break

    _nicknm = '' if nicknm is None else f'_{nicknm}'
    figpath1 = f'{figdir}/{i_savefig:02d}{_nicknm}_mds.pdf'
    fig.savefig(figpath1, bbox_inches='tight')
    print(f'Finished writing {figpath1}')

    figpath2 = figpath1.replace('.pdf', '.png')
    fig.savefig(figpath2, dpi=200, bbox_inches='tight')
    print(f'Finished writing {figpath2}')

fig

## Continuous-Label Condtional Training

### The Generated Aerosol Population Diagnostic Plots

In [None]:
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

for (viz_id, viz_data) in hie2deep(viz_datas, maxdepth=1).items():
    exprmnt, arch, split = viz_id.split(':')
    v_ydata = viz_data['v_ydata']
    d_histbinsum = viz_data['d_histbins']
    n_binsum = viz_data['n_bins']
    len_wvum = viz_data['len_wv']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    eps_histmin = viz_data['eps_histmin']
    eps_histmax = viz_data['eps_histmax']
    eps_histbins = viz_data['eps_histbins']
    tmprtr_inpmin = viz_data['tmprtr_inpmin']
    tmprtr_inpmax = viz_data['tmprtr_inpmax']
    temprtr_bins = viz_data['temprtr_bins']
    i_sel = viz_data['i_sel']
    e_samps = viz_data['e_samps']
    n_files, n_page = i_sel.shape

    exprm_info = get_subdict(exprm_infos, f'{exprmnt}:{arch}')
    n_seeds = exprm_info['n_seeds']
    n_snrt = exprm_info['n_snrt']
    nicknm = exprm_info['nicknm']
    figdir = exprm_info['figdir']
    i_savefig = exprm_info['i_fig/diag']

    if (not exprmnt.startswith('cond.cont.')) or (split not in ('test',)):
        continue

    vrnt_specs = exprm_info.get('vrnt_specs', None)
    if vrnt_specs is None:
        vrnt_specs = [
            ('orig', 1, 0, 'Original'), ('rcnst', 1, 0, 'Reconst' ), 
            ('yknn', 5, 0, '1st LNN' ), ('yknn',  5, 1, '1st LNN' ),
            ('znrm', 5, 0, 'Norm Lat'), ('znrm',  5, 1, 'Norm Lat')]
    n_figcols = len(vrnt_specs)
    n_figrows = v_mplcfgs['m_chmprthst.cndrcn']['plt.subplots/nrows']
    v_mplcfgs['m_chmprthst.cndrcn']['plt.subplots/ncols'] = n_figcols
    assert len(vrnt_specs) == n_figcols

    v_mplcfgs['ccn_cdf.cndrcn']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
    v_mplcfgs['frznfrac_tmp.cndrcn']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]
    
    for i_file in range(n_files):
        figs_dict = dict()
        for i_page in range(n_page):
            v_mpldatas = dict()
            i_samp = i_sel[i_file, i_page]
            
            # The speciated mass data
            i_row = 0
            for i_col in range(n_figcols):
                vrnt, n_rcns, i_rcns, ttl = vrnt_specs[i_col]
                for i_chem, chem in enumerate(chem_species):
                    ax_idx = i_row * n_figcols + i_col
                    v_mpldatas[f'm_chmprthst.cndrcn/{ax_idx}:{chem}:x'] = d_histbinsum
                    v_ynparr = v_ydata[f'm_chmprthst/{vrnt}']
                    assert v_ynparr.shape == (n_seeds, n_snrt, n_rcns, n_chem, n_binsum)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                    assert v_ynparr2.shape == (n_seeds * n_snrt, n_rcns, n_chem, n_binsum)
                    v_mpldatas[f'm_chmprthst.cndrcn/{ax_idx}:{chem}:y'] = v_ynparr2[i_samp, i_rcns, i_chem]
            
            # The total mass, particle count, ccn, frozen fraction, and optical cross-section data
            ax_ixyspec = []
            ax_ixyspec.append([1, d_histbinsum, 'n_prthst'    ])
            ax_ixyspec.append([2, eps_histbins, 'ccn_cdf'     ])
            ax_ixyspec.append([3, temprtr_bins, 'frznfrac_tmp'])
            ax_ixyspec.append([4, len_wvum,     'qs_pop'      ])
            ax_ixyspec.append([5, len_wvum,     'qa_pop'      ])
            for i_col in range(n_figcols):
                for i_row, xvals, ycol in ax_ixyspec:
                    ax_idx = i_row * n_figcols + i_col
                    vrnt, n_rcns, i_rcns, ttl = vrnt_specs[i_col]
                    v_mpldatas[f'{ycol}.cndrcn/{ax_idx}:cndrcn:x'] = xvals
                    v_ynparr = v_ydata[f'{ycol}/{vrnt}']
                    assert v_ynparr.shape[:3] == (n_seeds, n_snrt, n_rcns)
                    v_ynparr2 = v_ynparr.reshape(n_seeds * n_snrt, n_rcns, *v_ynparr.shape[3:])
                    assert v_ynparr2.shape[:2] == (n_seeds * n_snrt, n_rcns)
                    v_mpldatas[f'{ycol}.cndrcn/{ax_idx}:cndrcn:y'] = np.squeeze(v_ynparr2[i_samp, i_rcns])
            
            v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)
            ########## Calling Matplotlib to Plot the Diagnostics ########## 
            fig, axes = None, None
            for ax_id in v_mpldatashie:
                fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
                    axes=axes, mplopts=v_mplcfgs[ax_id])

            figs_dict[f'{i_page}/diag'] = fig

            # Adding the small identification text box
            axes1d = axes.ravel()
            for i_col in range(n_figcols):
                ax = axes[0, i_col]
                vrnt, n_rcns, i_rcns, textstr = vrnt_specs[i_col]
                props = dict(facecolor='none', edgecolor='none')
                ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                    usetex=True, verticalalignment='top', bbox=props)

        _nicknm = '' if nicknm is None else f'_{nicknm}'
        pdfpath = f'{figdir}/{i_savefig:02d}{_nicknm}_anec.pdf'
        with PdfPages(pdfpath) as pdf:
            for i_page, fig_clctn in figs_dict.items():
                pdf.savefig(figure=fig_clctn, bbox_inches='tight')
        print(f'Finished writing {pdfpath}')


### Label Compliance vs Label Dimensionality

In [None]:
resdir = f'{results_dir}/02_adhoc'
rio = resio(fpidx='02_adhoc/22_mlpcnd.*.*', resdir=results_dir, driver='sec2')

stdf1 = rio('stat')
epoch_max = stdf1['epoch'].max()
idx_lastep = stdf1['epoch'] == epoch_max
stdf2 = stdf1[idx_lastep].reset_index(drop=True)

hpdf1, hpdf_info = rio('hp', ret_info=True)
hpdf2 = hpdf1[idx_lastep].reset_index(drop=True)
hpdf2.insert(0, 'fpidx', hpdf_info[idx_lastep].reset_index(drop=True)['fpidx'])
hpdf2.insert(1, 'fpidxgrp', hpdf2['fpidx'])

In [None]:
# Aggregating the data
aggcfg = dict(type='bootstrap', n_boot=1000, q=[5, 95], stat='mean', device='cpu')
agg_data = get_aggdf(hpdf2, stdf2, xcol='epoch', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
hpdf_agg, stdf_agg = agg_data['hpdf'], agg_data['stdf']
hpdf_agg = drop_unqcols(hpdf_agg)

In [None]:
# Creating combined performance metrics
ycol_rnmngs = {
    'perf/compliance1/polyhst/normal/genr/polyhst/normal/genr.inp/same/0/y/y/abserr/mean/{stat}': 'cmplnc/encd/{stat}',
    'perf/compliance2/tricmplnc/test/znrm/tricmplnc/test/orig/same/0/x/yvec/abserr/mean/{stat}':  'cmplnc/znrm/{stat}',
    'perf/compliance2/tricmplnc/test/yknn/tricmplnc/test/orig/same/0/x/yvec/abserr/mean/{stat}':  'cmplnc/yknn/{stat}',
    'perf/compliance2/tricmplnc/test/rcnst/tricmplnc/test/orig/same/0/x/yvec/abserr/mean/{stat}': 'cmplnc/rcnst/{stat}'}

for ycolinp_fmt, ycolout_fmt in ycol_rnmngs.items():
    for stat in ['mean', 'low', 'high']:
        ycol_agginp = ycolinp_fmt.format(stat=stat)
        ycol_aggout = ycolout_fmt.format(stat=stat)
        stdf_agg[ycol_aggout] = stdf_agg[ycol_agginp]

hpstdf_agg = pd.concat([hpdf_agg, stdf_agg], axis=1)

In [None]:
i_fig = 10
plt.ioff()

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

huecol = 'cri/indzy/w'
x_col = 'nn/enc/n_lenlbl'
pltdf = hpstdf_agg.sort_values(by=[huecol, x_col])

fig, axes = None, None
for i_figcol, y_col in enumerate(['cmplnc/encd', 'cmplnc/znrm', 'cmplnc/yknn', 'cmplnc/rcnst']):
    v_mpldatas = dict()
    for hueval, huedf in pltdf.groupby(huecol):
        v_mpldatas[f'{i_figcol}:{hueval}:x'] = huedf[x_col].values
        v_mpldatas[f'{i_figcol}:{hueval}:y/mean'] = huedf[f'{y_col}/mean'].values
        v_mpldatas[f'{i_figcol}:{hueval}:y/low'] = huedf[f'{y_col}/low'].values
        v_mpldatas[f'{i_figcol}:{hueval}:y/high'] = huedf[f'{y_col}/high'].values
    
    v_mplcfg = v_mplcfgs[y_col.replace('/', '.')].copy()
    fig, axes = plot_mpl(data=v_mpldatas, fig=fig, 
        axes=axes, mplopts=v_mplcfg)

akws = {'xycoords': 'axes fraction', 'textcoords': 'axes fraction', 
    'fontsize': 10, 'bbox/pad': 90, 'bbox/facecolor': 'none', 'bbox/edgecolor': 'none', 
    'arrowprops/arrowstyle': '->', 'arrowprops/connectionstyle': 'arc3,rad=0.2'}
annot_kws = dict()
annot_kws['depzy'] = {'text': 'High Dependence', 'xytext': [0.4, 0.88], 
    'xy': [0.2, 0.82], 'arrowprops/relpos': [0.0, 0.5], **akws}
annot_kws['indzy'] = {'text': 'Low Dependence', 'xytext': [0.05, 0.05], 
    'xy': [0.85, 0.14], 'arrowprops/relpos': [1.0, 0.5], **akws}
annot_kws = hie2deep(deep2hie(annot_kws))

with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
    for lbl, antkws in annot_kws.items():
        axes[0, 0].annotate(**antkws)

pdfpath = f'{suppdir}/07_nrmtrgcont/{i_fig:02d}_lbl_cmplnc_mlp.pdf'
fig.savefig(pdfpath, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

fig

# PP Simulation Plots

## Loading the Simulation Data

In [None]:
n_seeds, n_snrtsim = 10, 1000
sim_h5path = './18_ccnmieinp/11_sim.h5'

############## Collecting the plotting data ###############
ycols = ['m_chmprthst', 'n_prthst', 'ccn_cdf', 'qs_prt', 'qscs_prt', 'qs_pop',
    'qa_prt', 'qacs_prt', 'qa_pop', 'frznfrac_tmp', 'logfrznfrac_tmp']

v_ydataraws = dict()
sim_data = load_h5data(sim_h5path)
for ycol in ycols:
    n_ychnls, n_ylen = ycol2dims[ycol]
    y_orig1 = sim_data[f'var/eval/yaero/orig/{ycol}']
    assert y_orig1.shape == (n_seeds, n_snrtsim, n_ychnls, n_ylen)
    for pptype in ['tuned', 'plain']:
        y_rcnst1 = sim_data[f'var/eval/yaero/{pptype}/{ycol}']
        assert y_rcnst1.shape == (n_seeds, n_snrtsim, n_ychnls, n_ylen)
        v_ydataraws[f'{pptype}/{ycol}/orig'] = y_orig1[:, :, None, :, :]
        v_ydataraws[f'{pptype}/{ycol}/rcnst'] = y_rcnst1[:, :, None, :, :]

v_mtrcdatas, v_ydatas = dict(), dict()
for pptype, v_ydataraw in hie2deep(v_ydataraws, maxdepth=1).items():
    ########### Data Cleaning and Unit Conversions ############
    v_ydata, aero_cstscnv = cnvrt_physunits(v_ydataraw, aero_csts)
    ######## Computing the Aerosol Diagnostic Metrics #########
    v_mtrcdata = calc_aerometrics(v_ydata, aero_cstscnv, avg_errs=True)
    
    v_mtrcdatas[pptype] = v_mtrcdata
    v_ydatas[pptype] = v_ydata

v_mtrcdatas = deep2hie(v_mtrcdatas)
v_ydatas = deep2hie(v_ydatas)

## Computing Confidence Intervals with Statistical Bootstrapping

In [None]:
noise_amp = 0.3

hpdf_lst, statdf_lst = [], []
for pptype, v_mtrcdata in hie2deep(v_mtrcdatas).items():
    for i_seed in range(n_seeds):
        strowdict = dict()
        strowdict['rng_seed'] = i_seed
        strowdict['noise_amp'] = noise_amp
        for ycol, val in v_mtrcdata.items():
            if isinstance(val, dict) and (val['err'] is not None):
                strowdict[f'{ycol}.err'] = val['err'][i_seed]
        statdf_lst.append(strowdict)
        hprowdict = {'pptype': pptype, 'fpidxgrp': pptype}
        hpdf_lst.append(hprowdict)

hpdf = pd.DataFrame(hpdf_lst)
statdf = pd.DataFrame(statdf_lst)

# Aggregating the data
aggcfg = dict(type='bootstrap', n_boot=1000, q=[5, 95], stat='mean', device='cpu')
agg_data = get_aggdf(hpdf, statdf, xcol='noise_amp', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
aggdf1 = pd.concat([agg_data['hpdf'], agg_data['stdf']], axis=1)
aggdf1 = aggdf1.drop(columns='fpidxgrp')
hpcols = ['pptype', 'noise_amp']

# Example: 
#   aggdf1 = 
#       pptype  noise_amp  ccn_cdf.err/mean  ccn_cdf.err/low  ccn_cdf.err/high  ...
#     0  tuned        0.3          0.067326         0.066693          0.067982   
#     1  plain        0.3          0.401532         0.398868          0.403894 

ycols = list({col.removesuffix('.err/mean'): None 
    for col in aggdf1.columns if col.endswith('.err/mean')})
assert all((col in hpcols) or (col.split('.err/')[0] in ycols) 
    for col in aggdf1.columns)

aggdf_melts = []
for stat in ['mean', 'low', 'high']:
    aggdf3 = aggdf1.melt(id_vars=hpcols, 
        value_vars=[f'{ycol}.err/{stat}' for ycol in ycols],
        var_name='metric', value_name=f'err/{stat}')
    aggdf3['metric'] = aggdf3['metric'].str.removesuffix(f'.err/{stat}')
    aggdf3 = aggdf3.set_index(hpcols + ['metric'])
    aggdf_melts.append(aggdf3)

aggdf4 = pd.concat(aggdf_melts, axis=1).reset_index()

# Example:
#   aggdf4 = 
#         pptype  noise_amp           metric  err/mean   err/low  err/high
#      0   tuned        0.3          ccn_cdf  0.067326  0.066693  0.067982
#      1   plain        0.3          ccn_cdf  0.401532  0.398868  0.403894
#      2   tuned        0.3           qs_pop  0.076846  0.075991  0.077730
#      3   plain        0.3           qs_pop  0.913600  0.908511  0.919038
#      4   tuned        0.3           qa_pop  0.165132  0.158923  0.170963
#      5   plain        0.3           qa_pop  0.953014  0.946191  0.959708
#      6   tuned        0.3  logfrznfrac_tmp  0.033548  0.030919  0.035911
#      7   plain        0.3  logfrznfrac_tmp  0.140155  0.135256  0.144967
#      8   tuned        0.3     frznfrac_tmp  0.341163  0.319734  0.361636
#      9   plain        0.3     frznfrac_tmp  0.813503  0.804350  0.822792
#      10  tuned        0.3      m_chmprthst  0.184169  0.180745  0.186895
#      11  plain        0.3      m_chmprthst  0.805887  0.799019  0.813259
#      12  tuned        0.3         n_prthst  0.056184  0.055867  0.056504
#      13  plain        0.3         n_prthst  0.094233  0.093702  0.094785
#      14  tuned        0.3         m_prthst  0.168582  0.165300  0.171351
#      15  plain        0.3         m_prthst  0.854394  0.849895  0.859212
#      16  tuned        0.3      m_perchmhst  0.489110  0.482695  0.495706
#      17  plain        0.3      m_perchmhst  0.946817  0.945679  0.947944

## The Tuned vs. Plain Error Metrics

### Horizontal Grouped Bar Plots

In [None]:
i_fig = 86
plt.ioff()

# Reading the experiment to fpidx specification configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

ycols1 = {
    'ccn_cdf': 'CCN Spectrum\nRelative Error', 
    'logqs_pop': 'Scattering\nLog-Rel Error', 
    'logqa_pop': 'Absorption\nLog-Rel Error', 
    'logfrznfrac_tmp': 'Frozen Fraction\nLog-Rel Error'}
ycols2 = { 
    'm_chmprthst': 'Speciated Mass\nRelative Error', 
    'n_prthst': 'Number\nRelative Error', 
    'm_prthst': 'Total Mass\nRelative Error', 
    'm_perchmhst': 'Species Bulk Mass\nRelative Error'}

for y_scale in ('log', 'lin'):
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        fig, axes = plt.subplots(1, 2, figsize=[4.0 * 2, 2.5 * 1], 
            dpi=100, sharex=False, sharey=False, squeeze=False)
        axes = np.array(axes)

        for ax, ycols, y_prtid in [(axes[0, 0], ycols1, 'p1'), (axes[0, 1], ycols2, 'p2')]:
            aggdf5 = aggdf4[aggdf4['metric'].isin(ycols)].reset_index(drop=True)
            aggdf6 = aggdf5.replace({'tuned': 'Tuned', 'plain': 'Plain', **ycols}).copy(deep=True)
            plt_cfg = {'fig': fig, 'ax': ax, **v_mplcfgs[f'paper.ppsim.grperr.horz.{y_scale}.{y_prtid}']}
            fig, ax = draw_matplotlib(plt_cfg, aggdf6)
            ax.set_ylim(3.6, -0.6)

        fig.subplots_adjust(wspace=0.5)
        for ax, tag_txt in zip(axes.ravel(), 'ab'):
            tag_axis(ax, f'$\\rm({tag_txt})$', fontsize=11, pad=[0.3, 0.5])

    pdfpath = f'{workdir}/{i_fig:02d}_ppsim_err{y_scale}_horz.pdf'
    fig.savefig(pdfpath, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')

fig

### Vertical Grouped Bar Plots

In [None]:
i_fig = 86
plt.ioff()

# Reading the experiment to fpidx specification configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

ycols = {
    'ccn_cdf': 'CCN Spectrum\nRelative Error', 
    'qs_pop': 'Scattering\nLog-Rel Error', 
    'qa_pop': 'Absorption\nLog-Rel Error', 
    'logfrznfrac_tmp': 'Frozen Fraction\nLog-Rel Error',
    'm_chmprthst': 'Speciated Mass\nRelative Error', 
    'n_prthst': 'Number\nRelative Error', 
    'm_prthst': 'Total Mass\nRelative Error', 
    'm_perchmhst': 'Per-Chem Mass\nRelative Error'}

aggdf5 = aggdf4[aggdf4['metric'].isin(ycols)].reset_index(drop=True)
aggdf6 = aggdf5.replace({'tuned': 'Tuned', 'plain': 'Plain', **ycols}).copy(deep=True)

for y_scale in ('log', 'lin'):
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        fig, axes = plt.subplots(1, 1, figsize=[9.6, 2.5], 
            dpi=100, sharex=False, sharey=False, squeeze=False)
        axes = np.array(axes)
        ax = axes[0, 0]
        plt_cfg = {'fig': fig, 'ax': ax, **v_mplcfgs[f'paper.ppsim.grperr.vert.{y_scale}']}
        fig, ax = draw_matplotlib(plt_cfg, aggdf6)
        
        ax.set_xlim(-0.6, 7.6)
        fig.subplots_adjust(wspace=0.3)

    pdfpath = f'{workdir}/{i_fig:02d}_ppsim_err{y_scale}_vert.pdf'
    fig.savefig(pdfpath, bbox_inches='tight')
    print(f'Finished writing {pdfpath}')

fig

### Non-Grouped Vertical Bar Plots

In [None]:
i_fig = 86
plt.ioff()

# Reading the experiment to fpidx specification configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

ycols = [
    'ccn_cdf', 'qs_pop', 'qa_pop', 'logfrznfrac_tmp',
    'm_chmprthst', 'n_prthst', 'm_prthst', 'm_perchmhst']

with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
    fig, axes = plt.subplots(1, 8, figsize=[1.8 * 8, 3.0 * 1], 
        dpi=100, sharex=False, sharey=False, squeeze=False)
    axes = np.array(axes)

    aggdf2 = aggdf1.replace({'tuned': 'Tuned', 'plain': 'Plain'}).copy(deep=True)
    for ycol, ax in zip(ycols, axes.ravel()):
        plt_cfg = {'fig': fig, 'ax': ax, **v_mplcfgs[f'paper.ppsim.{ycol}.err']}
        fig, ax = draw_matplotlib(plt_cfg, aggdf2)
        aa, bb = ax.get_xlim()
        cc, dd = (aa + bb) / 2, (bb - aa) / 2
        ax.set_xlim(cc - dd * 0.7, cc + dd * 0.7)

    fig.subplots_adjust(wspace=0.6)

    for ax, tag_txt in zip(axes.ravel(), 'abcdefgh'):
        tag_axis(ax, f'({tag_txt})', fontsize=11, pad=[0.3, 0.5])

pdfpath = f'{workdir}/{i_fig:02d}_ppsim_errlog_vert2.pdf'
fig.savefig(pdfpath, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

fig

## The Tuned vs. Plain Anecdotal Examples

In [None]:
i_fig = 87
plt.ioff()

chem_species = aero_cstscnv['chem_species']
eps_histmin = aero_cstscnv['eps_histmin']
eps_histmax = aero_cstscnv['eps_histmax']
tmprtr_inpmin = aero_cstscnv['tmprtr_inpmin']
tmprtr_inpmax = aero_cstscnv['tmprtr_inpmax']
d_histbinsum = aero_cstscnv['d_histbins']

# Reading the plot configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)
v_mplcfgs['ccn_cdf']['xlim'] = [eps_histmin * 0.9, eps_histmax * 1.1]
v_mplcfgs['frznfrac_tmp']['xlim'] = [tmprtr_inpmin - 1, tmprtr_inpmax + 1]

i_sel = np.array([[165, 58, 473]])
vrnts = ['orig', 'tuned', 'plain']

pats_rnm = {
    'tuned/{x_node}/orig': '{x_node}/orig',
    'tuned/{x_node}/rcnst': '{x_node}/tuned',
    'plain/{x_node}/rcnst': '{x_node}/plain'}
v_ydata2 = get_subdictrnmd(v_ydatas, pats_rnm)

figs_dict = dict()
i_seed = 0
n_figrows, n_figcols = i_sel.shape[1], 4
for i_figpage, i_samps in enumerate(i_sel):
    v_mpldatas = dict()
    for i_figrow, i_samp in enumerate(i_samps):
        v_ydata3 = {key: val[i_seed, i_samp, 0] for key, val in v_ydata2.items()}
        # The speciated mass data
        for i_figcol, vrnt in enumerate(vrnts):
            ax_idx = i_figrow * n_figcols + i_figcol
            for i_chem, chem in enumerate(chem_species):
                v_mpldatas[f'paper.ppsim.m_chmprthst.{vrnt}/{ax_idx}:{chem}:x'] = d_histbinsum
                v_mpldatas[f'paper.ppsim.m_chmprthst.{vrnt}/{ax_idx}:{chem}:y'] = v_ydata3[f'm_chmprthst/{vrnt}'][i_chem]
        
        # The particle count data        
        i_figcol = 3
        ax_idx = i_figrow * n_figcols + i_figcol
        for vrnt in vrnts:
            v_mpldatas[f'paper.ppsim.n_prthst/{ax_idx}:{vrnt}:x'] = d_histbinsum
            v_mpldatas[f'paper.ppsim.n_prthst/{ax_idx}:{vrnt}:y'] = np.squeeze(v_ydata3[f'n_prthst/{vrnt}'])
    
    v_mpldatashie = hie2deep(v_mpldatas, maxdepth=1)

    ########## Calling Matplotlib to Plot the Diagnostics ########## 
    fig, axes = None, None
    for ax_idx, ax_id in enumerate(v_mpldatashie):
        fig, axes = plot_mpl(data=v_mpldatashie[ax_id], fig=fig, 
            axes=axes, mplopts=v_mplcfgs[ax_id])
    
    # Sharing the y-axis limits between the three axes
    for i_figrow in range(n_figrows):
        axes_yshrd = axes[i_figrow, :2]
        ax_ylimlo = min(ax.get_ylim()[0] for ax in axes_yshrd)
        ax_ylimhi = max(ax.get_ylim()[1] for ax in axes_yshrd)
        for ax in axes_yshrd:
            ax.set_ylim(ax_ylimlo, ax_ylimhi)

    # Adding the top left april tag for axis identification
    with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
        for i_figrow in range(n_figrows):
            for i_figcol, tag_char in enumerate('abcd'):
                ax = axes[i_figrow, i_figcol]
                tag_axis(ax, f'({tag_char}$_{{{i_figrow+1}}}$)', 
                    fontsize=11, pad=(0.3, 0.6))

        for i_figcol, col_ttl in enumerate(['Input\n(Mass Distribution)', 
            'Simulated Reconst\n$\\textbf{with}$ Preprocessing\n(Mass Distribution)', 
            'Simulated Reconst\n$\\textbf{without}$ Preprocessing\n(Mass Distribution)', 
            'Number Distribution\nReconst Simulations']):
            print_axheader(axes[0, i_figcol], col_ttl, 'top', pad=4,
                fontsize=13.5, fontweight='bold')

        for i_figrow in range(n_figrows):
            row_ttl = f'Sample {i_figrow + 1}'
            print_axheader(axes[i_figrow, 0], row_ttl, 'left', 
                fontsize=14, fontweight='bold')

    figs_dict[f'{i_figpage}/diag'] = fig

pdfpath = f'{workdir}/{i_fig:02d}_ppsim_anec.pdf'
with PdfPages(pdfpath) as pdf:
    for ii_samp, fig3 in figs_dict.items():
        pdf.savefig(figure=fig3, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

i_fig += 1

fig

## The Tuned vs. Plain Q-Q Plots

In [None]:
i_fig = 88
plt.ioff()

# Reading the experiment to fpidx specification configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

n_q = 3999
q = np.linspace(0.0015, 0.9985, n_q)
assert q.shape == (n_q,)

nrml_qnts = norm.ppf(q)
assert nrml_qnts.shape == (n_q,)

rc_context = v_mplcfgs['rc_context'].copy()
rc_context.pop('axes.autolimit_mode')
with plt.rc_context(rc_context) as pltrcctx:
    n_figrows, n_figcols = 2, 3
    fig, axes = plt.subplots(n_figrows, n_figcols, figsize=[2.6 * n_figcols, 2.4 * n_figrows], 
        dpi=140, sharex=True, sharey=False, squeeze=False)
    axes = np.array(axes)

    for ax_idx, (y_col, y_name, y_desc, y_low, y_high, tag_txt) in enumerate([
        ('m_chmprthst', '$m$', 'Speciated Mass\nDistribution ($m$)', -1, 3, '(a)'), 
        ('n_prthst', '$n$', 'Number\nDistribution ($n$)', -1, 4, '(b)'), 
        (None, None, None, None, None, None), 
        ('m_prthstmag', '$u_1$', 'Speciated Mass\nMagnitudes ($u_1$)', -4, 6, '(d)'), 
        ('m_chmprtnrm', '$u_2$', 'Speciated Mass\nProportions ($u_2$)', -5, 5, '(c)'), 
        ('n_prthstnrm', '$u_3$', 'Normalized Number\nDistribution ($u_3$)', -4, 6, '(e)')]):
        i_figrow, i_figcol = ax_idx // n_figcols, ax_idx % n_figcols
        ax = axes[i_figrow, i_figcol]

        if y_col is None:
            ax.remove()
            continue

        y_nparr = sim_data[f'var/eval/raw/tuned/{y_col}']
        n_ychnls, n_ylen = y_nparr.shape[-2:]
        assert y_nparr.shape == (n_seeds, n_snrtsim, n_ychnls, n_ylen)

        y_nparr2 = (y_nparr - y_nparr.mean()) / y_nparr.std()
        assert y_nparr2.shape == (n_seeds, n_snrtsim, n_ychnls, n_ylen)

        y_qnts = np.quantile(y_nparr2.ravel(), q)
        assert y_qnts.shape == (n_q,)

        ax.scatter(nrml_qnts, y_qnts, marker='o', s=3, color='blue')
        y_lim1 = ax.get_ylim()

        y_line = nrml_qnts * y_qnts.std() + y_qnts.mean()
        ax.plot(nrml_qnts, y_line, ls='--', color='black', lw=1)

        # Adding the small identification text box
        props = dict(facecolor='none', edgecolor='none')
        ax.text(0.03, 0.97, y_desc, transform=ax.transAxes, fontsize=9,
            usetex=True, verticalalignment='top', bbox=props)

        if i_figrow == (n_figrows - 1):
            ax.set_xlabel('$\mathcal{N}(0,1)$ Quantiles')
        ax.set_ylabel(f'Standardized {y_name} Quantiles')

        ax.set_xlim(-3, 3)
        ax.set_ylim(y_low, y_high)
        ax.set_xticks([-3, -2, -1, 0, 1, 2, 3])

        tag_axis(ax, tag_txt, fontsize=11, pad=(0.3, 0.6))

    for ax, row_ttl in [(axes[0, 0], 'Without Preprocessing'), (axes[1, 0], 'With Preprocessing')]:
        print_axheader(ax, row_ttl, 'left', fontsize=14, pad=25, fontweight='bold')
            
    fig.subplots_adjust(wspace=0.35, hspace=0.3)

pdfpath = f'{workdir}/{i_fig:02d}_ppsim_qqplot.pdf'
fig.savefig(pdfpath, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

fig

# Experiment 8: Testing the New Pre-Processing Framework

Related Configs: 

  * `configs/02_adhoc/11_mlphist.yml`

  * `configs/02_adhoc/12_cnnhist.yml`

Goals: 

  * Following up on Experiment 7 with both MLP and CNN architectures.

Issues:

  * TBD

In [None]:
def get_exp8dashdata(hpdf, statdf, tabprfx):
    # Squeezing the stat column names by removing singular levels
    statdf = adjust_mtrcnames(statdf)
    stcol_longnames = statdf.columns.tolist()
    stcol_sqzdnames = squeeze_colnames(stcol_longnames, mindepth=2)
    statdf = statdf.rename(columns=dict(zip(stcol_longnames, stcol_sqzdnames)))

    # Dropping the quantile data to save on space :)
    drop_pats = ['q10', 'q25', 'q5', 'q75', 'q90', 'q95', 'median', '/kl:']
    keepcols = [col for col in statdf.columns if not any(x in col for x in drop_pats)]
    statdf = statdf[keepcols] 

    # Downcasting numerical types to save on space
    statdf = downcast_df(statdf)

    # `fpicols` will be an `fpidxgrp` to hp column mapping; each fpidxgrp 
    # is part of an ovat ablation defined by a single column.
    ii_drop = hpdf['fpidxgrp'].drop_duplicates().index
    hpdf2 = hpdf.loc[ii_drop].reset_index(drop=True)

    fpicols = dict()
    fpidxgrps = hpdf2['fpidxgrp']
    main_fpidx = fpidxgrps.iloc[0]
    for fpidx in fpidxgrps.iloc[1:]:
        hpdf3 = hpdf2[fpidxgrps.isin([fpidx, main_fpidx])]
        hpdf4 = drop_unqcols(hpdf3)
        hpdf5 = hpdf4.drop(columns=['fpidx', 'fpidxgrp'], errors='ignore')
        cols = hpdf5.columns.tolist()
        fpicols[fpidx] = cols[0] if ('nn/pp/magdim' not in cols) else 'nn/pp/magdim'

    aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')

    tab_ttl1 = f'{tabprfx.lower()}gnrl'
    tab_fpidxs1 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() if 'nn/pp' not in ovatcol]

    tab_ttl2 = f'{tabprfx.lower()}ppmag'
    tab_fpidxs2 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() 
        if ovatcol in ['nn/pp/magdim', 'nn/pp/magpnrm', 'nn/pp/magexp', 
                       'nn/pp/magshft', 'nn/pp/magscl', 'nn/pp/magscleps']]

    tab_ttl3 = f'{tabprfx.lower()}ppnrm'
    tab_fpidxs3 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() 
        if ovatcol in ['nn/pp/nrmeps', 'nn/pp/nrmexp', 'nn/pp/nrmshft', 'nn/pp/nrmscl', 
                       'nn/pp/nrmscleps', 'nn/pp/nrmscldim']]

    tab_ttl4 = f'{tabprfx.lower()}ppcnt'
    tab_fpidxs4 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() 
        if ovatcol in ['nn/pp/cnteps', 'nn/pp/cntexp', 'nn/pp/cntshft', 'nn/pp/cntscl', 
                       'nn/pp/cntscleps', 'nn/pp/cntscldim']]

    src_colsset = set(fpicols) 
    plt_colsset = set(tab_fpidxs1 + tab_fpidxs2 + tab_fpidxs3 + tab_fpidxs4)
    assert len(src_colsset - plt_colsset) == 0, f'unused ablations: {src_colsset - plt_colsset}'
    
    # Aggregating the data
    data = []
    for tab_ttl, tab_fpidxs in [(tab_ttl1, tab_fpidxs1), 
        (tab_ttl2, tab_fpidxs2), (tab_ttl3, tab_fpidxs3), (tab_ttl4, tab_fpidxs4)]:
        tab_idx = (hpdf['fpidxgrp'].isin(tab_fpidxs))
        hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
        stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)
        agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
            huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
        hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
        stcols_tab = agg_data['stcols']
        data.append((tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab))

    return data

def rnm_cols(colnames: list, renamecfg: dict):
    renamecfg = dict(renamecfg)
    col_renamer = renamecfg.get('columns', dict())
    col_replacer = renamecfg.get('colrplc', dict())
    colnames2 = [colnames] if isinstance(colnames, str) else colnames
    assert isinstance(colnames2, (list, tuple))
    assert all(isinstance(x, str) for x in colnames2)

    # Renaming the columns (whole)
    if col_renamer is not None:
        assert isinstance(col_renamer, dict)
        colnames3 = [col_renamer.get(col, col) for col in colnames2]
    else:
        colnames3 = colnames2[:]

    # Renaming the columns (by parts)
    if col_replacer is not None:
        colnames4 = []
        for col in colnames3:
            colsplt = col.split('/')
            colsplt = [col_replacer.get(val, val) for val in colsplt]
            colrp = '/'.join(colsplt)
            colnames4.append(colrp)
    else:
        colnames4 = colnames3[:]
    
    return colnames4[0] if isinstance(colnames, str) else colnames4

def append_relcols(hpstdf, n_lvl, i_lvl, name_ref, name_prd, name_rel, rel_type):
    cols_ref = [col for col in hpstdf.columns 
        if (col.count('/') == (n_lvl - 1)) and col.split('/')[i_lvl] == name_ref] 
    assert len(cols_ref) > 0, 'perhaps n_lvl or i_lvl is mis-specified since no ref cols exist'
    cols_prd = [col for col in hpstdf.columns 
        if (col.count('/') == (n_lvl - 1)) and col.split('/')[i_lvl] == name_prd] 

    drop_val = lambda col: '/'.join(col.split('/')[:i_lvl] + col.split('/')[i_lvl+1:])
    stdf_ref = hpstdf[cols_ref].rename(columns={col: drop_val(col) for col in cols_ref})
    stdf_prd = hpstdf[cols_prd].rename(columns={col: drop_val(col) for col in cols_prd})

    assert set(stdf_ref.columns) == set(stdf_prd.columns), dedent(f'''
        The column sets between the ref and prd dataframes do not match:
            missing columns: {set(stdf_ref.columns) - set(stdf_prd.columns)}
            extra columns: {set(stdf_prd.columns) - set(stdf_ref.columns)}''')
    stdf_prd = stdf_prd[stdf_ref.columns]

    assert all(stdf_ref.columns == stdf_prd.columns)

    ref_array, prd_array = stdf_ref.values, stdf_prd.values

    if rel_type == 'mean+ci':
        # Making sure the dataframe can be i_seed-factorized.
        df1, df2 = decomp_df(hpstdf[['fpidxgrp', 'noise_amp', 'rng_seed']], 
            [['fpidxgrp', 'noise_amp'], ['rng_seed']], validate=True)
        n_hps, n_seeds = df1.shape[0], df2.shape[0]

        n_cols = ref_array.shape[1]
        assert ref_array.shape == (n_hps * n_seeds, n_cols)
        assert stdf_prd.shape == (n_hps * n_seeds, n_cols)

        n_boot = 10 * n_seeds
        np_random = np.random.RandomState(seed=12345)
        i_btstrpref = (n_seeds * np_random.rand(n_boot * n_seeds)).astype(int)
        i_btstrpprd = (n_seeds * np_random.rand(n_boot * n_seeds)).astype(int)
        with torch.no_grad():
            # Getting a bunch of bootstrap samples of the reference mean.
            ref_array2 = torch.from_numpy(ref_array).reshape(n_hps, n_seeds, n_cols)
            assert ref_array2.shape == (n_hps, n_seeds, n_cols)

            ref_array3 = ref_array2[:, i_btstrpref, :].reshape(n_hps, n_boot, n_seeds, n_cols)
            assert ref_array3.shape == (n_hps, n_boot, n_seeds, n_cols)

            ref_array4 = ref_array3.mean(axis=-2)
            assert ref_array4.shape == (n_hps, n_boot, n_cols)

            ref_array5 = ref_array4.sort(axis=-2).values
            assert ref_array5.shape == (n_hps, n_boot, n_cols)

            ref_meansamps = ref_array5[:, ::10, :]
            assert ref_meansamps.shape == (n_hps, n_seeds, n_cols)

            # Getting a bunch of bootstrap samples of the prediction mean.
            prd_array2 = torch.from_numpy(prd_array).reshape(n_hps, n_seeds, n_cols)
            assert prd_array2.shape == (n_hps, n_seeds, n_cols)

            prd_array3 = prd_array2[:, i_btstrpprd, :].reshape(n_hps, n_boot, n_seeds, n_cols)
            assert prd_array3.shape == (n_hps, n_boot, n_seeds, n_cols)

            prd_array4 = prd_array3.mean(axis=-2)
            assert prd_array4.shape == (n_hps, n_boot, n_cols)

            prd_array5 = prd_array4.sort(axis=-2).values
            assert prd_array5.shape == (n_hps, n_boot, n_cols)

            prd_meansamps = prd_array5[:, ::10, :]
            assert prd_meansamps.shape == (n_hps, n_seeds, n_cols)

            ref_samps = ref_meansamps.detach().cpu().numpy().reshape(n_hps * n_seeds, n_cols)
            prd_samps = prd_meansamps.detach().cpu().numpy().reshape(n_hps * n_seeds, n_cols)
    elif rel_type == 'mean':
        # Making sure the dataframe can be i_seed-factorized.
        df1, df2 = decomp_df(hpstdf[['fpidxgrp', 'noise_amp', 'rng_seed']], 
            [['fpidxgrp', 'noise_amp'], ['rng_seed']], validate=True)
        n_hps, n_seeds = df1.shape[0], df2.shape[0]

        n_cols = ref_array.shape[1]
        assert ref_array.shape == (n_hps * n_seeds, n_cols)
        assert stdf_prd.shape == (n_hps * n_seeds, n_cols)

        # Getting a bunch of identical samples of the reference mean.
        ref_array2 = ref_array.reshape(n_hps, n_seeds, n_cols)
        assert ref_array2.shape == (n_hps, n_seeds, n_cols)

        ref_array3 = ref_array2.mean(axis=1, keepdims=True)[:, [0] * n_seeds, :]
        assert ref_array3.shape == (n_hps, n_seeds, n_cols)

        ref_array4 = ref_array3.reshape(n_hps * n_seeds, n_cols)
        assert ref_array4.shape == (n_hps * n_seeds, n_cols)

        # Getting a bunch of identical samples of the prediction mean.
        prd_array2 = prd_array.reshape(n_hps, n_seeds, n_cols)
        assert prd_array2.shape == (n_hps, n_seeds, n_cols)

        prd_array3 = prd_array2.mean(axis=1, keepdims=True)[:, [0] * n_seeds, :]
        assert prd_array3.shape == (n_hps, n_seeds, n_cols)

        prd_array4 = prd_array3.reshape(n_hps * n_seeds, n_cols)
        assert prd_array4.shape == (n_hps * n_seeds, n_cols)

        ref_samps, prd_samps = ref_array4, prd_array4
    elif rel_type == 'samp':
        ref_samps, prd_samps = ref_array, prd_array
    else:
        raise ValueError(f'undefined rel_type={rel_type}')

    rel_array = 2.0 * np.abs(ref_samps - prd_samps) / (np.abs(ref_samps) + np.abs(prd_samps))
    cols_rel = ['/'.join(col.split('/')[:i_lvl] + [name_rel] + col.split('/')[i_lvl:])
        for col in stdf_prd.columns.tolist()]
    stdf_rel = pd.DataFrame(rel_array, columns=cols_rel, index=stdf_ref.index)

    return stdf_rel

def rnm_fmt(pat_inp, pat_out, query):
    pattern = re.escape(pat_inp)
    pattern = re.sub(r'\\\{(\w+)\\\}', r'(?P<\1>.*)', pattern)
    matchres = re.match(pattern, query)
    return None if matchres is None else pat_out.format(**matchres.groupdict())

def get_colrnmngs(df, rnm_patdict):
    colrnmngs = dict()
    for col in df.columns:
        for pat_inp, pat_out in rnm_patdict.items():
            col_rnmd = rnm_fmt(pat_inp, pat_out, col)
            if (col_rnmd is not None) and (col not in colrnmngs):
                colrnmngs[col] = col_rnmd
            elif (col_rnmd is not None) and (col in colrnmngs):
                assert colrnmngs[col] == col_rnmd, dedent(f'''
                    Conflicting renamings for "{col}":
                        Name 1: "{colrnmngs[col]}"
                        Name 2: "{col_rnmd}"''')

    return colrnmngs

In [None]:
pltcache_path = f'./{workdir}/32_pltdata.h5'
remake_pltdata = not exists(pltcache_path)

if remake_pltdata:
    smrypath = '../summary/08_mlphist.h5'
    get_h5du(smrypath, verbose=True, detailed=False)
    data = load_h5data(smrypath)
    hpdf_mlp = data['hp']
    statdf_mlp = data['stat']

    smrypath = '../summary/09_cnnhist.h5'
    get_h5du(smrypath, verbose=True, detailed=False)
    data = load_h5data(smrypath)
    hpdf_cnn = data['hp']
    statdf_cnn = data['stat']

In [None]:
if remake_pltdata:
    # Making sure eps_sig*=0 can be displayed on a log scale
    for col in ['nn/pp/cntscleps', 'nn/pp/magscleps', 'nn/pp/nrmscleps']:
        hpdf_mlp[col] = hpdf_mlp[col].cat.rename_categories({0.0: 1e-6})
        hpdf_cnn[col] = hpdf_cnn[col].cat.rename_categories({0.0: 1e-6})

    # These columns are wreaking havoc on the OVAT detection for 
    # `nn/pp/magdim`. It's better to forget about them. 
    prblmtc_cols = ['nn/enc/n_chnlsmag', 'nn/enc/n_lenmag', 
        'nn/dec/n_chnlsmag', 'nn/dec/n_lenmag']
    for col in prblmtc_cols:
        hpdf_mlp.loc[hpdf_mlp['nn/pp/magdim'] != 'n_bins', col] = hpdf_mlp[col].iloc[0]
        hpdf_cnn.loc[hpdf_cnn['nn/pp/magdim'] != 'n_bins', col] = hpdf_cnn[col].iloc[0]
    
    data_mlp = get_exp8dashdata(hpdf_mlp, statdf_mlp, tabprfx='MLP')
    data_cnn = get_exp8dashdata(hpdf_cnn, statdf_cnn, tabprfx='CNN')
    data_real = data_mlp + data_cnn

## Matplotlib Curves

### Loading the Real Training Data

In [None]:
if remake_pltdata:
    ymlpath_real = f'{workdir}/29_vaehist.yml'
    with open(ymlpath_real, 'r') as fp:
        dash_cfgreal = ruyaml.load(fp, ruyaml.RoundTripLoader)

    rnm_cfgreal = dash_cfgreal.pop('rename')
    dash_cfgreal['data/global']['smry/selcol'] = 'loss/net/mean'
    dash_cfgreal['bokeh/global']['ycol'] = 'loss/net'

    dashdata_real = get_dashdata(data_real, dash_cfgreal, write_yml=False)['data']

### Loading the Simulated Data

In [None]:
if remake_pltdata:
    cachepath = './18_ccnmieinp/04_sim.h5'
    save_data = load_h5data(cachepath)
    hpdf = save_data['hpdf']
    statdf = save_data['statdf']

    # Making sure eps_sig*=0 can be displayed on a log scale
    hpdf.loc[hpdf['cntscleps'] == 0.0, 'cntscleps'] = 1e-6
    hpdf.loc[hpdf['magscleps'] == 0.0, 'magscleps'] = 1e-6
    hpdf.loc[hpdf['nrmscleps'] == 0.0, 'nrmscleps'] = 1e-6

    hpdf['fpidxgrp'] = hpdf.groupby(hpdf.columns.tolist()).ngroup()
    hpdf['fpidx'] = hpdf['fpidxgrp']
    hpdf = hpdf.astype('category')

    # Adding the relative difference columns
    hpstdf = pd.concat([hpdf, statdf], axis=1)
    statdf1 = append_relcols(hpstdf, n_lvl=4, i_lvl=2, name_ref='train', 
        name_prd='test', name_rel='trn2tst', rel_type='mean+ci')
    statdf2 = append_relcols(hpstdf, n_lvl=4, i_lvl=2, name_ref='train', 
        name_prd='comb', name_rel='trn2cmb', rel_type='mean+ci')
    statdf3 = append_relcols(hpstdf, n_lvl=4, i_lvl=3, name_ref='data', 
        name_prd='normal', name_rel='data2nrml', rel_type='mean+ci')
    statdf = pd.concat([statdf, statdf1, statdf2, statdf3], axis=1)

In [None]:
if remake_pltdata:
    aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
    # Aggregating the data
    agg_data = get_aggdf(hpdf, statdf, xcol='noise_amp', 
        huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
    stcols_tab = agg_data['stcols']
    data_sim = [('simpp', hpdf_tabagg, stdf_tabagg, stcols_tab)]

In [None]:
if remake_pltdata:
    ymlpath_sim = './18_ccnmieinp/05_simdash.yml'
    with open(ymlpath_sim, 'r') as fp:
        dash_cfgsim = ruyaml.load(fp, ruyaml.RoundTripLoader)
    rnm_cfgsim = dash_cfgsim.pop('rename')
    dashdata_sim = get_dashdata(data_sim, dash_cfgsim, write_yml=False)['data']

### Getting the Real Training Data

In [None]:
allxcols = [
    'magdim', 'magpnrm', 'magexp', 'magshft', 'magscl', 'magscleps', 
    'nrmeps', 'nrmexp', 'nrmshft', 'nrmscl', 'nrmscleps', 'nrmscldim', 
    'cnteps', 'cntexp', 'cntshft', 'cntscl', 'cntscleps', 'cntscldim']
    
if remake_pltdata:
    aggdf_ablsreal = dict()
    for xcol in allxcols:        
        col_patsreal = {
            'perf/aerosol/test/{ymtrc}/mean/{stat}': '{ymtrc}/{stat}',
            'nn/pp/{varbl}': '{varbl}', 'model': 'model', 
            'noise_amp': 'noise_amp', xcol: xcol}

        aggdfs_lst = []
        for tababl_id, tababl_data in dashdata_real.items():
            tab_id, abl_id = tababl_id.split('/')
            if (abl_id != xcol):
                continue
            # Example:
            #   tababl_id = 'mlpppmag/magexp'
            #   tab_id, abl_id = 'mlpppmag', 'magexp'
            #   tabname = 'MLP General'
            #   ablname = 'Mag Exponent'
            #   ablhpcols = ['nn/pp/magexp']
            tabname, ablname, abldf, ablhpcols, stcols, bkfigcfg, mpldict = tababl_data
            abldf2 = abldf.copy(deep=True)
            abldf2.insert(2, 'model', tab_id[:3])
            abldf2.insert(3, 'noise_amp', 0.0)
            abldf2 = abldf2.drop(columns=['fpidx', 'fpidxgrp', 'epoch'])
            aggdfs_lst.append(abldf2)
        
        assert len(aggdfs_lst) == 2, dedent(f'''
            The "mlp" and "cnn" tab data must be present for the 
            "{xcol}" ablation. However, "{len(aggdfs_lst)}" tab 
            data were available.''')

        aggdf1_real = pd.concat(aggdfs_lst, axis=0, ignore_index=True)
        # Removing categorical data types
        for col, col_dtype in dict(aggdf1_real.dtypes).items():
            if col_dtype == 'category':
                aggdf1_real[col] = aggdf1_real[col].tolist()

        cols_rnmsreal = get_colrnmngs(aggdf1_real, col_patsreal)
        aggdf_real = aggdf1_real[list(cols_rnmsreal)].rename(columns=cols_rnmsreal)
        aggdf_ablsreal[xcol] = aggdf_real

### Getting the Simulated Data

In [None]:
if remake_pltdata:
    aggdf_ablssim = dict()
    for xcol in allxcols:
        col_patssim = {
            '{ymtrc}/one/test/data/{stat}': '{ymtrc}/{stat}',
            'model': 'model', 'noise_amp': 'noise_amp', xcol: xcol}

        aggdf_sim1 = dashdata_sim[f'simpp/{xcol}'][2].copy(deep=True)
        aggdf_sim1.insert(0, 'model', 'sim')
        aggdf_sim1 = aggdf_sim1.drop(columns=['fpidx', 'fpidxgrp'])

        # Removing categorical data types
        for col, col_dtype in dict(aggdf_sim1.dtypes).items():
            if col_dtype == 'category':
                aggdf_sim1[col] = aggdf_sim1[col].tolist()

        cols_rnmssim = get_colrnmngs(aggdf_sim1, col_patssim)
        aggdf_sim = aggdf_sim1[list(cols_rnmssim)].rename(columns=cols_rnmssim)

        aggdf_ablssim[xcol] = aggdf_sim

### Combining and Plotting the Data 

In [None]:
if remake_pltdata:
    aggdf_abls = dict()
    for xcol in allxcols:
        aggdf_real = aggdf_ablsreal[xcol]
        aggdf_sim = aggdf_ablssim[xcol]

        assert set(aggdf_real.columns) == set(aggdf_sim.columns), dedent(f'''
            Extra columns: {set(aggdf_sim.columns) - set(aggdf_real.columns)}
            Missing columns: {set(aggdf_real.columns) - set(aggdf_sim.columns)}''')

        aggdf = pd.concat([aggdf_real, aggdf_sim], axis=0, ignore_index=True)
        aggdf_abls[xcol] = aggdf
    
    save_h5data(aggdf_abls, pltcache_path, driver='core')
else:
    aggdf_abls = load_h5data(pltcache_path)

### Matplotlib Ablation Booklets

In [None]:
plt.ioff()
ycols = ['ccn_err', 'opt_err', 'loginp_err', 'm_chmrelerr', 'n_relerr']

abl2fig = dict()
for xcol in allxcols:
    aggdf = aggdf_abls[xcol]

    n_ycols = len(ycols)
    fig, axes = plt.subplots(1, n_ycols, figsize=[2.6 * n_ycols, 3.0], 
      dpi=100, sharex=True, sharey=False)

    with open(f'./{workdir}/31_ppmpl.yml', 'r') as fp:
        mpl_cfgdict1 = ruyaml.load(fp, ruyaml.RoundTripLoader)
        mpl_cfgdict2 = parse_refs(mpl_cfgdict1, trnsfrmtn='hie', pathsep=' -> ')
        mpl_cfgdict3 = hie2deep(mpl_cfgdict2)['huegrpd']
        mplglbls = mpl_cfgdict3['mplglbls']
        yspecs = mpl_cfgdict3['yspecs']
        xspecs = mpl_cfgdict3['xspecs']
        valrplcs = mpl_cfgdict3['valrplcs']

    aggdf = aggdf.replace(valrplcs)
    for ycol, ax in zip(ycols, axes):
        plt_cfg = dict(xcol=xcol, ycol=ycol, fig=fig, ax=ax)
        yspec = yspecs[ycol] if ycol in yspecs else dict()
        xspec = xspecs[xcol] if xcol in xspecs else dict()
        plt_cfg = {**mplglbls, **plt_cfg, **yspec, **xspec}
        fig, ax = draw_matplotlib(plt_cfg, aggdf)

    fig.subplots_adjust(wspace=0.35)
    leg_handles, leg_lables = ax.get_legend_handles_labels()
    fig.legend(leg_handles, leg_lables, loc='lower center', bbox_to_anchor=[0.5, -0.17], ncol=7)
    abl2fig[xcol] = fig

fig

In [None]:
with PdfPages(f'./{workdir}/33_ppabls.pdf') as pdf:
    for xcol, fig in abl2fig.items():
        pdf.savefig(figure=fig, bbox_inches='tight')

### Matplotlib Trendy Ablations

In [None]:
plt.ioff()
xy_combos = [
    ('magexp', 'm_chmrelerr'), ('magdim', 'ccn_err'), ('nrmexp', 'loginp_err'), ('cntshft', 'n_relerr'),   ('nrmexp', 'ccn_err'), 
    ('magdim', 'm_perchmrelerr'), ('magexp', 'opt_err'), ('magscl', 'ccn_err'), ('magexp', 'm_perchmrelerr'), ('magpnrm', 'opt_err'),   
    ('magscleps', 'ccn_err'),  ('cntscl', 'loginp_err'), ('cntexp', 'ccn_err'), ('cntscldim', 'm_chmrelerr'),('magscleps', 'n_relerr')]

n_figrows, n_figcols = 3, 5
fig, axes = plt.subplots(n_figrows, n_figcols, 
    figsize=[2.6 * n_figcols, 3.0 * n_figrows], 
    dpi=100, sharex=False, sharey=False)
axes = np.array(axes).ravel()

for i_ax, (xcol, ycol) in enumerate(xy_combos):
    aggdf = aggdf_abls[xcol]
    ax = axes[i_ax]

    with open(f'./{workdir}/31_ppmpl.yml', 'r') as fp:
        mpl_cfgdict1 = ruyaml.load(fp, ruyaml.RoundTripLoader)
        mpl_cfgdict2 = parse_refs(mpl_cfgdict1, trnsfrmtn='hie', pathsep=' -> ')
        mpl_cfgdict3 = hie2deep(mpl_cfgdict2)['main']
        mplglbls = mpl_cfgdict3['mplglbls']
        yspecs = mpl_cfgdict3['yspecs']
        xspecs = mpl_cfgdict3['xspecs']
        valrplcs = mpl_cfgdict3['valrplcs']

    aggdf = aggdf.replace(valrplcs)
    plt_cfg = dict(xcol=xcol, ycol=ycol, fig=fig, ax=ax)
    yspec = yspecs[ycol] if ycol in yspecs else dict()
    xspec = xspecs[xcol] if xcol in xspecs else dict()
    plt_cfg = {**mplglbls, **plt_cfg, **yspec, **xspec}
    fig, ax = draw_matplotlib(plt_cfg, aggdf)

fig.subplots_adjust(wspace=0.35, hspace=0.25)
leg_handles, leg_lables = ax.get_legend_handles_labels()
fig.legend(leg_handles, leg_lables, loc='lower center', bbox_to_anchor=[0.5, 0.01], ncol=7)

fig.savefig(f'./{workdir}/36_ppabls.pdf', bbox_inches='tight')
fig

In [None]:
plt.ioff()
xy_combos = [
    ('magdim', 'ccn_err'),        ('magexp', 'm_chmrelerr'), ('cntshft', 'n_relerr'),   ('nrmexp', 'ccn_err'),        ('magpnrm', 'opt_err'),   
    ('magdim', 'm_perchmrelerr'), ('magexp', 'opt_err'),     ('magscl', 'ccn_err'),     ('magexp', 'm_perchmrelerr'), ('cntscldim', 'm_chmrelerr')]

rc_cntxt = {
    'figure.max_open_warning': 0, 'font.family': 'serif', 'text.usetex': True,
    'text.latex.preamble': r'\usepackage{amsmath} \usepackage{amssymb}'}

with plt.style.context('default') as psc, plt.rc_context(rc_cntxt) as pltrcctx:
    n_figrows, n_figcols = 2, 5
    fig, axes = plt.subplots(n_figrows, n_figcols, 
        figsize=[1.8 * n_figcols, 2.8 * n_figrows], 
        dpi=100, sharex=False, sharey=True)
    axes = np.array(axes).ravel()

    for i_ax, (xcol, ycol) in enumerate(xy_combos):
        aggdf = aggdf_abls[xcol].copy(deep=True)
        ax = axes[i_ax]

        with open(f'./{workdir}/31_ppmpl.yml', 'r') as fp:
            mpl_cfgdict1 = ruyaml.load(fp, ruyaml.RoundTripLoader)
            mpl_cfgdict2 = parse_refs(mpl_cfgdict1, trnsfrmtn='hie', pathsep=' -> ')
            mpl_cfgdict3 = hie2deep(mpl_cfgdict2)['main']
            mplglbls = mpl_cfgdict3['mplglbls2']
            yspecs = mpl_cfgdict3['yspecs2']
            xspecs = mpl_cfgdict3['xspecs2']
            valrplcs = mpl_cfgdict3['valrplcs2']

        aggdf = aggdf.replace(valrplcs)
        for ycol2, stat in product(
            ['ccn_err', 'm_chmrelerr', 'm_perchmrelerr', 'm_relerr', 
            'n_relerr', 'opt_err', 'qa_relerr', 'qs_relerr'],
            ['mean', 'low', 'high']):
            aggdf[f'{ycol2}/{stat}'] = aggdf[f'{ycol2}/{stat}'] / 2

        plt_cfg = dict(xcol=xcol, ycol=ycol, fig=fig, ax=ax)
        yspec = yspecs[ycol] if ycol in yspecs else dict()
        xspec = xspecs[xcol] if xcol in xspecs else dict()
        plt_cfg = {**mplglbls, **plt_cfg, **yspec, **xspec}
        fig, ax = draw_matplotlib(plt_cfg, aggdf)

    for i_ax, ax in enumerate(axes):
        tag_text = f'({"abcdefghij"[i_ax]})'
        tag_axis(ax, tag_text, fontsize=11)

    fig.subplots_adjust(wspace=0.25, hspace=0.35)
    leg_handles, leg_lables = ax.get_legend_handles_labels()
    fig.legend(leg_handles, leg_lables, loc='lower center', 
        bbox_to_anchor=[0.5, -0.035], edgecolor='black', ncol=7)

fig.savefig(f'./{workdir}/36_ppablsfnlzd.pdf', bbox_inches='tight')
fig

## Bokeh Dahboards

In [None]:
smrypath = '../summary/08_mlphist.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf_mlp = data['hp']
statdf_mlp = data['stat']

smrypath = '../summary/09_cnnhist.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf_cnn = data['hp']
statdf_cnn = data['stat']

### Training Curves

In [None]:
ymlpath = f'{workdir}/27_vaehist.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/28_vaehist.html')
save(fulllayout, title=dashdata['header'])

### Ablation Curves

In [None]:
ymlpath = f'{workdir}/29_vaehist.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/30_vaehist.html')
save(fulllayout, title=dashdata['header'])

# Experiment 1: Hyper-Parameter Ablations on the Histogram VAE

Related Configs: 

  * `configs/02_adhoc/01_mlphist.yml`

  * `configs/02_adhoc/02_cnnhist.yml`

Goals: 

  * Quick and dirty OVAT experiment of most hyper-parameters in the MLP and CNN architectures.

Issues:

  * The evaluation metrics' pre-processing is not the same as the training. This was a poor decision.

  * Nothing other than reconstruction is shown here.

In [None]:
smrypath = '../summary/01_vaehist.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
data = []
for arch in ('mlpenc', 'cnnenc'):
    tab_idx = (hpdf['nn/enc/type'] == arch)
    hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
    stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)
    tab_ttl = {'mlpenc': 'MLP Architecture', 'cnnenc': 'CNN Architecture'}[arch]

    # Aggregating the data
    agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
        huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
    stcols_tab = agg_data['stcols']
    
    data.append((tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab))

## Training Curves

In [None]:
ymlpath = f'{workdir}/01_vaehist.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/02_vaehist.html')
save(fulllayout, title=dashdata['header'])

## Ablation Curves

In [None]:
ymlpath = f'{workdir}/03_vaehist.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/04_vaehist.html')
save(fulllayout, title=dashdata['header'])

# Experiment 2: KL Weight Study on the Histogram VAE

Related config: `configs/02_adhoc/04_klstudy.yml`

Goals: 

  * An initial experiment for studying the effect of KL mu and sigma weights on the newly implemented performance metrics.
    
  * The performance metric categories were congestion, proximity, realism, and curvature.

Issues:

  * The realism metric was too noisy; the sliced wasserstein was only using 10 slices.

  * In the next run, I decided to do some importance sampling 
        
    * I used the test data to define a set of principal components, and emphasized those components in the slices.

  * The frequency of the evaluation was too high; the metric calculation time was more than quadruple the training time.

  * During taking the runs, the config was changed; the `eval/underscale` samples exist in some runs and not the others. 
    
    * None of the `eval/underscale` samples are important for any reason.

    * However, this screwed up the ovat group and column detection.

  * One of the runs could not fit 25 seeds into GPU RAM, so it was split into two runs with 24 and 1 seeds each.

In [None]:
smrypath = '../summary/02_klstudy.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

In [None]:
i1 = statdf['rng_seed'] <= 2300
hpdf = hpdf.loc[i1].reset_index(drop=True)
statdf = statdf.loc[i1].reset_index(drop=True)

hpdf = hpdf[[col for col in hpdf.columns if not col.startswith('eval/underscale')]]

In [None]:
# Squeezing the stat column names by removing singular levels  
stcol_longnames = statdf.columns.tolist()
stcol_sqzdnames = squeeze_colnames(stcol_longnames, mindepth=2)
statdf = statdf.rename(columns=dict(zip(stcol_longnames, stcol_sqzdnames)))

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cuda:0')
data = []
for ltnt_dim in (2, 10):
    tab_idx = (hpdf['nn/ltnt/dim'] == ltnt_dim)
    hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
    stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)
    tab_ttl = f'{ltnt_dim}-Dimensional'

    # Aggregating the data
    agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
        huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
    stcols_tab = agg_data['stcols']
    
    data.append((tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab))

## Training Curves

In [None]:
ymlpath = f'{workdir}/05_klstudy.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/06_klstudy.html')
save(fulllayout, title=dashdata['header'])

## Ablation Curves

In [None]:
ymlpath = f'{workdir}/03_vaehist.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/04_vaehist.html')
save(fulllayout, title=dashdata['header'])

# Experiment 3: KL Weight Study on the Histogram VAE

Related config: `configs/02_adhoc/05_klstudy.yml`

Goals: 

  * The second round of studying the effect of KL mu and sigma weights on the newly implemented performance metrics.
    
  * See Experiment 2 for previous issues.

Issues:

  * The realism performance was peaking at both KL weights of 0.001. 
  
    * This happens to be the central set of values for the OVAT-style experiment.

      * We did an OVAT sweep on the KL mu and sigma weights in each latent dimension.

    * I got suspicious that we hit the jackpot without any tuning.
      
      * Maybe, this peak performance may have been related to the fact that the mu and sigma weights are the same for this particular run.
    
    * To check for this, I performed another round of experiments where the mu and sigma KL weights were identical all the time.

In [None]:
smrypath = '../summary/03_klstudy.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

In [None]:
# Squeezing the stat column names by removing singular levels  
stcol_longnames = statdf.columns.tolist()
stcol_sqzdnames = squeeze_colnames(stcol_longnames, mindepth=2)
statdf = statdf.rename(columns=dict(zip(stcol_longnames, stcol_sqzdnames)))

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
data = []

for tab_ttl, inc_spec, sort_cols in [
    ('KL Weight (Mu)', {'cri/kl/sig/w': 0.001}, ['cri/kl/mu/w', 'nn/ltnt/dim']), 
    ('KL Weight (Sigma)', {'cri/kl/mu/w': 0.001}, ['cri/kl/sig/w', 'nn/ltnt/dim'])]:
    
    tab_idx = get_dfidxs(hpdf, inc_spec)
    hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
    stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)

    tab_idx2 = hpdf_tab.sort_values(sort_cols, kind='stable').index
    hpdf_tab = hpdf_tab.loc[tab_idx2, :].reset_index(drop=True)
    stdf_tab = stdf_tab.loc[tab_idx2, :].reset_index(drop=True)

    # Aggregating the data
    agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
        huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
    stcols_tab = agg_data['stcols']

    data.append((tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab))

## Training Curves

In [None]:
ymlpath = f'{workdir}/07_klstudy.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/08_klstudy.html')
save(fulllayout, title=dashdata['header'])

## Ablation Curves

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
hpdf_tab, statdf_tab = hpdf, statdf
hpdf_tab = hpdf_tab.rename(columns={'nn/ltnt/dim': 'nn/latent/dim'})
agg_data = get_aggdf(hpdf_tab, statdf_tab, xcol='epoch', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
data2 = [('All Dimensions', agg_data['hpdf'], agg_data['stdf'], agg_data['stcols'])]

In [None]:
ymlpath = f'{workdir}/09_klstudy.yml'
dashdata = get_dashdata(data2, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/10_klstudy.html')
save(fulllayout, title=dashdata['header'])

See the following plot to get a sense of the TV and Hellinger values.

In [None]:
with plt.style.context('dark_background'):
    fig, axes = plt.subplots(2, 3, figsize=(3*3.0, 2*2.4), sharex=True, sharey=True, dpi=100)
    axes = np.array(axes).ravel()

df_rowdicts = []
x = np.linspace(-5, 12, 10000)
for ax_idx, loc in enumerate((1, 2, 3, 4, 5, 6)):
    p1 = norm.pdf(x, loc=0)
    p2 = norm.pdf(x, loc=loc)
    tv_trg = 0.5 * np.abs(p1 - p2).sum() * x.ptp() / x.size
    h2_trg = 1 - np.sqrt(p1 * p2).sum() * x.ptp() / x.size
    h1_trg = np.sqrt(h2_trg)
    row_dict = {'Delta Mu': loc, 'TV': tv_trg, 
        'Hellinger': h1_trg, 'Hellinger-Squared': h2_trg}
    df_rowdicts.append(row_dict)

    ax = axes[ax_idx]
    ax.plot(x, p1, color='blue', lw=3)
    ax.plot(x, p2, color='red', lw=3)

    ax.text(0.95, 0.70, f'$\Delta\mu={loc}$\nTV={tv_trg:0.2f}\nH1={h1_trg:0.2f}',
        verticalalignment='bottom', horizontalalignment='right',
        transform=ax.transAxes, color='white', fontsize=10)

fig.set_tight_layout(True)

pd.DataFrame(df_rowdicts).round(3)

# Experiment 4: KL Weight Study on the Histogram VAE

Related Configs:

  * `configs/02_adhoc/06_klstudy.yml`

  * `configs/02_adhoc/07_klstudy.yml`

Goals: 

  * The third round of studying the effect of KL mu and sigma weights on the newly implemented performance metrics.
    
  * I used identical KL mu and sigma weights, and studied the unified KL weight effect.

In [None]:
smrypath = '../summary/04_klstudy.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

# Squeezing the stat column names by removing singular levels  
stcol_longnames = statdf.columns.tolist()
stcol_sqzdnames = squeeze_colnames(stcol_longnames, mindepth=2)
statdf = statdf.rename(columns=dict(zip(stcol_longnames, stcol_sqzdnames)))

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')

tab_ttl = 'KL Weight'
tab_idx = hpdf.sort_values(['nn/ltnt/dim', 'cri/kl/w'], kind='stable').index
hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)

# Aggregating the data
agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
stcols_tab = agg_data['stcols']

data = [(tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab)]

## Training Curves

In [None]:
ymlpath = f'{workdir}/11_klstudy.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/12_klstudy.html')
save(fulllayout, title=dashdata['header'])

## Ablation Curves

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
hpdf_tab, statdf_tab = hpdf, statdf
hpdf_tab = hpdf_tab.rename(columns={'nn/ltnt/dim': 'nn/latent/dim'})
agg_data = get_aggdf(hpdf_tab, statdf_tab, xcol='epoch', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
data2 = [('All Dimensions', agg_data['hpdf'], agg_data['stdf'], agg_data['stcols'])]

In [None]:
ymlpath = f'{workdir}/13_klstudy.yml'
dashdata = get_dashdata(data2, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/14_klstudy.html')
save(fulllayout, title=dashdata['header'])

# Experiment 5: KL Weight Study on the Histogram VAE

Related config: `configs/02_adhoc/08_klstudy.yml`

Goals: 

  * The fourth round of studying the effect of KL mu and sigma weights on the newly implemented performance metrics.
    
  * See Experiment 4 for previous issues.

Issues:

  * This is a similar config to `configs/02_adhoc/05_klstudy.yml`

    * The KL mu and sigma weights are sweeped over independently at a base value of 0.001
    
  * There was a bug in TV calculation in the previous run. 
  
    * That's the main reason I took this run. 
  
  * I made a few other adjustments as well
    
    * The metric evaluation reductions were changed to `mean` instead of `sum`.
    
    * The Hellinger's TV approximantion is now capped at `1.0`.
    
    * A legend labeling bug in the plots was fixed in this revision.
    
    * Noisy reconstruction evaluation and metrics were added here.
    
    * `10th` and `100th` neighbors were added to congestion metrics.

In [None]:
smrypath = '../summary/05_klstudy.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

In [None]:
# Squeezing the stat column names by removing singular levels
statdf = adjust_mtrcnames(statdf)
stcol_longnames = statdf.columns.tolist()
stcol_sqzdnames = squeeze_colnames(stcol_longnames, mindepth=2)
statdf = statdf.rename(columns=dict(zip(stcol_longnames, stcol_sqzdnames)))

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
data = []

for tab_ttl, inc_spec, sort_cols in [
    ('KL Weight (Mu)', {'cri/kl/sig/w': 0.001}, ['cri/kl/mu/w', 'nn/ltnt/dim']), 
    ('KL Weight (Sigma)', {'cri/kl/mu/w': 0.001}, ['cri/kl/sig/w', 'nn/ltnt/dim'])]:
    
    tab_idx = get_dfidxs(hpdf, inc_spec)
    hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
    stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)

    tab_idx2 = hpdf_tab.sort_values(sort_cols, kind='stable').index
    hpdf_tab = hpdf_tab.loc[tab_idx2, :].reset_index(drop=True)
    stdf_tab = stdf_tab.loc[tab_idx2, :].reset_index(drop=True)

    # Aggregating the data
    agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
        huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
    stcols_tab = agg_data['stcols']

    data.append((tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab))

## Training Curves

In [None]:
ymlpath = f'{workdir}/15_klstudy.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/16_klstudy.html')
save(fulllayout, title=dashdata['header'])

## Ablation Curves

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
hpdf_tab, statdf_tab = hpdf, statdf
hpdf_tab = hpdf_tab.rename(columns={'nn/ltnt/dim': 'nn/latent/dim'})
agg_data = get_aggdf(hpdf_tab, statdf_tab, xcol='epoch', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
data2 = [('All Dimensions', agg_data['hpdf'], agg_data['stdf'], agg_data['stcols'])]

In [None]:
ymlpath = f'{workdir}/17_klstudy.yml'
dashdata = get_dashdata(data2, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/18_klstudy.html')
save(fulllayout, title=dashdata['header'])

See the following plot to get a sense of the TV and Hellinger values.

In [None]:
with plt.style.context('dark_background'):
    fig, axes = plt.subplots(2, 3, figsize=(3*3.0, 2*2.4), sharex=True, sharey=True, dpi=100)
    axes = np.array(axes).ravel()

df_rowdicts = []
x = np.linspace(-5, 12, 10000)
for ax_idx, loc in enumerate((1, 2, 3, 4, 5, 6)):
    p1 = norm.pdf(x, loc=0)
    p2 = norm.pdf(x, loc=loc)
    tv_trg = 0.5 * np.abs(p1 - p2).sum() * x.ptp() / x.size
    h2_trg = 1 - np.sqrt(p1 * p2).sum() * x.ptp() / x.size
    h1_trg = np.sqrt(h2_trg)
    row_dict = {'Delta Mu': loc, 'TV': tv_trg, 
        'Hellinger': h1_trg, 'Hellinger-Squared': h2_trg}
    df_rowdicts.append(row_dict)

    ax = axes[ax_idx]
    ax.plot(x, p1, color='blue', lw=3)
    ax.plot(x, p2, color='red', lw=3)

    ax.text(0.95, 0.70, f'$\Delta\mu={loc}$\nTV={tv_trg:0.2f}\nH1={h1_trg:0.2f}',
        verticalalignment='bottom', horizontalalignment='right',
        transform=ax.transAxes, color='white', fontsize=10)

fig.set_tight_layout(True)

pd.DataFrame(df_rowdicts).round(3)

## Matplotlib Realism vs. KL Weight Curves

In [None]:
i_fig = 89
plt.ioff()

# Reading the experiment to fpidx specification configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

pltdfmu = pd.concat(data[0][1: 3], axis=1)
for col in pltdfmu.columns:
    pltdfmu[col] = pltdfmu[col].tolist()
pltdfsig = pd.concat(data[1][1: 3], axis=1)
for col in pltdfsig.columns:
    pltdfsig[col] = pltdfsig[col].tolist()

with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
    n_figrows, n_figcols = 2, 3

    fig, axes = plt.subplots(n_figrows, n_figcols, 
        figsize=[2.6 * n_figcols, 2.2 * n_figrows], 
        dpi=140, sharex=False, sharey=True, squeeze=False)
    axes = np.array(axes)

    for i_figrow, kl_wmuorsig, pltdf in [(0, 'klwsig', pltdfsig), (1, 'klwmu', pltdfmu)]:
        for i_figcol, svd_alpha in [(0, 1), (1, 0), (2, -1)]:
            ax = axes[i_figrow, i_figcol]
            plt_cfg = {'fig': fig, 'ax': ax, **v_mplcfgs[f'paper.realism.{kl_wmuorsig}.ldim']}
            plt_cfg['ycol'] = f'perf/realism/wslcwass2:{svd_alpha:.1f}/abserr/median'
            plt_cfg['ylabel'] = 'Realism Error' if (i_figcol == 0) else None
            if (i_figrow, i_figcol) != (0, 0):
                plt_cfg = {key: val for key, val in deep2hie(plt_cfg).items() 
                    if 'ax.annotate/text' not in key}
                plt_cfg = hie2deep(plt_cfg)

            fig, ax = draw_matplotlib(plt_cfg, pltdf)
            tag_axis(ax, f'({"abc"[i_figcol]}$_{{{i_figrow + 1}}}$)', fontsize=12, pad=(0.3, 0.6))
    
    for i_figrow, row_ttl in enumerate(['KL-$\sigma$ Study', 'KL-$\mu$ Study']):
        print_axheader(axes[i_figrow, 0], row_ttl, 'left', 
            fontsize=12, fontweight='bold')
    
    for i_figcol, svd_alpha in [(0, 1), (1, 0), (2, -1)]:
        col_ttl = f'$\\alpha={svd_alpha}$'
        print_axheader(axes[0, i_figcol], col_ttl, 'top', 
            fontsize=14, fontweight='bold')

    fig.subplots_adjust(wspace=0.1, hspace=0.5)

pdfpath = f'{workdir}/{i_fig:02d}_realism_klw.pdf'
fig.savefig(pdfpath, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

fig

## Matplotlib Realism vs. KL Weight Curves

In [None]:
i_fig = 90
plt.ioff()

# Reading the experiment to fpidx specification configs
with open(f'{workdir}/37_aerompl.yml', 'r') as fp:
    v_mplcfgsraw = ruyaml.load(fp, ruyaml.RoundTripLoader)
v_mplcfgs = hie2deep(parse_refs(v_mplcfgsraw), maxdepth=1)

pltdfmu = pd.concat(data[0][1: 3], axis=1)
for col in pltdfmu.columns:
    pltdfmu[col] = pltdfmu[col].tolist()
pltdfsig = pd.concat(data[1][1: 3], axis=1)
for col in pltdfsig.columns:
    pltdfsig[col] = pltdfsig[col].tolist()

with plt.rc_context(v_mplcfgs['rc_context']) as pltrcctx:
    n_figrows, n_figcols = 2, 3

    fig, axes = plt.subplots(n_figrows, n_figcols, 
        figsize=[2.6 * n_figcols, 2.2 * n_figrows], 
        dpi=140, sharex=True, sharey=True, squeeze=False)
    axes = np.array(axes)

    for i_figrow, kl_wmuorsig, pltdf in [(0, 'klwsig', pltdfsig), (1, 'klwmu', pltdfmu)]:
        for i_figcol, svd_alpha in [(0, 1), (1, 0), (2, -1)]:
            ax = axes[i_figrow, i_figcol]
            plt_cfg = {'fig': fig, 'ax': ax, **v_mplcfgs[f'paper.realism.ldim.{kl_wmuorsig}']}
            plt_cfg['ycol'] = f'perf/realism/wslcwass2:{svd_alpha:.1f}/abserr/median'
            plt_cfg['ylabel'] = 'Realism Error' if (i_figcol == 0) else None
            plt_cfg['xlabel'] = 'Latent Dimension' if (i_figrow == n_figrows - 1) else None
            if i_figcol != 2:
                plt_cfg = {key: val for key, val in deep2hie(plt_cfg).items() 
                    if 'ax.annotate/text' not in key}
                plt_cfg = hie2deep(plt_cfg)
            fig, ax = draw_matplotlib(plt_cfg, pltdf)
        
            tag_axis(ax, f'({"abc"[i_figcol]}$_{{{i_figrow + 1}}}$)', fontsize=12, pad=(0.3, 0.6))
    
    for i_figrow, row_ttl in enumerate(['KL-$\sigma$ Study', 'KL-$\mu$ Study']):
        print_axheader(axes[i_figrow, 0], row_ttl, 'left', 
            fontsize=12, fontweight='bold')
    
    for i_figcol, svd_alpha in [(0, 1), (1, 0), (2, -1)]:
        col_ttl = f'$\\alpha={svd_alpha}$'
        print_axheader(axes[0, i_figcol], col_ttl, 'top', 
            fontsize=14, fontweight='bold')

    fig.subplots_adjust(wspace=0.1, hspace=0.25)

pdfpath = f'{workdir}/{i_fig:02d}_realism_dltnt.pdf'
fig.savefig(pdfpath, bbox_inches='tight')
print(f'Finished writing {pdfpath}')

fig

# Experiment 6: KL Weight Study on the Histogram VAE

Related Configs:

  * `configs/02_adhoc/09_klstudy.yml`

Goals: 

  * This is a similar config to `configs/02_adhoc/06_klstudy.yml` and `configs/02_adhoc/07_klstudy.yml`

    * The KL mu and sigma weights are sweeped over identically.
    
  * There was a bug in TV calculation in the previous run. That's why I took this run. 

  * I made a few other adjustments as well
    
    * The metric evaluation reductions were changed to `mean` instead of `sum`.
    
    * The Hellinger's TV approximantion is now capped at `1.0`.
    
    * A legend labeling bug in the plots was fixed in this revision.
    
    * Noisy reconstruction evaluation and metrics were added here.
    
    * `10th` and `100th` neighbors were added to congestion metrics.

In [None]:
smrypath = '../summary/06_klstudy.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

# Squeezing the stat column names by removing singular levels
statdf = adjust_mtrcnames(statdf)
stcol_longnames = statdf.columns.tolist()
stcol_sqzdnames = squeeze_colnames(stcol_longnames, mindepth=2)
statdf = statdf.rename(columns=dict(zip(stcol_longnames, stcol_sqzdnames)))

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')

tab_ttl = 'KL Weight'
tab_idx = hpdf.sort_values(['nn/ltnt/dim', 'cri/kl/w'], kind='stable').index
hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)

# Aggregating the data
agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
stcols_tab = agg_data['stcols']

data = [(tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab)]

## Training Curves

In [None]:
ymlpath = f'{workdir}/19_klstudy.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/20_klstudy.html')
save(fulllayout, title=dashdata['header'])

## Ablation Curves

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')
hpdf_tab, statdf_tab = hpdf, statdf
hpdf_tab = hpdf_tab.rename(columns={'nn/ltnt/dim': 'nn/latent/dim'})
agg_data = get_aggdf(hpdf_tab, statdf_tab, xcol='epoch', 
    huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
data2 = [('All Dimensions', agg_data['hpdf'], agg_data['stdf'], agg_data['stcols'])]

In [None]:
ymlpath = f'{workdir}/21_klstudy.yml'
dashdata = get_dashdata(data2, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/22_klstudy.html')
save(fulllayout, title=dashdata['header'])

# Experiment 7: Testing the New Pre-Processing Framework

Related Configs: 

  * `configs/02_adhoc/10_pppoly.yml`

Goals: 

  * Just getting a VAE training with the new PP framework.

  * Quick and dirty OVAT experiment of most hyper-parameters in the MLP architectures.

Issues:

  * I made a mistake in preparing the training configs:

    1. The realism metric was defined on the pre-processed mass variables.

    2. For this reason, it cannot be trusted as an apple to apple comparison.

    3. The future training configs will contain seperate x-realism and u-realism metrics.

  * My main observation was that, in terms of the realism metric, the `nn/pp/nrmscl` made a lot of impact:

    1. When turned on the realism was 0.15 vs. when it was turned off the realism was 0.02.

    2. That being said, the aerosol metrics did not make much of a deal about this particular hyper-parameter.

    3. Since the realism metric for this training was defined on the pre-processed variables, this observation cannot be trusted.

  * Another issue was that the `nn/pp/magdim` ablations were absent:

    1. The VAE trainings failed with `NaN` values when `nn/pp/magdim` was either `n_chem` or `one`.

    2. This may be related to the mis-specification of the `nn/pp/nrmscl` hyper-parameter again.

  * The next issue with the data was that in the training curves of `notebooks/08_plotting/24_vaehist.html`:

    1. The `performane/reconstruction/noisy/test` metric had a wild range.
    
    2. This range specifically expanded in the early epochs for the `nn/pp/nrmexp` (normalization exponent) hyper-parameter ablations.

  * Another poor outcome was the proximity metric:

    1. The KL weight ablations caused a lot of unbounded behavior for the proximity metric.

    2. Other hyper-parameters were also sporadically causing this metric to go wild.

Incident Report:

  * One issue with the results was that the `nn/pp/magdim` ablations were absent:

    1. The VAE trainings failed with `NaN` values generated at the first epoch when `nn/pp/magdim` was either `n_chem` or `one`.

    2. This may be related to the mis-specification of the `nn/pp/nrmscl` hyper-parameter again.

    3. Upon further investigation, I realized that with `nn/pp/magdim=n_chem`,

      1. Some of of the `sig08` pp parameters were on the order of 3e-7. 

      2. This was happening since some of the leftmost bins in the normalized mass tensor could have been zero all the time.

      3. In the test split, a value of 0.4 was getting divided by this 3e-7 element, producing a value of roughly 600_000.

      4. This value of 600_000 was getting fed to the encoder!

      5. The encoder was using ReLU. This resulted in a very large log-sigma predictions by the encoder.

      6. Therefore, the z sigma value for this data point were becoming inf. 

      7. When sampling z, this infinite sigma was multiplied by a sigscale of 0, producing a NaN!

  * The following measures were taken to address this issue in the next rounds:

    1. I added an `nrmscleps` hyper-parameter to the pre-porcessor in the next revision.

        1. Similarly, the `magscleps` and `cntscleps` hyper-parameters were also added.

        2. I've set the default for these to zero (i.e., identical to before).

        3. If these hyper-parameters had showen promise after *full trainings*, I would have adjusted their defaults.

        4. However, as I saw in later experiments, they made marginal improvements to test performance metrics.

    2. I also changed the exp function, into a piecewise exp-linear function to get the z sigma 
    
        1. See the related `nn/ltnt/sig/tnsfm` hyper-parameter.

In [None]:
smrypath = '../summary/07_pppoly.h5'
get_h5du(smrypath, verbose=True, detailed=False)
data = load_h5data(smrypath)
hpdf = data['hp']
statdf = data['stat']

In [None]:
# Squeezing the stat column names by removing singular levels
statdf = adjust_mtrcnames(statdf)
stcol_longnames = statdf.columns.tolist()
stcol_sqzdnames = squeeze_colnames(stcol_longnames, mindepth=2)
statdf = statdf.rename(columns=dict(zip(stcol_longnames, stcol_sqzdnames)))

# Dropping the quantile data to save on space :)
drop_pats = ['q10', 'q25', 'q5', 'q75', 'q90', 'q95', 'median', '/kl:']
keepcols = [col for col in statdf.columns if not any(x in col for x in drop_pats)]
statdf = statdf[keepcols] 

# Downcasting numerical types to save on space
statdf = downcast_df(statdf)

In [None]:
# `fpicols` will be an `fpidxgrp` to hp column mapping; each fpidxgrp 
# is part of an ovat ablation defined by a single column.
ii_drop = hpdf['fpidxgrp'].drop_duplicates().index
hpdf2 = hpdf.loc[ii_drop].reset_index(drop=True)

fpicols = dict()
fpidxgrps = hpdf2['fpidxgrp']
main_fpidx = fpidxgrps.iloc[0]
for fpidx in fpidxgrps.iloc[1:]:
    hpdf3 = hpdf2[fpidxgrps.isin([fpidx, main_fpidx])]
    hpdf4 = drop_unqcols(hpdf3)
    hpdf5 = hpdf4.drop(columns=['fpidx', 'fpidxgrp'], errors='ignore')
    cols = hpdf5.columns.tolist()
    fpicols[fpidx] = hpdf5.columns.tolist()[0]

In [None]:
aggcfg = dict(type='bootstrap', n_boot=40, q=[5, 95], stat='mean', device='cpu')

tab_ttl1 = 'General Ablations'
tab_fpidxs1 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() if 'nn/pp' not in ovatcol]

tab_ttl2 = 'PP Magnitude Ablations'
tab_fpidxs2 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() 
    if ovatcol in ['nn/pp/magpnrm', 'nn/pp/magexp', 'nn/pp/magshft', 'nn/pp/magscl']]

tab_ttl3 = 'PP Normalization Ablations'
tab_fpidxs3 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() 
    if ovatcol in ['nn/pp/eps', 'nn/pp/nrmexp', 'nn/pp/nrmshft', 'nn/pp/nrmscl', 'nn/pp/nrmscldim']]

tab_ttl4 = 'PP Count Ablations'
tab_fpidxs4 = [main_fpidx] + [fpidx for fpidx, ovatcol in fpicols.items() 
    if ovatcol in ['nn/pp/cnteps', 'nn/pp/cntexp', 'nn/pp/cntshft', 'nn/pp/cntscldim']]

# Aggregating the data
data = []
for tab_ttl, tab_fpidxs in [(tab_ttl1, tab_fpidxs1), 
    (tab_ttl2, tab_fpidxs2), (tab_ttl3, tab_fpidxs3), (tab_ttl4, tab_fpidxs4)]:
    tab_idx = (hpdf['fpidxgrp'].isin(tab_fpidxs))
    hpdf_tab = hpdf.loc[tab_idx, :].reset_index(drop=True)
    stdf_tab = statdf.loc[tab_idx, :].reset_index(drop=True)
    agg_data = get_aggdf(hpdf_tab, stdf_tab, xcol='epoch', 
        huecol='fpidxgrp', rngcol='rng_seed', aggcfg=aggcfg)
    hpdf_tabagg, stdf_tabagg = agg_data['hpdf'], agg_data['stdf']
    stcols_tab = agg_data['stcols']
    data.append((tab_ttl, hpdf_tabagg, stdf_tabagg, stcols_tab))

## Training Curves

In [None]:
ymlpath = f'{workdir}/23_vaehist.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/24_vaehist.html')
save(fulllayout, title=dashdata['header'])

## Ablation Curves

In [None]:
ymlpath = f'{workdir}/25_vaehist.yml'
dashdata = get_dashdata(data, ymlpath, write_yml=False)

In [None]:
fulllayout = build_dashboard(None, **dashdata)
output_file(f'{workdir}/26_vaehist.html')
save(fulllayout, title=dashdata['header'])