In [1]:
%matplotlib inline

In [2]:
#__ = plt.style.use("./diffstar.mplstyle")
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from scipy.optimize import curve_fit
mred = u"#d62728"
morange = u"#ff7f0e"
mgreen = u"#2ca02c"
mblue = u"#1f77b4"
mpurple = u"#9467bd"
plt.rc('font', family="serif")
plt.rc('font', size=16)
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}') #necessary to use \dfrac
# matplotlib.rcParams['figure.figsize'] 
plt.rc('figure', figsize=(6,4)) 


In [3]:
from jax import (
    numpy as jnp,
    jit as jjit,
    random as jran,
    grad,
    vmap,
)
import os
import h5py
from functools import partial
from collections import OrderedDict, namedtuple
import warnings
import getdist
from getdist import plots, MCSamples

names = ["ulgm", "ulgy", "ul", "uh", "utau", "uqt", "uqs", "udrop", "urej"]



In [None]:
from diffstar.kernels.main_sequence_kernels_tpeak import (
    _get_bounded_sfr_params,
    # _get_unbounded_sfr_params,
    # _get_unbounded_sfr_params_vmap,
)
from diffstar.kernels.quenching_kernels import (
    _get_bounded_q_params,
    # _get_unbounded_q_params,
    # _get_unbounded_q_params_vmap,
)

from diffstar.kernels.main_sequence_kernels import MS_PARAM_BOUNDS_PDICT, MS_BOUNDING_SIGMOID_PDICT
from diffstar.kernels.quenching_kernels import Q_PARAM_BOUNDS_PDICT, Q_BOUNDING_SIGMOID_PDICT
from diffstar.utils import _inverse_sigmoid, _sigmoid
from diffstar.defaults import TODAY, LGT0

from diffmah.diffmah_kernels import DiffmahParams, mah_halopop


In [None]:


@jjit
def _get_unbounded_sfr_params(
    lgmcrit,
    lgy_at_mcrit,
    indx_lo,
    indx_hi,
    tau_dep,
):
    EPS = 1e-5
    lgmcrit = jnp.clip(lgmcrit, MS_PARAM_BOUNDS_PDICT["lgmcrit"][0] + EPS, MS_PARAM_BOUNDS_PDICT["lgmcrit"][1] - EPS)
    lgy_at_mcrit = jnp.clip(lgy_at_mcrit, MS_PARAM_BOUNDS_PDICT["lgy_at_mcrit"][0] + EPS, MS_PARAM_BOUNDS_PDICT["lgy_at_mcrit"][1] - EPS)
    indx_lo = jnp.clip(indx_lo, MS_PARAM_BOUNDS_PDICT["indx_lo"][0] + EPS, MS_PARAM_BOUNDS_PDICT["indx_lo"][1] - EPS)
    indx_hi = jnp.clip(indx_hi, MS_PARAM_BOUNDS_PDICT["indx_hi"][0] + EPS, MS_PARAM_BOUNDS_PDICT["indx_hi"][1] - EPS)
    tau_dep = jnp.clip(tau_dep, MS_PARAM_BOUNDS_PDICT["tau_dep"][0] + EPS, MS_PARAM_BOUNDS_PDICT["tau_dep"][1] - EPS)

    u_lgmcrit = _inverse_sigmoid(lgmcrit, *MS_BOUNDING_SIGMOID_PDICT["lgmcrit"])
    u_lgy_at_mcrit = _inverse_sigmoid(
        lgy_at_mcrit, *MS_BOUNDING_SIGMOID_PDICT["lgy_at_mcrit"]
    )
    u_indx_lo = _inverse_sigmoid(indx_lo, *MS_BOUNDING_SIGMOID_PDICT["indx_lo"])
    u_indx_hi = _inverse_sigmoid(indx_hi, *MS_BOUNDING_SIGMOID_PDICT["indx_hi"])
    u_tau_dep = _inverse_sigmoid(tau_dep, *MS_BOUNDING_SIGMOID_PDICT["tau_dep"])
    bounded_params = (
        u_lgmcrit,
        u_lgy_at_mcrit,
        u_indx_lo,
        u_indx_hi,
        u_tau_dep,
    )
    return bounded_params

@jjit
def _get_unbounded_sfr_params_galpop_kern(u_ms_params):
    return jnp.array(_get_unbounded_sfr_params(*u_ms_params))

_get_unbounded_sfr_params_vmap = jjit(
    vmap(_get_unbounded_sfr_params_galpop_kern, in_axes=(0,))
)

@jjit
def _get_unbounded_q_params(lg_qt, qlglgdt, lg_drop, lg_rejuv):
    EPS = 1e-5
    lg_qt = jnp.clip(lg_qt, Q_PARAM_BOUNDS_PDICT["lg_qt"][0] + EPS, Q_PARAM_BOUNDS_PDICT["lg_qt"][1] - EPS)
    qlglgdt = jnp.clip(qlglgdt, Q_PARAM_BOUNDS_PDICT["qlglgdt"][0] + EPS, Q_PARAM_BOUNDS_PDICT["qlglgdt"][1] - EPS)
    lg_drop = jnp.clip(lg_drop, Q_PARAM_BOUNDS_PDICT["lg_drop"][0] + EPS, Q_PARAM_BOUNDS_PDICT["lg_drop"][1] - EPS)
    lg_rejuv = jnp.clip(lg_rejuv, lg_drop + EPS, Q_PARAM_BOUNDS_PDICT["lg_rejuv"][1] - EPS)
    u_lg_qt = _inverse_sigmoid(lg_qt, *Q_BOUNDING_SIGMOID_PDICT["lg_qt"])
    u_qlglgdt = _inverse_sigmoid(qlglgdt, *Q_BOUNDING_SIGMOID_PDICT["qlglgdt"])
    u_lg_drop = _inverse_sigmoid(lg_drop, *Q_BOUNDING_SIGMOID_PDICT["lg_drop"])
    u_lg_rejuv = _get_unbounded_qrejuv(lg_rejuv, lg_drop)
    return u_lg_qt, u_qlglgdt, u_lg_drop, u_lg_rejuv


@jjit
def _get_unbounded_qrejuv(lg_rejuv, lg_drop):
    u_lg_rejuv = _inverse_sigmoid(
        lg_rejuv,
        *Q_BOUNDING_SIGMOID_PDICT["lg_rejuv"][:2],
        lg_drop,
        Q_BOUNDING_SIGMOID_PDICT["lg_rejuv"][3],
    )
    return u_lg_rejuv

@jjit
def _get_unbounded_q_params_galpop_kern(u_q_params):
    return jnp.array(_get_unbounded_q_params(*u_q_params))


_get_unbounded_q_params_vmap = jjit(
    vmap(_get_unbounded_q_params_galpop_kern, in_axes=(0,))
)

In [None]:
tpeak_path = "/Users/alarcon/Documents/diffmah_data/tpeak/"

sim_name_list = [
    "smdpl",
    "smdpl_DR1",
    "tng",
    "galacticus_in_situ",
    "galacticus_in_plus_ex_situ"
]

sim_name_tuplename_list = [
    "SMDPL",
    "SMDPL_DR1",
    "TNG",
    "GALACTICUS_IN",
    "GALACTICUS_INPLUSEX",
]

sim_index = 0

sim_name = sim_name_list[sim_index]
sim_name_tuplename = sim_name_tuplename_list[sim_index]

print(f"Using simulation: {sim_name}")


indir = sim_name + "_pdf_target_data"
fname = os.path.join(tpeak_path, indir, "smdpl_smhm_samples_haloes.h5")
with h5py.File(fname, "r") as hdf:
    logmh_id = hdf["logmh_id"][:]
    logmh_val = hdf["logmh_id"][:]
    mah_params_samp = hdf["mah_params_samp"][:]
    ms_params_samp = hdf["ms_params_samp"][:]
    q_params_samp = hdf["q_params_samp"][:]
    upid_samp = hdf["upid_samp"][:]
    tobs_id = hdf["tobs_id"][:]
    tobs_val = hdf["tobs_val"][:]
    redshift_val = hdf["redshift_val"][:]

u_ms_fit_params = _get_unbounded_sfr_params_vmap(ms_params_samp.T)
u_q_fit_params = _get_unbounded_q_params_vmap(q_params_samp.T)
ms_fit_params = ms_params_samp.T.copy()
q_fit_params = q_params_samp.T.copy()
assert not np.isnan(u_ms_fit_params).any()
assert not np.isnan(u_q_fit_params).any()

tarr_logm0 = np.logspace(-1, LGT0, 50)
mah_pars_ntuple = DiffmahParams(*mah_params_samp)
dmhdt_fit, log_mah_fit = mah_halopop(mah_pars_ntuple, tarr_logm0, LGT0)
logmp0_data = log_mah_fit[:, -1]



tpeak_path = "/Users/alarcon/Documents/diffmah_data/tpeak/"
filename = "SMDPL_dr1_no_merging_upidh_tpeak_mahsfh_difffits_tpeak_small_100subvols.npz"
_data = np.load(tpeak_path+filename)
data = namedtuple("data", _data.files)
data = data(*[_data[key] for key in _data.files])
u_ms_fit_params = _get_unbounded_sfr_params_vmap(ms_fit_params)
u_q_fit_params = _get_unbounded_q_params_vmap(q_fit_params)
assert not np.isnan(u_ms_fit_params).any()
assert not np.isnan(u_q_fit_params).any()
print(data._fields)


In [None]:
plt.hist(logmp0_data, np.linspace(10,15.5,100), histtype='step')
plt.yscale('symlog', linthresh=5.0)
plt.xticks(np.arange(10,16.0, 1.0))
plt.show()

plt.scatter(logmp0_data, mah_pars_ntuple.t_peak, s=1)
plt.show()

In [None]:
mpeak_bins = np.arange(10.75, 15.1, 0.5)
mpeak_binsc = mpeak_bins[:-1] + np.diff(mpeak_bins)/2
_bins = np.linspace(0, TODAY+0.01, 30)
_binsc = _bins[:-1] + np.diff(_bins)/2
cmap = plt.get_cmap("viridis")(np.linspace(0,1,len(mpeak_bins)-1))
for i in range(len(mpeak_bins)-1):
    sel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i+1]) # & (mah_pars_ntuple.t_peak < 13.7)
    vals = np.histogram(mah_pars_ntuple.t_peak[sel], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    # vals /= vals.sum()
    plt.plot(_binsc, vals, color=cmap[i], label=r"$M_0=%.1f$"%mpeak_binsc[i])
plt.legend(fontsize=14)
plt.xlabel(r"$t_{\rm peak}\,  [{\rm Gyr}]$")
plt.ylabel("CDF")
plt.ylim(0,1)
plt.xlim(0,14)
plt.xticks(np.arange(0,14.1, 2))
plt.show()



In [None]:
lgqt = q_fit_params[:, 0]


mpeak_bins = np.arange(10.75, 15.1, 0.5)
mpeak_binsc = mpeak_bins[:-1] + np.diff(mpeak_bins)/2

_bins = np.linspace(0, 2.0, 30)
_binsc = _bins[:-1] + np.diff(_bins)/2
cmap = plt.get_cmap("viridis")(np.linspace(0,1,len(mpeak_bins)-1))

fig, ax = plt.subplots(1,3, figsize=(16,4))
for i in range(len(mpeak_bins)-1):
    sel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i+1]) # & (mah_pars_ntuple.t_peak < 13.7)
    lgqt_sel = lgqt[sel]
    tpeak_sel = mah_pars_ntuple.t_peak[sel]
    # p50_sel = data.p50[sel]
    
    vals = np.histogram(lgqt_sel, _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[0].plot(_binsc, vals, color=cmap[i], label=r"$%.1f$"%mpeak_binsc[i])
    vals_ref = vals.copy()
    med_tpeak = np.median(tpeak_sel)

    sel2 = (tpeak_sel < med_tpeak)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[1].plot(_binsc, vals-vals_ref, color=cmap[i], ls = '--')

    sel2 = (tpeak_sel > med_tpeak)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[1].plot(_binsc, vals-vals_ref, color=cmap[i], ls = ':')
 
    """
    med_p50 = np.median(p50_sel)

    sel2 = (p50_sel < med_p50)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[2].plot(_binsc, vals-vals_ref, color=cmap[i], ls = '--')

    sel2 = (p50_sel > med_p50)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[2].plot(_binsc, vals-vals_ref, color=cmap[i], ls = ':')
    """

ax[0].legend(fontsize=14, title=r"$M_0$", title_fontsize=13, loc=4)

legend_elements = [
    Line2D([0], [0], color = 'k', ls='--', label=r"Low $t_{\rm peak}$"),
    Line2D([0], [0], color = 'k', ls=':',  label=r"High $t_{\rm peak}$"),
            ]

legend2 = ax[1].legend(handles=legend_elements, loc=0, ncol=1, fontsize=14, )

legend_elements = [
    Line2D([0], [0], color = 'k', ls='--', label=r"Low $p_{50}$"),
    Line2D([0], [0], color = 'k', ls=':',  label=r"High $p_{50}$"),
            ]

legend2 = ax[2].legend(handles=legend_elements, loc=0, ncol=1, fontsize=14, )

ax[0].set_xlabel("lgqt")
ax[1].set_xlabel("lgqt")
ax[2].set_xlabel("lgqt")
ax[0].set_ylabel("CDF")
ax[1].set_ylabel(r"$\Delta$CDF")
ax[2].set_ylabel(r"$\Delta$CDF")
ax[0].axvline(np.log10(TODAY), ls=':', color='k')
ax[1].axvline(np.log10(TODAY), ls=':', color='k')
ax[2].axvline(np.log10(TODAY), ls=':', color='k')
ax[1].set_ylim(-0.2, 0.2)
ax[2].set_ylim(-0.2, 0.2)
fig.subplots_adjust(wspace=0.25)

plt.show()



In [None]:
lgqt = ms_fit_params[:, 4]


mpeak_bins = np.arange(10.75, 15.1, 0.5)
mpeak_binsc = mpeak_bins[:-1] + np.diff(mpeak_bins)/2

_bins = np.linspace(lgqt.min(), lgqt.max(), 30)
_binsc = _bins[:-1] + np.diff(_bins)/2
cmap = plt.get_cmap("viridis")(np.linspace(0,1,len(mpeak_bins)-1))

fig, ax = plt.subplots(1,3, figsize=(16,4))
for i in range(len(mpeak_bins)-1):
    sel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i+1]) # & (mah_pars_ntuple.t_peak < 13.7)
    lgqt_sel = lgqt[sel]
    tpeak_sel = mah_pars_ntuple.t_peak[sel]
    # p50_sel = data.p50[sel]
    
    vals = np.histogram(lgqt_sel, _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[0].plot(_binsc, vals, color=cmap[i], label=r"$%.1f$"%mpeak_binsc[i])
    vals_ref = vals.copy()
    med_tpeak = np.median(tpeak_sel)

    sel2 = (tpeak_sel < med_tpeak)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[1].plot(_binsc, vals-vals_ref, color=cmap[i], ls = '--')

    sel2 = (tpeak_sel > med_tpeak)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[1].plot(_binsc, vals-vals_ref, color=cmap[i], ls = ':')
    """
    med_p50 = np.median(p50_sel)

    sel2 = (p50_sel < med_p50)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[2].plot(_binsc, vals-vals_ref, color=cmap[i], ls = '--')

    sel2 = (p50_sel > med_p50)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax[2].plot(_binsc, vals-vals_ref, color=cmap[i], ls = ':')
    """

ax[0].legend(fontsize=14, title=r"$M_0$", title_fontsize=13, loc=4)

legend_elements = [
    Line2D([0], [0], color = 'k', ls='--', label=r"Low $t_{\rm peak}$"),
    Line2D([0], [0], color = 'k', ls=':',  label=r"High $t_{\rm peak}$"),
            ]

legend2 = ax[1].legend(handles=legend_elements, loc=0, ncol=1, fontsize=14, )

legend_elements = [
    Line2D([0], [0], color = 'k', ls='--', label=r"Low $p_{50}$"),
    Line2D([0], [0], color = 'k', ls=':',  label=r"High $p_{50}$"),
            ]

legend2 = ax[2].legend(handles=legend_elements, loc=0, ncol=1, fontsize=14, )

ax[0].set_xlabel("lgqt")
ax[1].set_xlabel("lgqt")
ax[2].set_xlabel("lgqt")
ax[0].set_ylabel("CDF")
ax[1].set_ylabel(r"$\Delta$CDF")
ax[2].set_ylabel(r"$\Delta$CDF")
#ax[0].axvline(np.log10(TODAY), ls=':', color='k')
#ax[1].axvline(np.log10(TODAY), ls=':', color='k')
#ax[2].axvline(np.log10(TODAY), ls=':', color='k')
ax[1].set_ylim(-0.2, 0.2)
ax[2].set_ylim(-0.2, 0.2)
fig.subplots_adjust(wspace=0.25)

plt.show()


In [None]:
lgqt = q_fit_params[:, 0]


mpeak_bins = np.arange(10.75, 15.1, 0.5)
mpeak_binsc = mpeak_bins[:-1] + np.diff(mpeak_bins)/2

_bins = np.linspace(0, 2.0, 30)
_binsc = _bins[:-1] + np.diff(_bins)/2
cmap = plt.get_cmap("viridis")(np.linspace(0,1,len(mpeak_bins)-1))

fig, ax = plt.subplots(1,1)
for i in range(len(mpeak_bins)-1):
    sel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i+1]) # & (mah_pars_ntuple.t_peak < 13.7)
    lgqt_sel = lgqt[sel]
    tpeak_sel = mah_pars_ntuple.t_peak[sel]
    # p50_sel = data.p50[sel]
    
    vals = np.histogram(lgqt_sel, _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax.plot(_binsc, vals, color=cmap[i], label=r"$M_0=%.1f$"%mpeak_binsc[i])

    med_tpeak = np.median(tpeak_sel)

    sel2 = (tpeak_sel < med_tpeak)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax.plot(_binsc, vals, color=cmap[i], ls = '--')

    sel2 = (tpeak_sel > med_tpeak)
    vals = np.histogram(lgqt_sel[sel2], _bins)[0].astype(float)
    vals = np.cumsum(vals)
    vals /= vals[-1]
    ax.plot(_binsc, vals, color=cmap[i], ls = ':')

legend1 = ax.legend(fontsize=14, bbox_to_anchor=(1.,1.))

legend_elements = [
    Line2D([0], [0], color = 'k', ls='--', label=r"Low $t_{\rm peak}$"),
    Line2D([0], [0], color = 'k', ls=':',  label=r"High $t_{\rm peak}$"),
            ]

legend2 = ax.legend(handles=legend_elements, loc=2, ncol=1, bbox_to_anchor=(1.,0.25), fontsize=14, )

ax.add_artist(legend1)

ax.axvline(np.log10(TODAY), ls=':', color='k')
ax.set_xlabel("lgqt")
ax.set_ylabel("CDF")
#plt.hist()
plt.show()



In [None]:
sigmoid_param_bounds = []
for x in MS_BOUNDING_SIGMOID_PDICT:
    sigmoid_param_bounds.append(MS_BOUNDING_SIGMOID_PDICT[x])
for x in Q_BOUNDING_SIGMOID_PDICT:
    sigmoid_param_bounds.append(Q_BOUNDING_SIGMOID_PDICT[x])
    
param_names = np.concatenate((list(MS_BOUNDING_SIGMOID_PDICT.keys()), list(Q_BOUNDING_SIGMOID_PDICT.keys())))
param_names = [x.replace("_", "-") for x in param_names]


In [None]:
from diffsky.diffndhist import tw_ndhist
from scipy.optimize import minimize
import time

@partial(jjit, static_argnames=["n_histories"])
def _fun(
    params, 
    loss_data,
    n_histories, 
):
    (
        ran_key, 
        ndsig, 
        ndbins_lo, 
        ndbins_hi,
        sigmoid_params,
        target,
    ) = loss_data
    mu, sig = params
    _random_unbound = mu + sig * jran.normal(ran_key, shape=(n_histories,))
    _random_bound = _sigmoid(_random_unbound, *sigmoid_params)
    _random_bound = _random_bound.reshape(n_histories, 1)
    counts = tw_ndhist(_random_bound, ndsig, ndbins_lo, ndbins_hi)
    counts = counts / jnp.sum(counts)
    return jnp.sum((counts-target)**2)

_fun_deriv = jjit(grad(_fun, argnums=(0)), static_argnames=["n_histories"])


def draw_samples(params, ran_key, n_histories, sigmoid_params):
    mu, sig = params
    _random_unbound = mu + sig * jran.normal(ran_key, shape=(n_histories,))
    _random_bound = _sigmoid(_random_unbound, *sigmoid_params)
    return _random_bound

def _fun_deriv_np(params, data, n_histories):
    return np.array(_fun_deriv(params, data, n_histories)).astype(float)

In [None]:
fit_limits = (
    [10.5, 13.0],
    [-2.0, -0.01],
    [0.01, 3.0],
    [-10.0, 10.0],
    [0.1, 19.0],
    [0.3, 1.5],
    [-2.0, -0.05],
    [-2.9, -0.1],
    [-2.9, -0.01],
)
from scipy.signal import peak_widths

def get_init_params(params):
    
    median = np.median(params)    
    _sigmas = np.diff(np.percentile(params, [50-68.27/2.0, 50+68.27/2.0]))[0]/2.0

    _bins = np.linspace(median-5*_sigmas,median+5*_sigmas,100)
    _binsc = 0.5*(_bins[1:]+_bins[:-1])
    _vals = np.histogram(params, _bins)[0]
    peak_ind = np.atleast_1d(np.argmax(_vals))
    _peak_width = np.diff(_bins)[0] * peak_widths(_vals, peak_ind, rel_height=0.5)[0] / 2.0
    fraction_within_FWHM = np.sum((median - _peak_width < params) & (params < median + _peak_width)) / len(params)
    sigma = _peak_width * 0.68 / fraction_within_FWHM 
    return median, sigma[0]


# Quenched

In [None]:
fit_params = np.concatenate((ms_fit_params, q_fit_params), axis=1)
u_fit_params = np.concatenate((u_ms_fit_params, u_q_fit_params), axis=1)

massvals = np.arange(11.5, 13.99, 0.1)
massvalsc = 0.5*(massvals[1:]+massvals[:-1])

_free_params = np.array([0, 1, 2, 4, 5, 6, 7, 8])

medians_Q = np.zeros((len(massvals)-1,8)) 
medians_MS = np.zeros((len(massvals)-1,4))

sigmas_Q = np.zeros((len(massvals)-1,8))
sigmas_MS = np.zeros((len(massvals)-1,4))

corrs_Q = np.ones((len(massvals)-1,8,8))
corrs_MS = np.ones((len(massvals)-1,4,4)) 

fquench_cen = np.zeros(len(massvalsc))
fquench_sat = np.zeros(len(massvalsc))

In [None]:
ran_key = jran.PRNGKey(0)
n_histories = int(1e4)

for j_id in range(8):  
    j = _free_params[j_id]
    sigmoid_params = sigmoid_param_bounds[j]

    for i in range(len(massvals)-1):
        masslow, masshigh = massvals[i], massvals[i+1]
        mask_mass = (logmp0_data>masslow)&((logmp0_data<masshigh)) # & (data.p50 > 0.4) & (data.p50 < 0.6)
        quenched = (10**q_fit_params[:,0] <= TODAY)

        _val_unbound = u_fit_params[mask_mass&quenched, j].copy()
        _val_bound = fit_params[mask_mass&quenched, j].copy()

        msk = (_val_bound > fit_limits[j][0]) & (_val_bound < fit_limits[j][1])
        assert(msk.sum()>0)
        _val_unbound = _val_unbound[msk]
        _val_bound = _val_bound[msk]

        bins = np.linspace(fit_limits[j][0], fit_limits[j][1], 50)
        binsc = 0.5*(bins[:-1] + bins[1:])
        bins_LO = bins[:-1].reshape((len(bins)-1), 1)
        bins_HI = bins[1:].reshape((len(bins)-1), 1)
        ndsig = np.ones((n_histories, 1)) * np.diff(bins)[0]

        target_counts = np.histogram(_val_bound, bins, density=0)[0]
        target_counts = target_counts / np.sum(target_counts)

        init_params = np.array(get_init_params(_val_unbound))
        init_params[1] *= 2.0
        loss_data = (
            ran_key, ndsig, bins_LO, bins_HI, sigmoid_params, target_counts
        )
        # print(_fun(init_params, loss_data, n_histories))
        # print(_fun_deriv_np(init_params, loss_data, n_histories))

        # t0 = time.time()
        res = minimize(
            _fun, x0=init_params, method="L-BFGS-B", jac=_fun_deriv_np, 
            args=(loss_data, n_histories, )
            )
        p_best = res.x
        loss_best = float(res.fun)
        # t1 = time.time()
        # print(i, t1-t0)
        best_samples = draw_samples(p_best, ran_key, n_histories, sigmoid_params)
        """
        best_counts = np.histogram(best_samples, bins, density=0)[0]
        best_counts = best_counts / np.sum(best_counts)
        """
        bins_plot = np.linspace(sigmoid_param_bounds[j][2], sigmoid_param_bounds[j][3], 50)
        binsc_plot = 0.5*(bins_plot[:-1] + bins_plot[1:])

        _val_bound = fit_params[mask_mass&quenched, j].copy()
        target_counts_plot = np.histogram(_val_bound, bins_plot, density=0)[0]
        target_counts_plot = target_counts_plot / np.sum(target_counts_plot)

        best_counts_plot = np.histogram(best_samples, bins_plot, density=0)[0]
        best_counts_plot = best_counts_plot / np.sum(best_counts_plot)

        plt.plot(binsc_plot, target_counts_plot)
        plt.plot(binsc_plot, best_counts_plot, ls='--')
        plt.title("%s, m0 = %.2f"%(param_names[j], massvalsc[i]))
        plt.show() 

        medians_Q[i,j_id] = p_best[0]
        sigmas_Q[i,j_id] = p_best[1]

        # break


# Main Sequence

In [None]:
ran_key = jran.PRNGKey(0)
n_histories = int(1e4)

for j_id in range(4):  
    j = _free_params[j_id]
    sigmoid_params = sigmoid_param_bounds[j]

    for i in range(len(massvals)-1):
        masslow, masshigh = massvals[i], massvals[i+1]
        mask_mass = (logmp0_data>masslow)&((logmp0_data<masshigh)) # & (data.p50 > 0.4) & (data.p50 < 0.6)
        quenched = (10**q_fit_params[:,0] <= TODAY)

        _val_unbound = u_fit_params[mask_mass&(~quenched), j].copy()
        _val_bound = fit_params[mask_mass&(~quenched), j].copy()

        msk = (_val_bound > fit_limits[j][0]) & (_val_bound < fit_limits[j][1])
        if msk.sum() < 30:
            medians_MS[i,j_id] = np.nan
            sigmas_MS[i,j_id] = np.nan
            continue

        _val_unbound = _val_unbound[msk]
        _val_bound = _val_bound[msk]

        bins = np.linspace(fit_limits[j][0], fit_limits[j][1], 50)
        binsc = 0.5*(bins[:-1] + bins[1:])
        bins_LO = bins[:-1].reshape((len(bins)-1), 1)
        bins_HI = bins[1:].reshape((len(bins)-1), 1)
        ndsig = np.ones((n_histories, 1)) * np.diff(bins)[0]

        target_counts = np.histogram(_val_bound, bins, density=0)[0]
        target_counts = target_counts / np.sum(target_counts)

        init_params = np.array(get_init_params(_val_unbound))
        init_params[1] *= 2.0

        loss_data = (
            ran_key, ndsig, bins_LO, bins_HI, sigmoid_params, target_counts
        )
        # print(_fun(init_params, loss_data, n_histories))
        # print(_fun_deriv_np(init_params, loss_data, n_histories))

        # t0 = time.time()
        res = minimize(
            _fun, x0=init_params, method="L-BFGS-B", jac=_fun_deriv_np, 
            args=(loss_data, n_histories, )
            )
        p_best = res.x
        loss_best = float(res.fun)
        # t1 = time.time()
        # print(i, t1-t0)
        best_samples = draw_samples(p_best, ran_key, n_histories, sigmoid_params)
        """
        best_counts = np.histogram(best_samples, bins, density=0)[0]
        best_counts = best_counts / np.sum(best_counts)
        """
        bins_plot = np.linspace(sigmoid_param_bounds[j][2], sigmoid_param_bounds[j][3], 50)
        binsc_plot = 0.5*(bins_plot[:-1] + bins_plot[1:])

        _val_bound = fit_params[mask_mass&(~quenched), j].copy()
        target_counts_plot = np.histogram(_val_bound, bins_plot, density=0)[0]
        target_counts_plot = target_counts_plot / np.sum(target_counts_plot)

        best_counts_plot = np.histogram(best_samples, bins_plot, density=0)[0]
        best_counts_plot = best_counts_plot / np.sum(best_counts_plot)

        plt.plot(binsc_plot, target_counts_plot)
        plt.plot(binsc_plot, best_counts_plot, ls='--')
        plt.title("%s, m0 = %.2f"%(param_names[j], massvalsc[i]))
        plt.show() 

        medians_MS[i,j_id] = p_best[0]
        sigmas_MS[i,j_id] = p_best[1]

        # break


In [None]:
corrs_Q = np.array([np.eye(8) for i in range(len(massvalsc))])
corrs_MS = np.array([np.eye(4) for i in range(len(massvalsc))])


covs_Q = jnp.einsum("mi,mij,mj->mij", sigmas_Q, corrs_Q, sigmas_Q)
covs_MS = jnp.einsum("mi,mij,mj->mij", sigmas_MS, corrs_MS, sigmas_MS)

chols_Q = np.array([np.linalg.cholesky(x) for x in covs_Q])
chols_MS = np.array([np.linalg.cholesky(x) for x in covs_MS])

## Quenched fits

In [None]:
BOUNDING_K = 0.1
SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT = OrderedDict(
    frac_quench_x0=(10.0, 13.0),
    frac_quench_k=(0.01, 5.0),
    frac_quench_ylo=(0.0, 1.0),
    frac_quench_yhi=(0.0, 1.0),
)

@jjit
def _get_p_from_u_p_scalar(u_p, bounds):
    lo, hi = bounds
    p0 = 0.5 * (lo + hi)
    p = _sigmoid(u_p, p0, BOUNDING_K, lo, hi)
    return p


for i in range(len(massvals)-1):
    
    masslow, masshigh = massvals[i],massvals[i+1]
    mask_mass = (logmp0_data>masslow)&((logmp0_data<masshigh)) & (upid_samp == -1)# & (data.p50 > 0.4) & (data.p50 < 0.6)
    quenched = (10**q_fit_params[:,0] <= TODAY)
    
    fquench_cen[i] = quenched[mask_mass].sum()/len(quenched[mask_mass])

for i in range(len(massvals)-1):
    
    masslow, masshigh = massvals[i],massvals[i+1]
    mask_mass = (logmp0_data>masslow)&((logmp0_data<masshigh)) & (upid_samp != -1)# & (data.p50 > 0.4) & (data.p50 < 0.6)
    quenched = (10**q_fit_params[:,0] <= TODAY)
    
    fquench_sat[i] = quenched[mask_mass].sum()/len(quenched[mask_mass])
    
def _fun_fquench(x, u_x0, u_k, u_ymin, u_ymax):
    x0 = _get_p_from_u_p_scalar(u_x0, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_x0"])
    k = _get_p_from_u_p_scalar(u_k, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_k"])
    ymin = _get_p_from_u_p_scalar(u_ymin, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_ylo"])
    ymax = _get_p_from_u_p_scalar(u_ymax, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_yhi"])
    _res = _sigmoid(x, x0, k, ymin, ymax)
    return _res

def bound_frac_quench_params(u_x0, u_k, u_ymin, u_ymax):
    x0 = _get_p_from_u_p_scalar(u_x0, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_x0"])
    k = _get_p_from_u_p_scalar(u_k, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_k"])
    ymin = _get_p_from_u_p_scalar(u_ymin, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_ylo"])
    ymax = _get_p_from_u_p_scalar(u_ymax, SFH_PDF_FRAC_QUENCH_BOUNDS_PDICT["frac_quench_yhi"])
    return  x0, k, ymin, ymax

median_lims_fquench_cen = np.zeros(4)
_res = curve_fit(_fun_fquench, massvalsc, fquench_cen, p0 = [20, 4.0, -100., 70.])
median_lims_fquench_cen[:] = _res[0].copy()

median_lims_fquench_sat = np.zeros(4)
_res = curve_fit(_fun_fquench, massvalsc, fquench_sat, p0 = [20, 4.0, -100., 70.])
median_lims_fquench_sat[:] = _res[0].copy()
 
line = plt.plot(massvalsc, fquench_cen, label='Centrals')[0]
plt.plot(massvalsc, _fun_fquench(massvalsc, *median_lims_fquench_cen), ls='--', color=line.get_color())
line = plt.plot(massvalsc, fquench_sat, label='Satellites')[0]
plt.plot(massvalsc, _fun_fquench(massvalsc, *median_lims_fquench_sat), ls='--', color=line.get_color())
plt.ylim(0,1)
plt.xlabel("logmp0")
plt.ylabel("frac_quench")
plt.legend()
plt.title(f"{sim_name}")
plt.show()

median_lims_fquench_cen = np.array(bound_frac_quench_params(*median_lims_fquench_cen))
median_lims_fquench_sat = np.array(bound_frac_quench_params(*median_lims_fquench_sat))
median_lims_fquench_cen[2:] = np.clip(median_lims_fquench_cen[2:], 0.002, 0.998)
median_lims_fquench_sat[2:] = np.clip(median_lims_fquench_sat[2:], 0.002, 0.998)
median_lims_fquench_cen[0] = np.clip(median_lims_fquench_cen[0], 10.2, 12.8)
median_lims_fquench_sat[0] = np.clip(median_lims_fquench_sat[0], 10.2, 12.8)
median_lims_fquench_cen[1] = np.clip(median_lims_fquench_cen[1], 0.2, 4.8)
median_lims_fquench_sat[1] = np.clip(median_lims_fquench_sat[1], 0.2, 4.8)

fquench = fquench_cen

In [None]:
median_lims_fquench_cen, median_lims_fquench_sat

In [None]:
@jjit
def smoothly_clipped_line(x, x0, y0, m, y_lo, y_hi):
    x_lo = (y_lo - y0) / m + x0  # value of x at which y=y_lo
    x_hi = (y_hi - y0) / m + x0  # value of x at which y=y_hi

    dx_lo = x0 - x_lo
    dx_hi = x_hi - x0

    eps_lo = dx_lo / 100.0  # tiny step in x above x_lo
    eps_hi = dx_hi / 100.0  # tiny step in x below x_hi

    xc_lo = x_lo + eps_lo  # value of x at which y=y_lo+ε
    xc_hi = x_hi - eps_hi  # value of x at which y=y_hi-ε

    y_lo_bound = y0 + m * (xc_lo - x0)  # value of y=y_lo+ε
    y_hi_bound = y0 + m * (xc_hi - x0)  # value of y=y_hi-ε

    CLIPPING_K = 20.0 / eps_lo  # Steep transition speed for this Δx
    y_unclipped = y0 + m * (x - x0)
    y_clipped_from_below = _sigmoid(x, xc_lo, CLIPPING_K, y_lo_bound, y_unclipped)
    y_clipped = _sigmoid(x, xc_hi, CLIPPING_K, y_clipped_from_below, y_hi_bound)
    return y_clipped

In [None]:
def _fun(x, int, slp):
    return slp * (x - 12.5) + int

median_lims_Q = np.zeros((2,8))

for i in range(8):
    _mask = np.isfinite(medians_Q[:,i])
    _res = curve_fit(_fun, massvalsc[_mask], medians_Q[_mask,i], p0 = [medians_Q[_mask,i][0], medians_Q[_mask,i][-1]], sigma=1.0/fquench[_mask])
    median_lims_Q[:,i] = _res[0]
    
i = 0
_mask = np.isfinite(medians_Q[:,i])
_res = curve_fit(_fun, massvalsc[_mask], medians_Q[_mask,i], p0 = [medians_Q[_mask,i][0], medians_Q[_mask,i][-1]], sigma=1.0/fquench[_mask])
median_lims_Q[:,i] = _res[0]
i = 2
_mask = np.isfinite(medians_Q[:,i]) & (massvalsc > 12.0)
_res = curve_fit(_fun, massvalsc[_mask], medians_Q[_mask,i], p0 = [medians_Q[_mask,i][0], medians_Q[_mask,i][-1]], sigma=1.0/fquench[_mask])
median_lims_Q[:,i] = _res[0] 
i = 3
_mask = np.isfinite(medians_Q[:,i]) & (massvalsc < 12.5)
_res = curve_fit(_fun, massvalsc[_mask], medians_Q[_mask,i], p0 = [medians_Q[_mask,i][0], medians_Q[_mask,i][-1]], sigma=1.0/fquench[_mask])
median_lims_Q[:,i] = _res[0] 
    
i = 5
_mask = np.isfinite(medians_Q[:,i]) & (medians_Q[:,i] < np.log10(TODAY)+0.1)
_res = curve_fit(_fun, massvalsc[_mask], medians_Q[_mask,i], p0 = [medians_Q[_mask,i][0], medians_Q[_mask,i][-1]], sigma=1.0/fquench[_mask])
#_res = curve_fit(_fun_QT, massvalsc[_mask], medians_Q[_mask,i], p0 = [medians_Q[_mask,i][0], medians_Q[_mask,i][-1]])
median_lims_Q[:,i] = _res[0]


In [None]:
names = ["ulgm", "ulgy", "ul", "utau", "uqt", "uqs", "udrop", "urej"]

for i in range(0,4):
    line=plt.plot(massvalsc, medians_Q[:,i])[0]
    if i==0:
        plt.plot(massvalsc, _fun(massvalsc, *median_lims_Q[:,i]), ls='--', color = line.get_color())
    else:
        plt.plot(massvalsc, _fun(massvalsc, *median_lims_Q[:,i]), ls='--', color = line.get_color())
    plt.title(names[i])
    plt.show()
    
    
for i in range(4,8):
    line=plt.plot(massvalsc, medians_Q[:,i])[0]
    if i==1000:
        plt.plot(massvalsc, _fun(massvalsc, *median_lims_Q[:,i]), ls='--', color = line.get_color())
    else:
        plt.plot(massvalsc, _fun(massvalsc, *median_lims_Q[:,i]), ls='--', color = line.get_color())
    plt.title(names[i])
    plt.show()

In [None]:
def _fun(x, int, slp):
    return slp * (x - 12.5) + int

chol_lims_Q = np.zeros((2,8,8))

for i in range(8):
    _targ = np.sqrt(covs_Q[:,i,i])
    _mask = np.isfinite(_targ)
    _targ = _targ[_mask]
    _res = curve_fit(_fun, massvalsc[_mask], _targ, 
                     p0 = [_targ[0], _targ[-1]], sigma=1.0/fquench[_mask])
    chol_lims_Q[:,i,i] = _res[0] 
    
for i in range(8):
    for j in range(8):
        if i>j:
            _targ = covs_Q[:,i,j]
            _mask = np.isfinite(_targ)
            _targ = _targ[_mask]
            _res = curve_fit(_fun, massvalsc[_mask], _targ, 
                             p0 = [_targ[0], _targ[-1]], sigma=1.0/fquench[_mask])
            chol_lims_Q[:,i,j] = _res[0] 
    
names = ["ulgm", "ulgy", "ul", "utau", "uqt", "uqs", "udrop", "urej"]

for i in range(8):
    line=plt.plot(massvalsc, np.sqrt(covs_Q[:,i,i]))[0]
    plt.plot(massvalsc, _fun(massvalsc, *chol_lims_Q[:,i,i]), ls='--', color = line.get_color())
    plt.title("STD(%s)"%names[i])
    plt.show()
"""
for i in range(8):
    for j in range(8):
        if i>j:
            line=plt.plot(massvalsc, covs_Q[:,i,j])[0]
            plt.plot(massvalsc, _fun(massvalsc, *chol_lims_Q[:,i,j]), ls='--', color = line.get_color())
            plt.title("Corr(%s,%s)"%(names[i], names[j]))
            plt.show()
"""



## Main sequence fits

In [None]:
def _fun(x, int, slp):
    return slp * (x - 12.5) + int

median_lims_MS = np.zeros((2,8))
    
i = 0
_mask = np.isfinite(medians_MS[:,i]) & (massvalsc < 12.8)
_res = curve_fit(_fun, massvalsc[_mask], medians_MS[_mask,i], p0 = [medians_MS[_mask,i][0], medians_MS[_mask,i][-1]], sigma=fquench[_mask])
median_lims_MS[:,i] = _res[0]
i = 1
_mask = np.isfinite(medians_MS[:,i]) & (massvalsc < 12.8)
_res = curve_fit(_fun, massvalsc[_mask], medians_MS[_mask,i], p0 = [medians_MS[_mask,i][0], medians_MS[_mask,i][-1]], sigma=fquench[_mask])
median_lims_MS[:,i] = _res[0] 
i = 2
_mask = np.isfinite(medians_MS[:,i]) & (massvalsc < 12.8)
_res = curve_fit(_fun, massvalsc[_mask], medians_MS[_mask,i], p0 = [medians_MS[_mask,i][0], medians_MS[_mask,i][-1]], sigma=fquench[_mask])
median_lims_MS[:,i] = _res[0] 
i = 3
_mask = np.isfinite(medians_MS[:,i]) & (massvalsc < 12.5)
_res = curve_fit(_fun, massvalsc[_mask], medians_MS[_mask,i], p0 = [medians_MS[_mask,i][0], medians_MS[_mask,i][-1]], sigma=fquench[_mask])
median_lims_MS[:,i] = _res[0] 


In [None]:
names = ["ulgm", "ulgy", "ul", "utau", "uqt", "uqs", "udrop", "urej"]

for i in range(0,4):
    line=plt.plot(massvalsc, medians_MS[:,i])[0]
    plt.plot(massvalsc, _fun(massvalsc, *median_lims_MS[:,i]), ls='--', color = line.get_color())
        
    plt.title(names[i])
    plt.show()
    

In [None]:
for i in range(0,4):
    line=plt.plot(massvalsc, medians_Q[:,i])[0]
    if i==0:
        plt.plot(massvalsc, _fun(massvalsc, *median_lims_Q[:,i]), ls='--', color = line.get_color())
    else:
        plt.plot(massvalsc, _fun(massvalsc, *median_lims_Q[:,i]), ls='--', color = line.get_color())
    plt.title(names[i])

    line=plt.plot(massvalsc, medians_MS[:,i])[0]
    plt.plot(massvalsc, _fun(massvalsc, *median_lims_MS[:,i]), ls='--', color = line.get_color())
        
    plt.show()

In [None]:
chol_lims_MS = np.zeros((2,4,4))

for i in range(4):
    _targ = np.sqrt(covs_MS[:,i,i])
    _mask = np.isfinite(_targ)
    _targ = _targ[_mask]
    _res = curve_fit(_fun, massvalsc[_mask], _targ, 
                     p0 = [_targ[0], _targ[-1]], sigma=fquench[_mask])
    chol_lims_MS[:,i,i] = _res[0] 
    
for i in range(4):
    for j in range(4):
        if i>j:
            _targ = covs_MS[:,i,j]
            _mask = np.isfinite(_targ)
            _targ = _targ[_mask]
            _res = curve_fit(_fun, massvalsc[_mask], _targ, 
                             p0 = [_targ[0], _targ[-1]], sigma=fquench[_mask])
            chol_lims_MS[:,i,j] = _res[0] 
    
names = ["ulgm", "ulgy", "ul", "utau", "uqt", "uqs", "udrop", "urej"]

for i in range(4):
    line=plt.plot(massvalsc, np.sqrt(covs_MS[:,i,i]))[0]
    plt.plot(massvalsc, _fun(massvalsc, *chol_lims_MS[:,i,i]), ls='--', color = line.get_color())
    plt.title("STD(%s)"%names[i])
    plt.show()
    
for i in range(4):
    for j in range(4):
        if i>j:
            line=plt.plot(massvalsc, covs_MS[:,i,j])[0]
            plt.plot(massvalsc, _fun(massvalsc, *chol_lims_MS[:,i,j]), ls='--', color = line.get_color())
            plt.title("Chol(%s,%s)"%(names[i], names[j]))
            plt.show()

In [None]:
for i in range(4):
    line=plt.plot(massvalsc, np.sqrt(covs_Q[:,i,i]))[0]
    plt.plot(massvalsc, _fun(massvalsc, *chol_lims_Q[:,i,i]), ls='--', color = line.get_color())
    plt.title("STD(%s)"%names[i])

    line=plt.plot(massvalsc, np.sqrt(covs_MS[:,i,i]))[0]
    plt.plot(massvalsc, _fun(massvalsc, *chol_lims_MS[:,i,i]), ls='--', color = line.get_color())

    plt.show()
    

In [None]:
from diffstarpop.kernels.sfh_pdf_tpeak_line_sepms_satfrac_sigslope import SFH_PDF_QUENCH_MU_BOUNDS_PDICT
BOUNDING_K = 0.1

@jjit
def _get_p_from_u_p_scalar(u_p, bounds):
    lo, hi = bounds
    p0 = 0.5 * (lo + hi)
    p = _sigmoid(u_p, p0, BOUNDING_K, lo, hi)
    return p

def _fun(x, int, slp):
    return slp * (x - 12.5) + int

@jjit
def _sig_slope(x, xtp, ytp, x0, slope_k, lo, hi):
    slope = _sigmoid(x, x0, slope_k, lo, hi)
    return ytp + slope * (x - xtp)

def _fun_Mcrit(x, u_xtp, u_ytp, u_lo, u_hi):
    xtp = _get_p_from_u_p_scalar(u_xtp, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_xtp"])
    ytp = _get_p_from_u_p_scalar(u_ytp, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_ytp"])
    lo = _get_p_from_u_p_scalar(u_lo, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_lo"])
    hi = _get_p_from_u_p_scalar(u_hi, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_hi"])
    x0 = xtp
    slope_k = 3.0
    return _sig_slope(x, xtp, ytp, x0, slope_k, lo, hi)

def return_sigslope_bound_params(u_xtp, u_ytp, u_lo, u_hi):
    xtp = _get_p_from_u_p_scalar(u_xtp, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_xtp"])
    ytp = _get_p_from_u_p_scalar(u_ytp, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_ytp"])
    lo = _get_p_from_u_p_scalar(u_lo, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_lo"])
    hi = _get_p_from_u_p_scalar(u_hi, SFH_PDF_QUENCH_MU_BOUNDS_PDICT["mean_ulgm_mseq_hi"])
    return xtp, ytp, lo, hi

i = 0
_mask = np.isfinite(medians_Q[:,i])
_ulgm_sigslope_qseq = curve_fit(_fun_Mcrit, massvalsc[_mask], medians_Q[_mask,i], p0 = [12.5, 12.2, 1.0, 0.1], sigma=1.0/fquench[_mask])
# median_lims_Q[:,i] = _res[0]

i = 0
_mask = np.isfinite(medians_MS[:,i])
_ulgm_sigslope_mseq = curve_fit(_fun_Mcrit, massvalsc[_mask], medians_MS[_mask,i], p0 = [12.5, 12.2, 1.0, 0.1], sigma=fquench[_mask])
# median_lims_Q[:,i] = _res[0]

line=plt.plot(massvalsc, medians_Q[:,i], label='Quenched')[0]
plt.plot(massvalsc, _fun_Mcrit(massvalsc, *_ulgm_sigslope_qseq[0]), ls='--', color = line.get_color())
line=plt.plot(massvalsc, medians_MS[:,i], label='Main Seq.')[0]
plt.plot(massvalsc, _fun_Mcrit(massvalsc, *_ulgm_sigslope_mseq[0]), ls='--', color = line.get_color(), label='Diffstarpop')
plt.title(sim_name)
plt.ylabel("mean %s"%names[i])
plt.xlabel("logmp0")
plt.legend()
plt.show()

_ulgm_sigslope_qseq = np.array(return_sigslope_bound_params(*_ulgm_sigslope_qseq[0]))
_ulgm_sigslope_mseq = np.array(return_sigslope_bound_params(*_ulgm_sigslope_mseq[0]))

#print(_ulgm_sigslope_qseq)
#print(_ulgm_sigslope_mseq)

for i, name in enumerate(["xtp","ytp","lo","hi"]):
    key = f"mean_ulgm_mseq_{name}"
    bound_min, bound_max = SFH_PDF_QUENCH_MU_BOUNDS_PDICT[key]
    print(bound_min, bound_max)
    lower = bound_min + 0.1 * abs(bound_max - bound_min)
    upper = bound_max - 0.1 * abs(bound_max - bound_min)
    _ulgm_sigslope_qseq[i] = np.clip(_ulgm_sigslope_qseq[i], lower, upper)
    _ulgm_sigslope_mseq[i] = np.clip(_ulgm_sigslope_mseq[i], lower, upper)
    print(lower, upper)


print(_ulgm_sigslope_qseq)
print(_ulgm_sigslope_mseq)



In [None]:
from diffstarpop.kernels.sfh_pdf_tpeak_line_sepms_satfrac import (
    SFH_PDF_QUENCH_MU_BOUNDS_PDICT, 
    SFH_PDF_QUENCH_COV_MS_BLOCK_BOUNDS_PDICT, 
    SFH_PDF_QUENCH_COV_Q_BLOCK_BOUNDS_PDICT
)

median_lims_Q = np.round(median_lims_Q, 3)
median_lims_MS = np.round(median_lims_MS, 3)
chol_lims_Q = np.round(chol_lims_Q, 3)
chol_lims_MS = np.round(chol_lims_MS, 3)

# Clip median_lims_Q and median_lims_MS using dynamic 10% margin
for i in range(4):
    for suffix in ['mseq', 'qseq']:
        for kind, arr in zip(['int', 'slp'], [0, 1]):
            key = f"mean_{names[i]}_{suffix}_{kind}"
            bound_min, bound_max = SFH_PDF_QUENCH_MU_BOUNDS_PDICT[key]
            lower = bound_min + 0.1 * abs(bound_min)
            upper = bound_max - 0.1 * abs(bound_max)
            array = median_lims_Q if suffix == 'mseq' else median_lims_MS
            array[arr, i] = np.clip(array[arr, i], lower, upper)
            if array[arr, i] == 0.0: array[arr, i] += 1e-2

for i in range(4, 8):
    for kind, arr in zip(['int', 'slp'], [0, 1]):
        key = f"mean_{names[i]}_{kind}"
        bound_min, bound_max = SFH_PDF_QUENCH_MU_BOUNDS_PDICT[key]
        lower = bound_min + 0.1 * abs(bound_min)
        upper = bound_max - 0.1 * abs(bound_max)
        median_lims_Q[arr, i] = np.clip(median_lims_Q[arr, i], lower, upper)
        if median_lims_Q[arr, i] == 0.0: median_lims_Q[arr, i] += 1e-2

def clip_with_10pct_margin(value, bound_min, bound_max):
    buffer_min = 0.1 * abs(bound_min)
    buffer_max = 0.1 * abs(bound_max)
    lower = bound_min + buffer_min
    upper = bound_max - buffer_max
    return np.clip(value, lower, upper)

# Clip chol_lims_Q and chol_lims_MS before writing to file
for i in range(4):
    for suffix, arr in zip(['mseq', 'qseq'], [chol_lims_Q, chol_lims_MS]):
        for kind, idx in zip(['int', 'slp'], [0, 1]):
            key = f"std_{names[i]}_{suffix}_{kind}"
            bounds_dict = SFH_PDF_QUENCH_COV_MS_BLOCK_BOUNDS_PDICT
            if key in bounds_dict:
                bmin, bmax = bounds_dict[key]
                arr[idx, i, i] = clip_with_10pct_margin(arr[idx, i, i], bmin, bmax)
                if arr[idx, i, i] == 0.0: arr[idx, i, i] += 1e-2

for i in range(4, 8):
    for kind, idx in zip(['int', 'slp'], [0, 1]):
        key = f"std_{names[i]}_{kind}"
        bounds_dict = SFH_PDF_QUENCH_COV_Q_BLOCK_BOUNDS_PDICT
        if key in bounds_dict:
            bmin, bmax = bounds_dict[key]
            chol_lims_Q[idx, i, i] = clip_with_10pct_margin(chol_lims_Q[idx, i, i], bmin, bmax)
            if chol_lims_Q[idx, i, i] == 0.0: chol_lims_Q[idx, i, i] += 1e-2

In [None]:
median_lims_Q[1]

# Generate params.py script

### sepms_satfrac

In [None]:
with open(f"params_diffstarfits_line_sepms_satfrac_{sim_name}.py", "w") as f:
    f.write("from collections import OrderedDict, namedtuple\n\n")
    f.write("import typing\n")
    # f.write("from jax import jit as jjit\n")
    f.write("from jax import numpy as jnp\n\n")
    f.write("from ..satquenchpop_model import (\n")
    f.write("    DEFAULT_SATQUENCHPOP_PARAMS,\n")
    # f.write("    get_bounded_satquenchpop_params,\n")
    # f.write("    get_unbounded_satquenchpop_params,\n")
    f.write(")\n")
    f.write("from ..defaults_tpeak_line_sepms_satfrac import get_unbounded_diffstarpop_params\n\n")


    # Individual OrderedDicts
    f.write("SFH_PDF_QUENCH_MU_PDICT = OrderedDict([\n")
    for i in range(4):
        f.write(f"    ('mean_{names[i]}_mseq_int', {median_lims_Q[0, i]:.3f}),\n")
        f.write(f"    ('mean_{names[i]}_mseq_slp', {median_lims_Q[1, i]:.3f}),\n")
    for i in range(4):
        f.write(f"    ('mean_{names[i]}_qseq_int', {median_lims_MS[0, i]:.3f}),\n")
        f.write(f"    ('mean_{names[i]}_qseq_slp', {median_lims_MS[1, i]:.3f}),\n")
    for i in range(4,8):
        f.write(f"    ('mean_{names[i]}_int', {median_lims_Q[0, i]:.3f}),\n")
        f.write(f"    ('mean_{names[i]}_slp', {median_lims_Q[1, i]:.3f}),\n")
    f.write("])\n\n")


    f.write("SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict([\n")

    for i in range(4):
        f.write(f"    ('std_{names[i]}_mseq_int', {chol_lims_Q[0, i, i]:.3f}),\n")
        f.write(f"    ('std_{names[i]}_mseq_slp', {chol_lims_Q[1, i, i]:.3f}),\n")
    for i in range(4):
        f.write(f"    ('std_{names[i]}_qseq_int', {chol_lims_MS[0, i, i]:.3f}),\n")
        f.write(f"    ('std_{names[i]}_qseq_slp', {chol_lims_MS[1, i, i]:.3f}),\n")
    f.write("])\n\n")

    f.write("SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict([\n")
    for i in range(4, 8):
        f.write(f"    ('std_{names[i]}_int', {chol_lims_Q[0, i, i]:.3f}),\n")
        f.write(f"    ('std_{names[i]}_slp', {chol_lims_Q[1, i, i]:.3f}),\n")
    f.write("])\n\n")

    f.write("SFH_PDF_FRAC_QUENCH_PDICT = OrderedDict([\n")
    f.write(f"    ('frac_quench_cen_x0', {median_lims_fquench_cen[0]:.3f}),\n")
    f.write(f"    ('frac_quench_cen_k', {median_lims_fquench_cen[1]:.3f}),\n")
    f.write(f"    ('frac_quench_cen_ylo', {median_lims_fquench_cen[2]:.3f}),\n")
    f.write(f"    ('frac_quench_cen_yhi', {median_lims_fquench_cen[3]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_x0', {median_lims_fquench_sat[0]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_k', {median_lims_fquench_sat[1]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_ylo', {median_lims_fquench_sat[2]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_yhi', {median_lims_fquench_sat[3]:.3f}),\n")
    f.write("])\n\n")

    # Final combination logic
    f.write("SFH_PDF_QUENCH_PDICT = SFH_PDF_FRAC_QUENCH_PDICT.copy()\n")
    f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_MU_PDICT)\n")
    f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT)\n")
    f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT)\n\n")


    f.write("QseqParams = namedtuple('QseqParams', list(SFH_PDF_QUENCH_PDICT.keys()))\n")
    f.write("SFH_PDF_QUENCH_PARAMS = QseqParams(**SFH_PDF_QUENCH_PDICT)\n")

    f.write("_UPNAMES = ['u_' + key for key in QseqParams._fields]\n")
    f.write("QseqUParams = namedtuple('QseqUParams', _UPNAMES)\n\n")

    f.write("\n# Define a namedtuple container for the params of each component\n")
    f.write("class DiffstarPopParams(typing.NamedTuple):\n")
    f.write("    sfh_pdf_cens_params: jnp.array\n")
    f.write("    satquench_params: jnp.array\n\n")

    f.write(f"DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS = DiffstarPopParams(\n")
    f.write("    SFH_PDF_QUENCH_PARAMS, DEFAULT_SATQUENCHPOP_PARAMS\n")
    f.write(")\n\n")

    f.write(f"_U_PNAMES = ['u_' + key for key in DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS._fields]\n")
    f.write("DiffstarPopUParams = namedtuple('DiffstarPopUParams', _U_PNAMES)\n\n")

    f.write(f"DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_U_PARAMS = get_unbounded_diffstarpop_params(\n")
    f.write(f"    DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS\n")
    f.write(")\n")


### sepms_satfrac_sigslope

In [None]:
with open(f"params_diffstarfits_line_sepms_satfrac_sigslope_{sim_name}.py", "w") as f:
    f.write("from collections import OrderedDict, namedtuple\n\n")

    f.write("import typing\n")
    # f.write("from jax import jit as jjit\n")
    f.write("from jax import numpy as jnp\n\n")
    f.write("from ..satquenchpop_model import (\n")
    f.write("    DEFAULT_SATQUENCHPOP_PARAMS,\n")
    # f.write("    get_bounded_satquenchpop_params,\n")
    # f.write("    get_unbounded_satquenchpop_params,\n")
    f.write(")\n")
    f.write("from ..defaults_tpeak_line_sepms_satfrac_sigslope import get_unbounded_diffstarpop_params\n\n")

    # Individual OrderedDicts
    f.write("SFH_PDF_QUENCH_MU_PDICT = OrderedDict([\n")
    f.write(f"    ('mean_{names[0]}_mseq_xtp', {_ulgm_sigslope_mseq[0]:.3f}),\n")
    f.write(f"    ('mean_{names[0]}_mseq_ytp', {_ulgm_sigslope_mseq[1]:.3f}),\n")
    f.write(f"    ('mean_{names[0]}_mseq_lo', {_ulgm_sigslope_mseq[2]:.3f}),\n")
    f.write(f"    ('mean_{names[0]}_mseq_hi', {_ulgm_sigslope_mseq[3]:.3f}),\n")
    for i in range(1,4):
        f.write(f"    ('mean_{names[i]}_mseq_int', {median_lims_Q[0, i]:.3f}),\n")
        f.write(f"    ('mean_{names[i]}_mseq_slp', {median_lims_Q[1, i]:.3f}),\n")

    f.write(f"    ('mean_{names[0]}_qseq_xtp', {_ulgm_sigslope_qseq[0]:.3f}),\n")
    f.write(f"    ('mean_{names[0]}_qseq_ytp', {_ulgm_sigslope_qseq[1]:.3f}),\n")
    f.write(f"    ('mean_{names[0]}_qseq_lo', {_ulgm_sigslope_qseq[2]:.3f}),\n")
    f.write(f"    ('mean_{names[0]}_qseq_hi', {_ulgm_sigslope_qseq[3]:.3f}),\n")
    for i in range(1,4):
        f.write(f"    ('mean_{names[i]}_qseq_int', {median_lims_MS[0, i]:.3f}),\n")
        f.write(f"    ('mean_{names[i]}_qseq_slp', {median_lims_MS[1, i]:.3f}),\n")
    for i in range(4,8):
        f.write(f"    ('mean_{names[i]}_int', {median_lims_Q[0, i]:.3f}),\n")
        f.write(f"    ('mean_{names[i]}_slp', {median_lims_Q[1, i]:.3f}),\n")
    f.write("])\n\n")


    f.write("SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict([\n")

    for i in range(4):
        f.write(f"    ('std_{names[i]}_mseq_int', {chol_lims_Q[0, i, i]:.3f}),\n")
        f.write(f"    ('std_{names[i]}_mseq_slp', {chol_lims_Q[1, i, i]:.3f}),\n")
    for i in range(4):
        f.write(f"    ('std_{names[i]}_qseq_int', {chol_lims_MS[0, i, i]:.3f}),\n")
        f.write(f"    ('std_{names[i]}_qseq_slp', {chol_lims_MS[1, i, i]:.3f}),\n")
    f.write("])\n\n")

    f.write("SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict([\n")
    for i in range(4, 8):
        f.write(f"    ('std_{names[i]}_int', {chol_lims_Q[0, i, i]:.3f}),\n")
        f.write(f"    ('std_{names[i]}_slp', {chol_lims_Q[1, i, i]:.3f}),\n")
    f.write("])\n\n")


    f.write("SFH_PDF_FRAC_QUENCH_PDICT = OrderedDict([\n")
    f.write(f"    ('frac_quench_cen_x0', {median_lims_fquench_cen[0]:.3f}),\n")
    f.write(f"    ('frac_quench_cen_k', {median_lims_fquench_cen[1]:.3f}),\n")
    f.write(f"    ('frac_quench_cen_ylo', {median_lims_fquench_cen[2]:.3f}),\n")
    f.write(f"    ('frac_quench_cen_yhi', {median_lims_fquench_cen[3]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_x0', {median_lims_fquench_sat[0]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_k', {median_lims_fquench_sat[1]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_ylo', {median_lims_fquench_sat[2]:.3f}),\n")
    f.write(f"    ('frac_quench_sat_yhi', {median_lims_fquench_sat[3]:.3f}),\n")
    f.write("])\n\n")

    # Final combination logic
    f.write("SFH_PDF_QUENCH_PDICT = SFH_PDF_FRAC_QUENCH_PDICT.copy()\n")
    f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_MU_PDICT)\n")
    f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT)\n")
    f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT)\n\n")

    f.write("QseqParams = namedtuple('QseqParams', list(SFH_PDF_QUENCH_PDICT.keys()))\n")
    f.write("SFH_PDF_QUENCH_PARAMS = QseqParams(**SFH_PDF_QUENCH_PDICT)\n")

    f.write("_UPNAMES = ['u_' + key for key in QseqParams._fields]\n")
    f.write("QseqUParams = namedtuple('QseqUParams', _UPNAMES)\n\n")

    f.write("\n# Define a namedtuple container for the params of each component\n")
    f.write("class DiffstarPopParams(typing.NamedTuple):\n")
    f.write("    sfh_pdf_cens_params: jnp.array\n")
    f.write("    satquench_params: jnp.array\n\n")

    f.write(f"DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS = DiffstarPopParams(\n")
    f.write("    SFH_PDF_QUENCH_PARAMS, DEFAULT_SATQUENCHPOP_PARAMS\n")
    f.write(")\n\n")

    f.write(f"_U_PNAMES = ['u_' + key for key in DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS._fields]\n")
    f.write("DiffstarPopUParams = namedtuple('DiffstarPopUParams', _U_PNAMES)\n\n")

    f.write(f"DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_U_PARAMS = get_unbounded_diffstarpop_params(\n")
    f.write(f"    DIFFSTARFITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS\n")
    f.write(")\n")



# Default Diffstarpop

In [None]:
print("SFH_PDF_QUENCH_MU_PDICT = OrderedDict(")
for i in range(len(names)):
    print(f"    mean_{names[i]}_int={median_lims_Q[0, i]:.2f},")
    print(f"    mean_{names[i]}_slp={median_lims_Q[1, i]:.2f},")
print(")")


In [None]:
bounds_int = [
    (11.0, 13.0),
    (-1.0, 3.5),
    (-3.0, 5.0),
    (-25.0, 50.0),
    (0.0, 2.0),
    (-5.0, 2.0),
    (-3.0, 2.0),
    (-10.0, 2.0),
]
bounds_slp = [
    (-20.0, 20.0),
    (-20.0, 20.0),
    (-20.0, 20.0),
    (-20.0, 20.0),
    (-20.0, 20.0),
    (-20.0, 20.0),
    (-20.0, 20.0),
    (-20.0, 20.0),
]

print("SFH_PDF_QUENCH_MU_BOUNDS_PDICT = OrderedDict(")
for i in range(len(names)):
    print(f"    mean_{names[i]}_int=({bounds_int[i][0]:.1f}, {bounds_int[i][1]:.1f}),")
    print(f"    mean_{names[i]}_slp=({bounds_slp[i][0]:.1f}, {bounds_slp[i][1]:.1f}),")
print(")")

In [None]:
print("SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict(")
for i in range(4):
    print(f"    std_{names[i]}_int={chol_lims_Q[0, i, i]:.2f},")
    print(f"    std_{names[i]}_slp={chol_lims_Q[1, i, i]:.2f},")
for i in range(4):
    for j in range(4):
        if i>j:
            print(f"    rho_{names[i]}_{names[j]}_int=0.01,")
            print(f"    rho_{names[i]}_{names[j]}_slp=0.01,")
print(")")


In [None]:
bounds_std_int = [
    (0.01, 1.0),
    (0.01, 1.0),
    (0.01, 1.0),
    (1.0, 12.0),
    (0.01, 0.5),
    (0.01, 1.0),
    (0.01, 2.0),
    (0.01, 2.0),
]

bounds_std_slp = [
    (-1.0, 1.0),
    (-1.0, 1.0),
    (-1.0, 1.0),
    (-3.0, 3.0),
    (-1.0, 1.0),
    (-1.0, 1.0),
    (-1.0, 1.0),
    (-1.0, 1.0),
]

bounds_rho_int = (-20, 20)
bounds_rho_slp = (-20, 20)

print("SFH_PDF_QUENCH_COV_MS_BLOCK_BOUNDS_PDICT = OrderedDict(")
for i in range(4):
    print(f"    std_{names[i]}_int=({bounds_std_int[i][0]:.2f}, {bounds_std_int[i][1]:.1f}),")
    print(f"    std_{names[i]}_slp=({bounds_std_slp[i][0]:.2f}, {bounds_std_slp[i][1]:.1f}),")
for i in range(4):
    for j in range(4):
        if i>j:
            print(f"    rho_{names[i]}_{names[j]}_int=({bounds_rho_int[0]:.1f}, {bounds_rho_int[1]:.1f}),")
            print(f"    rho_{names[i]}_{names[j]}_slp=({bounds_rho_slp[0]:.1f}, {bounds_rho_slp[1]:.1f}),")
            # print(f"    rho_{names[i]}_{names[j]}_int=RHO_BOUNDS,")
            # print(f"    rho_{names[i]}_{names[j]}_slp=RHO_BOUNDS,")
print(")")


In [None]:
print("SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict(")
for i in range(4,8):
    print(f"    std_{names[i]}_int={chol_lims_Q[0, i, i]:.2f},")
    print(f"    std_{names[i]}_slp={chol_lims_Q[1, i, i]:.2f},")
for i in range(4,8):
    for j in range(4,8):
        if i>j:
            print(f"    rho_{names[i]}_{names[j]}_int=0.01,")
            print(f"    rho_{names[i]}_{names[j]}_slp=0.01,")
print(")")

In [None]:
print("SFH_PDF_QUENCH_COV_Q_BLOCK_BOUNDS_PDICT = OrderedDict(")
for i in range(4,8):
    print(f"    std_{names[i]}_int=({bounds_std_int[i][0]:.2f}, {bounds_std_int[i][1]:.1f}),")
    print(f"    std_{names[i]}_slp=({bounds_std_slp[i][0]:.2f}, {bounds_std_slp[i][1]:.1f}),")
for i in range(4,8):
    for j in range(4,8):
        if i>j:
            print(f"    rho_{names[i]}_{names[j]}_int=({bounds_rho_int[0]:.1f}, {bounds_rho_int[1]:.1f}),")
            print(f"    rho_{names[i]}_{names[j]}_slp=({bounds_rho_slp[0]:.1f}, {bounds_rho_slp[1]:.1f}),")
            # print(f"    rho_{names[i]}_{names[j]}_int=RHO_BOUNDS,")
            # print(f"    rho_{names[i]}_{names[j]}_slp=RHO_BOUNDS,")
print(")")

In [None]:
print(f"    frac_quench_x0={median_lims_fquenc[0]:.2f},")
print(f"    frac_quench_k={median_lims_fquenc[1]:.2f},")
print(f"    frac_quench_ylo={median_lims_fquenc[2]:.2f},")
print(f"    frac_quench_yhi={median_lims_fquenc[3]:.2f},")


## Bounding dictionaries

In [None]:
print("BOUNDING_MEAN_VALS_PDICT = OrderedDict(")
for i in range(8):
    print(f"    mean_{names[i]}=({bounds_int[i][0]:.1f}, {bounds_int[i][1]:.1f}),")
print(")")



In [None]:
print("BOUNDING_STD_VALS_PDICT = OrderedDict(")
for i in range(8):
    print(f"    std_{names[i]}=({bounds_std_int[i][0]:.2f}, {bounds_std_int[i][1]:.1f}),")
print(")")

In [None]:
print("BOUNDING_RHO_VALS_PDICT = OrderedDict(")
for i in range(4):
    for j in range(4):
        if i>j:
            print(f"    rho_{names[i]}_{names[j]}=RHO_BOUNDS,")

for i in range(4,8):
    for j in range(4,8):
        if i>j:
            print(f"    rho_{names[i]}_{names[j]}=RHO_BOUNDS,")
print(")")


In [None]:
print("@jjit")
print("def _get_mean_u_params(params, logmp0):")
for i in range(8):
    print(
f"""
    {names[i]} = line_model(
        logmp0,
        params.mean_{names[i]}_int,
        params.mean_{names[i]}_slp,
        *BOUNDING_VALS.mean_{names[i]},
    )"""
    )
print("    return (ulgm, ulgy, ul, utau, uqt, uqs, udrop, urej)")

    

In [None]:
print("@jjit")
print("def _get_cov_params_qseq_ms_block(params, logmp0):")
for i in range(4):
    print(
f"""
    std_{names[i]} = line_model(
        logmp0,
        params.std_{names[i]}_int,
        params.std_{names[i]}_slp,
        *BOUNDING_VALS.std_{names[i]},
    )"""
    )
for i in range(4):
    for j in range(4):
        if i>j:
             print(
f"""
    rho_{names[i]}_{names[j]} = line_model(
        logmp0,
        params.rho_{names[i]}_{names[j]}_int,
        params.rho_{names[i]}_{names[j]}_slp,
        *BOUNDING_VALS.rho_{names[i]}_{names[j]},
    )"""
            )
             
print(
"""
    diags = std_ulgm, std_ulgy, std_ul, std_utau
    off_diags = (
        rho_ulgy_ulgm,
        rho_ul_ulgm,
        rho_ul_ulgy,
        rho_utau_ulgm,
        rho_utau_ulgy,
        rho_utau_ul,
    )
    return diags, off_diags
"""
)


In [None]:
print("@jjit")
print("def _get_cov_params_qseq_q_block(params, logmp0):")
for i in range(4,8):
    print(
f"""
    std_{names[i]} = line_model(
        logmp0,
        params.std_{names[i]}_int,
        params.std_{names[i]}_slp,
        *BOUNDING_VALS.std_{names[i]},
    )"""
    )
for i in range(4,8):
    for j in range(4,8):
        if i>j:
             print(
f"""
    rho_{names[i]}_{names[j]} = line_model(
        logmp0,
        params.rho_{names[i]}_{names[j]}_int,
        params.rho_{names[i]}_{names[j]}_slp,
        *BOUNDING_VALS.rho_{names[i]}_{names[j]},
    )"""
            )
             
print(
"""
    diags = std_uqt, std_uqs, std_udrop, std_urej
    off_diags = (
        rho_uqs_uqt,
        rho_udrop_uqt,
        rho_udrop_uqs,
        rho_urej_uqt,
        rho_urej_uqs,
        rho_urej_udrop,
    )

    return diags, off_diags
"""
)

# Sepms satfrac Diffstarpop

In [None]:
print("SFH_PDF_QUENCH_MU_PDICT = OrderedDict(")
for i in range(4):
    print(f"    mean_{names[i]}_mseq_int={median_lims_Q[0, i]:.2f},")
    print(f"    mean_{names[i]}_mseq_slp={median_lims_Q[1, i]:.2f},")
for i in range(4):
    print(f"    mean_{names[i]}_qseq_int={median_lims_Q[0, i]:.2f},")
    print(f"    mean_{names[i]}_qseq_slp={median_lims_Q[1, i]:.2f},")

for i in range(4,8):
    print(f"    mean_{names[i]}_int={median_lims_Q[0, i]:.2f},")
    print(f"    mean_{names[i]}_slp={median_lims_Q[1, i]:.2f},")
print(")")

In [None]:
print("SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict(")
for i in range(4):
    print(f"    std_{names[i]}_mseq_int={chol_lims_Q[0, i, i]:.2f},")
    print(f"    std_{names[i]}_mseq_slp={chol_lims_Q[1, i, i]:.2f},")
for i in range(4):
    print(f"    std_{names[i]}_qseq_int={chol_lims_Q[0, i, i]:.2f},")
    print(f"    std_{names[i]}_qseq_slp={chol_lims_Q[1, i, i]:.2f},")
print(")")


print("SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict(")
for i in range(4,8):
    print(f"    std_{names[i]}_int={chol_lims_Q[0, i, i]:.2f},")
    print(f"    std_{names[i]}_slp={chol_lims_Q[1, i, i]:.2f},")
print(")")


## Sigslope

In [None]:
i=0
print(f"    mean_{names[i]}_mseq_xtp={_ulgm_sigslope_res[0][0]:.3f},")
print(f"    mean_{names[i]}_mseq_ytp={_ulgm_sigslope_res[0][1]:.3f},")
print(f"    mean_{names[i]}_mseq_lo={_ulgm_sigslope_res[0][2]:.3f},")
print(f"    mean_{names[i]}_mseq_hi={_ulgm_sigslope_res[0][3]:.3f},")

print(f"    mean_{names[i]}_qseq_xtp={_ulgm_sigslope_res[0][0]:.3f},")
print(f"    mean_{names[i]}_qseq_ytp={_ulgm_sigslope_res[0][1]:.3f},")
print(f"    mean_{names[i]}_qseq_lo={_ulgm_sigslope_res[0][2]:.3f},")
print(f"    mean_{names[i]}_qseq_hi={_ulgm_sigslope_res[0][3]:.3f},")

In [None]:
print("""
    mean_ulgm_mseq_xtp=(11.0, 14.0),
    mean_ulgm_mseq_ytp=(11.0, 14.0),
    mean_ulgm_mseq_lo=(-1.0, 5.0),
    mean_ulgm_mseq_hi=(-5.0, 1.0),
      
    mean_ulgm_qseq_xtp=(11.0, 14.0),
    mean_ulgm_qseq_ytp=(11.0, 14.0),
    mean_ulgm_qseq_lo=(-1.0, 5.0),
    mean_ulgm_qseq_hi=(-5.0, 1.0),
            
""")
