In [None]:
import holoviews as hv
from holoviews import dim, opts, streams
from holoviews.selection import link_selections
import hvplot.pandas
import pandas as pd
from itertools import combinations
import numpy as np
from tqdm import tqdm
import re
import panel as pn
import spatialpandas

hv.extension('bokeh',width=100)
hv.Store.set_current_backend('bokeh')

In [None]:
def load_params(filename):
    params_list = []
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            params = re.split(' \t ', line)
            params_list.append(params)
    return [item[0] for item in params_list]


def load_data(filename, column_names):
    data = np.loadtxt(filename)
    df = pd.DataFrame(data[:,2:], columns=column_names)
    return df


def plot_scatter_table(data, params, plots):
    # kwargs:
    # data: a pandas DataFrame
    # params: list of parameters that correspond to column names in the DataFrame
    # plots: the number of plots to display
    
    # generate a list of all pairs of the parameters
    pairs = [list(comb) for comb in combinations(params, 2)]
    
    # create linked selections
    ls = link_selections.instance()
    
    # make height of the table match the total height of the plots
    if (plots % 2) == 1:
        table_height = (plots//2) + 1
    else:
        table_height = plots/2
    table = hv.Table(data[params]).opts(width=800, height=int(300*table_height))
    
    layout = hv.Layout()
    
    for param_a, param_b in pairs[:plots]:
        # vdims = [e for e in params if e not in (param_a, param_b)]
        # ^^^ uncomment the above and add ", vdims" to the argument of hv.Dataset if desired
        ds = hv.Dataset(data, [param_a, param_b])
        pts = hv.Points(ds).opts(
            opts.Points(color='black', size=2))
        bivar = hv.Bivariate(data[[param_a,param_b]].values, [param_a,param_b], []).opts(
            opts.Bivariate(bandwidth=0.5,
                           cut=0,cmap="blues",
                           levels=5,
                           colorbar=False,
                           show_legend=False,
                           filled=True,
                           toolbar='above',
                           width=350,
                           alpha=0.75))
        layout += (ls(pts)*bivar).opts(width=300, height=300)
    
    layout = layout.cols(2)
    params = pn.Param(ls, parameters=['selection_mode'])
    scatter_table = pn.Column(params, pn.Row(layout, ls(table)))
    return scatter_table

In [None]:
if __name__=='__main__':
    # Read in data
    param_names = load_params('data/test_IDM_n_0/2022-05-04_75000_.paramnames')
    df = pd.DataFrame(columns=param_names)
    for i in tqdm(range(1,56)):
        temp = load_data('data/test_IDM_n_0/2022-05-04_75000__{}.txt'.format(i), column_names=param_names)
        df = pd.concat([df,temp]).reset_index(drop=True)

In [None]:
params = ['omega_b', 'omega_dmeff', 'n_s', 'tau_reio', 'sigma_dmeff', 'H0', 'A_s', 'sigma8']
# slicing because I'm REALLY impatient. A better downsampling method may be implemented
df_slice = df[::1000]
new_df = df_slice.reset_index(drop=True)
# recommend plotting only up to 4 plots at once to avoid running into performance issues
viz = plot_scatter_table(new_df, params, plots=4)
viz