In [None]:
%matplotlib inline

In [None]:
#__ = plt.style.use("./diffstar.mplstyle")
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from scipy.optimize import curve_fit
mred = u"#d62728"
morange = u"#ff7f0e"
mgreen = u"#2ca02c"
mblue = u"#1f77b4"
mpurple = u"#9467bd"
plt.rc('font', family="serif")
plt.rc('font', size=22)
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}') #necessary to use \dfrac

In [None]:
from jax import numpy as jnp
from jax import jit as jjit
from jax import vmap
import os
import h5py
from collections import OrderedDict, namedtuple



In [None]:
names = ["ulgm", "ulgy", "ul", "utau", "uqt", "uqs", "udrop", "urej"]

sim_name_pathname_list = [
    "smdpl",
    "smdpl_dr1",
    "tng",
    "galacticus_in_situ",
    "galacticus_in_plus_ex_situ",
]

sim_name_tuplename_list = [
    "SMDPL",
    "SMDPL_DR1",
    "TNG",
    "GALACTICUS_IN",
    "GALACTICUS_INPLUSEX",
]

sim_index = 4
sigslope = 1

sim_name_pathname = sim_name_pathname_list[sim_index]
sim_name_tuplename = sim_name_tuplename_list[sim_index]
sigslope_name = "_sigslope" if sigslope else ""

path_all = "/Users/alarcon/Documents/diffmah_data/tpeak/"
test_path = f"test_diffstarfits_params_sepms_satfrac{sigslope_name}_{sim_name_pathname}"

params_path = os.path.join(path_all, test_path, "bestfit_diffstarpop_params_mstar_ssfr_cen_sat.npz")
assert os.path.exists(params_path)
params_file = np.load(params_path)
all_u_params = params_file["diffstarpop_u_params"]

In [None]:
if sigslope:
    from diffstarpop.loss_kernels.namedtuple_utils_tpeak_sepms_satfrac_sigslope import (
        tuple_to_array,
        register_tuple_new_diffstarpop_tpeak,
        array_to_tuple_new_diffstarpop_tpeak,
    )
    from diffstarpop.kernels.defaults_tpeak_line_sepms_satfrac_sigslope import (
        DEFAULT_DIFFSTARPOP_U_PARAMS,
        DEFAULT_DIFFSTARPOP_PARAMS,
        get_bounded_diffstarpop_params,
        get_unbounded_diffstarpop_params,
    )
    from diffstarpop.kernels.sfh_pdf_tpeak_line_sepms_satfrac_sigslope import SFH_PDF_QUENCH_BOUNDS_PDICT
else:
    from diffstarpop.loss_kernels.namedtuple_utils_tpeak_sepms_satfrac import (
        tuple_to_array,
        register_tuple_new_diffstarpop_tpeak,
        array_to_tuple_new_diffstarpop_tpeak,
    )
    from diffstarpop.kernels.defaults_tpeak_line_sepms_satfrac import (
        DEFAULT_DIFFSTARPOP_U_PARAMS,
        DEFAULT_DIFFSTARPOP_PARAMS,
        get_bounded_diffstarpop_params,
        get_unbounded_diffstarpop_params,
    )
    from diffstarpop.kernels.sfh_pdf_tpeak_line_sepms_satfrac import SFH_PDF_QUENCH_BOUNDS_PDICT
unbound_params_dict = OrderedDict(diffstarpop_u_params=DEFAULT_DIFFSTARPOP_U_PARAMS)
UnboundParams = namedtuple("UnboundParams", list(unbound_params_dict.keys()))


bestfit_u_tuple = array_to_tuple_new_diffstarpop_tpeak(
    all_u_params, UnboundParams
)
diffstarpop_params = get_bounded_diffstarpop_params(
    bestfit_u_tuple.diffstarpop_u_params
)
sfh_pdf_cens_params = diffstarpop_params.sfh_pdf_cens_params._asdict()


In [None]:
for key in sfh_pdf_cens_params.keys():
    val = np.round(sfh_pdf_cens_params[key], 3)
    lower_bound, upper_bound = SFH_PDF_QUENCH_BOUNDS_PDICT[key]
    if val == lower_bound:
        sfh_pdf_cens_params[key] = val + 0.001 * abs(val)
        print(key, "Val clashes with lower bound!")
    if val == upper_bound:
        sfh_pdf_cens_params[key] = val - 0.001 * abs(val)
        print(key, "Val clashes with upper bound!")
    if sfh_pdf_cens_params[key]==0.0:
        sfh_pdf_cens_params[key] = 0.001


In [None]:
def write_diffstarpopofits_script_line_sepms_satfrac():
    with open(f"params_diffstarpopfits_line_sepms_satfrac_{sim_name_pathname}.py", "w") as f:
        f.write("from collections import OrderedDict, namedtuple\n\n")
        f.write("import typing\n")
        # f.write("from jax import jit as jjit\n")
        f.write("from jax import numpy as jnp\n\n")
        f.write("from ..satquenchpop_model import (\n")
        f.write("    DEFAULT_SATQUENCHPOP_PARAMS,\n")
        # f.write("    get_bounded_satquenchpop_params,\n")
        # f.write("    get_unbounded_satquenchpop_params,\n")
        f.write(")\n")
        f.write("from ..defaults_tpeak_line_sepms_satfrac import get_unbounded_diffstarpop_params\n\n")

        # Individual OrderedDicts
        f.write("SFH_PDF_QUENCH_MU_PDICT = OrderedDict([\n")
        for i in range(4):
            f.write(f"    ('mean_{names[i]}_mseq_int', {sfh_pdf_cens_params[f'mean_{names[i]}_mseq_int']:.3f}),\n")
            f.write(f"    ('mean_{names[i]}_mseq_slp', {sfh_pdf_cens_params[f'mean_{names[i]}_mseq_slp']:.3f}),\n")
        for i in range(4):
            f.write(f"    ('mean_{names[i]}_qseq_int', {sfh_pdf_cens_params[f'mean_{names[i]}_qseq_int']:.3f}),\n")
            f.write(f"    ('mean_{names[i]}_qseq_slp', {sfh_pdf_cens_params[f'mean_{names[i]}_qseq_slp']:.3f}),\n")
        for i in range(4, 8):
            f.write(f"    ('mean_{names[i]}_int', {sfh_pdf_cens_params[f'mean_{names[i]}_int']:.3f}),\n")
            f.write(f"    ('mean_{names[i]}_slp', {sfh_pdf_cens_params[f'mean_{names[i]}_slp']:.3f}),\n")
        f.write("])\n\n")

        f.write("SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict([\n")
        for i in range(4):
            f.write(f"    ('std_{names[i]}_mseq_int', {sfh_pdf_cens_params[f'std_{names[i]}_mseq_int']:.3f}),\n")
            f.write(f"    ('std_{names[i]}_mseq_slp', {sfh_pdf_cens_params[f'std_{names[i]}_mseq_slp']:.3f}),\n")
        for i in range(4):
            f.write(f"    ('std_{names[i]}_qseq_int', {sfh_pdf_cens_params[f'std_{names[i]}_qseq_int']:.3f}),\n")
            f.write(f"    ('std_{names[i]}_qseq_slp', {sfh_pdf_cens_params[f'std_{names[i]}_qseq_slp']:.3f}),\n")
        f.write("])\n\n")

        f.write("SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict([\n")
        for i in range(4, 8):
            f.write(f"    ('std_{names[i]}_int', {sfh_pdf_cens_params[f'std_{names[i]}_int']:.3f}),\n")
            f.write(f"    ('std_{names[i]}_slp', {sfh_pdf_cens_params[f'std_{names[i]}_slp']:.3f}),\n")
        f.write("])\n\n")

        f.write("SFH_PDF_FRAC_QUENCH_PDICT = OrderedDict([\n")
        f.write(f"    ('frac_quench_cen_x0', {sfh_pdf_cens_params[f'frac_quench_cen_x0']:.3f}),\n")
        f.write(f"    ('frac_quench_cen_k', {sfh_pdf_cens_params[f'frac_quench_cen_k']:.3f}),\n")
        f.write(f"    ('frac_quench_cen_ylo', {sfh_pdf_cens_params[f'frac_quench_cen_ylo']:.3f}),\n")
        f.write(f"    ('frac_quench_cen_yhi', {sfh_pdf_cens_params[f'frac_quench_cen_yhi']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_x0', {sfh_pdf_cens_params[f'frac_quench_sat_x0']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_k', {sfh_pdf_cens_params[f'frac_quench_sat_k']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_ylo', {sfh_pdf_cens_params[f'frac_quench_sat_ylo']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_yhi', {sfh_pdf_cens_params[f'frac_quench_sat_yhi']:.3f}),\n")
        f.write("])\n\n")

        # Final combination logic
        f.write("SFH_PDF_QUENCH_PDICT = SFH_PDF_FRAC_QUENCH_PDICT.copy()\n")
        f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_MU_PDICT)\n")
        f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT)\n")
        f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT)\n\n")


        f.write("QseqParams = namedtuple('QseqParams', list(SFH_PDF_QUENCH_PDICT.keys()))\n")
        f.write("SFH_PDF_QUENCH_PARAMS = QseqParams(**SFH_PDF_QUENCH_PDICT)\n")

        f.write("_UPNAMES = ['u_' + key for key in QseqParams._fields]\n")
        f.write("QseqUParams = namedtuple('QseqUParams', _UPNAMES)\n\n")

        f.write("\n# Define a namedtuple container for the params of each component\n")
        f.write("class DiffstarPopParams(typing.NamedTuple):\n")
        f.write("    sfh_pdf_cens_params: jnp.array\n")
        f.write("    satquench_params: jnp.array\n\n")

        f.write(f"DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS = DiffstarPopParams(\n")
        f.write("    SFH_PDF_QUENCH_PARAMS, DEFAULT_SATQUENCHPOP_PARAMS\n")
        f.write(")\n\n")

        f.write(f"_U_PNAMES = ['u_' + key for key in DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS._fields]\n")
        f.write("DiffstarPopUParams = namedtuple('DiffstarPopUParams', _U_PNAMES)\n\n")

        f.write(f"DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_U_PARAMS = get_unbounded_diffstarpop_params(\n")
        f.write(f"    DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS\n")
        f.write(")\n")


In [None]:
def write_diffstarpopofits_script_line_sepms_satfrac_sigslope():
    with open(f"params_diffstarpopfits_line_sepms_satfrac_sigslope_{sim_name_pathname}.py", "w") as f:
        f.write("from collections import OrderedDict, namedtuple\n\n")
        f.write("import typing\n")
        # f.write("from jax import jit as jjit\n")
        f.write("from jax import numpy as jnp\n\n")
        f.write("from ..satquenchpop_model import (\n")
        f.write("    DEFAULT_SATQUENCHPOP_PARAMS,\n")
        # f.write("    get_bounded_satquenchpop_params,\n")
        # f.write("    get_unbounded_satquenchpop_params,\n")
        f.write(")\n")
        f.write("from ..defaults_tpeak_line_sepms_satfrac_sigslope import get_unbounded_diffstarpop_params\n\n")

        # Individual OrderedDicts
        f.write("SFH_PDF_QUENCH_MU_PDICT = OrderedDict([\n")
        f.write(f"    ('mean_{names[0]}_mseq_xtp', {sfh_pdf_cens_params[f'mean_{names[0]}_mseq_xtp']:.3f}),\n")
        f.write(f"    ('mean_{names[0]}_mseq_ytp', {sfh_pdf_cens_params[f'mean_{names[0]}_mseq_ytp']:.3f}),\n")
        f.write(f"    ('mean_{names[0]}_mseq_lo', {sfh_pdf_cens_params[f'mean_{names[0]}_mseq_lo']:.3f}),\n")
        f.write(f"    ('mean_{names[0]}_mseq_hi', {sfh_pdf_cens_params[f'mean_{names[0]}_mseq_hi']:.3f}),\n")
        for i in range(1, 4):
            f.write(f"    ('mean_{names[i]}_mseq_int', {sfh_pdf_cens_params[f'mean_{names[i]}_mseq_int']:.3f}),\n")
            f.write(f"    ('mean_{names[i]}_mseq_slp', {sfh_pdf_cens_params[f'mean_{names[i]}_mseq_slp']:.3f}),\n")
            
        f.write(f"    ('mean_{names[0]}_qseq_xtp', {sfh_pdf_cens_params[f'mean_{names[0]}_qseq_xtp']:.3f}),\n")
        f.write(f"    ('mean_{names[0]}_qseq_ytp', {sfh_pdf_cens_params[f'mean_{names[0]}_qseq_ytp']:.3f}),\n")
        f.write(f"    ('mean_{names[0]}_qseq_lo', {sfh_pdf_cens_params[f'mean_{names[0]}_qseq_lo']:.3f}),\n")
        f.write(f"    ('mean_{names[0]}_qseq_hi', {sfh_pdf_cens_params[f'mean_{names[0]}_qseq_hi']:.3f}),\n")
        for i in range(1, 4):
            f.write(f"    ('mean_{names[i]}_qseq_int', {sfh_pdf_cens_params[f'mean_{names[i]}_qseq_int']:.3f}),\n")
            f.write(f"    ('mean_{names[i]}_qseq_slp', {sfh_pdf_cens_params[f'mean_{names[i]}_qseq_slp']:.3f}),\n")
        for i in range(4, 8):
            f.write(f"    ('mean_{names[i]}_int', {sfh_pdf_cens_params[f'mean_{names[i]}_int']:.3f}),\n")
            f.write(f"    ('mean_{names[i]}_slp', {sfh_pdf_cens_params[f'mean_{names[i]}_slp']:.3f}),\n")
        f.write("])\n\n")

        f.write("SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict([\n")
        for i in range(4):
            f.write(f"    ('std_{names[i]}_mseq_int', {sfh_pdf_cens_params[f'std_{names[i]}_mseq_int']:.3f}),\n")
            f.write(f"    ('std_{names[i]}_mseq_slp', {sfh_pdf_cens_params[f'std_{names[i]}_mseq_slp']:.3f}),\n")
        for i in range(4):
            f.write(f"    ('std_{names[i]}_qseq_int', {sfh_pdf_cens_params[f'std_{names[i]}_qseq_int']:.3f}),\n")
            f.write(f"    ('std_{names[i]}_qseq_slp', {sfh_pdf_cens_params[f'std_{names[i]}_qseq_slp']:.3f}),\n")
        f.write("])\n\n")

        f.write("SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict([\n")
        for i in range(4, 8):
            f.write(f"    ('std_{names[i]}_int', {sfh_pdf_cens_params[f'std_{names[i]}_int']:.3f}),\n")
            f.write(f"    ('std_{names[i]}_slp', {sfh_pdf_cens_params[f'std_{names[i]}_slp']:.3f}),\n")
        f.write("])\n\n")

        f.write("SFH_PDF_FRAC_QUENCH_PDICT = OrderedDict([\n")
        f.write(f"    ('frac_quench_cen_x0', {sfh_pdf_cens_params[f'frac_quench_cen_x0']:.3f}),\n")
        f.write(f"    ('frac_quench_cen_k', {sfh_pdf_cens_params[f'frac_quench_cen_k']:.3f}),\n")
        f.write(f"    ('frac_quench_cen_ylo', {sfh_pdf_cens_params[f'frac_quench_cen_ylo']:.3f}),\n")
        f.write(f"    ('frac_quench_cen_yhi', {sfh_pdf_cens_params[f'frac_quench_cen_yhi']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_x0', {sfh_pdf_cens_params[f'frac_quench_sat_x0']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_k', {sfh_pdf_cens_params[f'frac_quench_sat_k']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_ylo', {sfh_pdf_cens_params[f'frac_quench_sat_ylo']:.3f}),\n")
        f.write(f"    ('frac_quench_sat_yhi', {sfh_pdf_cens_params[f'frac_quench_sat_yhi']:.3f}),\n")
        f.write("])\n\n")

        # Final combination logic
        f.write("SFH_PDF_QUENCH_PDICT = SFH_PDF_FRAC_QUENCH_PDICT.copy()\n")
        f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_MU_PDICT)\n")
        f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT)\n")
        f.write("SFH_PDF_QUENCH_PDICT.update(SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT)\n\n")


        f.write("QseqParams = namedtuple('QseqParams', list(SFH_PDF_QUENCH_PDICT.keys()))\n")
        f.write("SFH_PDF_QUENCH_PARAMS = QseqParams(**SFH_PDF_QUENCH_PDICT)\n")

        f.write("_UPNAMES = ['u_' + key for key in QseqParams._fields]\n")
        f.write("QseqUParams = namedtuple('QseqUParams', _UPNAMES)\n\n")

        f.write("\n# Define a namedtuple container for the params of each component\n")
        f.write("class DiffstarPopParams(typing.NamedTuple):\n")
        f.write("    sfh_pdf_cens_params: jnp.array\n")
        f.write("    satquench_params: jnp.array\n\n")

        f.write(f"DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS = DiffstarPopParams(\n")
        f.write("    SFH_PDF_QUENCH_PARAMS, DEFAULT_SATQUENCHPOP_PARAMS\n")
        f.write(")\n\n")

        f.write(f"_U_PNAMES = ['u_' + key for key in DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS._fields]\n")
        f.write("DiffstarPopUParams = namedtuple('DiffstarPopUParams', _U_PNAMES)\n\n")

        f.write(f"DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_U_PARAMS = get_unbounded_diffstarpop_params(\n")
        f.write(f"    DIFFSTARPOP_FITS_{sim_name_tuplename}_DIFFSTARPOP_PARAMS\n")
        f.write(")\n")

In [None]:
if sigslope:
    write_diffstarpopofits_script_line_sepms_satfrac_sigslope()
else:
    write_diffstarpopofits_script_line_sepms_satfrac()