## Here we compare the Bagpipes results for ~4000 JADES GOODS galaxies to synference results.

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from astropy.table import Table, hstack
from unyt import Jy

from synference import SBI_Fitter, create_uncertainty_models_from_EPOCHS_cat

plt.style.use("paper.style")

### Load data and trained model

In [None]:
bagpipes_results = (
    "/home/tharvey/work/synference/priv/data/cats/v2_sfh_dense_basis_dust_Calzetti.fits"
)
jades_catalog = "/home/tharvey/work/synference/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"  # noqa: E501
data_err_file = "/home/tharvey/work/JADES-DR3-GS_MASTER_Sel-F277W+F356W+F444W_v13.fits"

bagpipes = Table.read(bagpipes_results)
jades = Table.read(jades_catalog)

assert len(bagpipes) == len(jades)


# combine the tables
cat = hstack([jades, bagpipes])

model_name = "BPASS_DenseBasis_v4_final"  # "DenseBasis_v4_final_custom_small"
extra = "nsf_0"


fitter = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl",
    grid_path="/home/tharvey/work/synference/grids/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_6.0_Calzetti_v4_multinode.hdf5",
)

Create noise model

In [None]:
fitter.feature_array_flags["empirical_noise_models"]

In [None]:
bands = [
    "F435W",
    "F606W",
    "F775W",
    "F814W",
    "F850LP",
    "F090W",
    "F115W",
    "F150W",
    "F200W",
    "F277W",
    "F335M",
    "F356W",
    "F410M",
    "F444W",
]
hst_bands = ["F435W", "F606W", "F775W", "F814W", "F850LP"]
new_band_names = [
    (f"HST/ACS_WFC.{band.upper()}" if band in hst_bands else f"JWST/NIRCam.{band.upper()}")
    for band in bands
]

empirical_noise_models = create_uncertainty_models_from_EPOCHS_cat(
    data_err_file,
    bands,
    new_band_names,
    plot=False,
    hdu="OBJECTS",
    save=False,
    min_flux_error=0,
    model_class="asinh",
)

fitter.feature_array_flags["empirical_noise_models"] = empirical_noise_models

Explain catalogue to SBI_Fitter

In [None]:
def flux_cols_syntax(band):
    """Returns the syntax for the flux columns in the table."""
    return f"FLUX_TOTAL_{band}"


def magerr_cols_syntax(band):
    """Returns the syntax for the magnitude error columns in the table."""
    return f"FLUXERR_TOTAL_{band}"


new_band_names = fitter.feature_names

bands = [
    band.split(".")[-1]
    for band in new_band_names
    if not (band.startswith("unc_") or band == "redshift")
]
new_table = Table()

# Do this but just keep in Jy

for band in bands:
    flux = cat[flux_cols_syntax(band)] / 1e9
    err = cat[magerr_cols_syntax(band)] / 1e9
    new_table[band] = flux
    new_table[f"unc_{band}"] = err


new_table["redshift"] = cat["specz"]
new_table["id"] = cat["UNIQUE_ID"]

conversion_dict = {band: new_band for band, new_band in zip(bands, new_band_names)}
conversion_dict.update(
    {f"unc_{band}": f"unc_{new_band}" for band, new_band in zip(bands, new_band_names)}
)

conversion_dict["redshift"] = "redshift"

mask = np.zeros(len(new_table), dtype=bool)
for i in range(len(new_table)):
    for band in bands:
        if np.isfinite(new_table[band][i]) and not np.isfinite(new_table[f"unc_{band}"][i]):
            mask[i] = True
            break
orig_table = new_table.copy()
new_table = new_table[~mask]
orig_cat = cat.copy()
cat = cat[~mask]

# Do Inference

In [None]:
fitter.fitted_parameter_names

Inference on subset and make SED plots

In [None]:
ids = [5680, 145705, 149130, 44]

cat_to_plot = new_table[np.isin(new_table["id"], ids)]

In [None]:
fitter.feature_array_flags["empirical_noise_models"]

In [None]:
idx = np.argsort(new_table["redshift"])[::-1]

zorder_cat = new_table[idx]

post_tab = fitter.fit_catalogue(
    #
    cat_to_plot,
    # zorder_cat[:60],
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=False,
    recover_SEDs=True,
    plot_SEDs=True,
    missing_data_flag=np.nan,
    use_temp_samples=False,
    num_samples=300,
)

In [None]:
_, obs_mask = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=True,
    check_out_of_distribution=False,
    recover_SEDs=False,
    plot_SEDs=False,
    missing_data_flag=np.nan,
    use_temp_samples=False,
    num_samples=300,
    outlier_methods=[
        "iforest",
    ],  # "ecod", "knn", "lof", "gmm"],
)

Infer parameters for all galaxies in catalogue

In [None]:
post_tab = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=True,
    recover_SEDs=False,
    plot_SEDs=False,
    missing_data_flag=np.nan,
    use_temp_samples=False,
    num_samples=300,
    outlier_methods=["iforest", "ecod", "knn", "lof", "gmm"],
)

In [None]:
fitter.parameter_names

In [None]:
mask = post_tab["redshift"] >= 0.01
print(f"Number of galaxies at z>0.01: {np.sum(mask)}")
cat_mask = cat["input_redshift"] >= 0.01
cat = cat[cat_mask]
post_tab = post_tab[mask]

### Outlier Detection

In [None]:
from synference.utils import detect_outliers_pyod

observations = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=True,
    check_out_of_distribution=False,
)

output = detect_outliers_pyod(
    fitter.feature_array,
    observations[0],
    contamination=0.01,
    methods=["iforest", "ecod", "knn", "lof", "gmm", "kde"],
)

In [None]:
cat["outlier"] = 0

cat[~observations[1]]["outlier"] = output

In [None]:
cat_unmasked = cat[~observations[1]]
cat_unmasked["outlier"] = output

In [None]:
cat_unmasked[cat_unmasked["outlier"] == 1]

plt.plot(cat["input_redshift"], cat["stellar_mass_50"], ".", alpha=0.1)
plt.plot(
    cat_unmasked[cat_unmasked["outlier"] == 1]["input_redshift"],
    cat_unmasked[cat_unmasked["outlier"] == 1]["stellar_mass_50"],
    "r.",
    label="Outliers",
)
plt.xlabel("Redshift")

In [None]:
plt.scatter(post_tab["log_mass_50"], post_tab["log10_floor_sfr_10_50"])

In [None]:
fig, (ax1, ax2) = plt.subplots(
    2, 1, figsize=(5, 6), gridspec_kw={"height_ratios": [3, 1]}, sharex=True
)
ax1.errorbar(
    cat["stellar_mass_50"],
    post_tab["log_surviving_mass_50"],
    xerr=[
        post_tab["log_surviving_mass_50"] - post_tab["log_surviving_mass_16"],
        post_tab["log_surviving_mass_84"] - post_tab["log_surviving_mass_50"],
    ],
    yerr=[
        post_tab["log_mass_50"] - post_tab["log_mass_16"],
        post_tab["log_mass_84"] - post_tab["log_mass_50"],
    ],
    fmt="o",
    alpha=0.5,
    label="Data points with error bars",
    markersize=4,
    linewidth=1,
    markeredgecolor="black",
)
ax1.plot([6, 12], [6, 12], "r--", label="1:1 Line")
ax1.set_ylabel(r"Synference $\log_{10}(M_*/M_\odot)$", fontsize=16)
ax1.grid()
ax1.yaxis.set_tick_params(labelsize=14)

residuals = post_tab["log_mass_50"] - cat["stellar_mass_50"]
ax2.errorbar(
    cat["stellar_mass_50"],
    residuals,
    yerr=[
        (post_tab["log_mass_50"] - post_tab["log_mass_16"]),
        (post_tab["log_mass_84"] - post_tab["log_mass_50"]),
    ],
    fmt="o",
    alpha=0.5,
    label="Residuals with error bars",
    markersize=4,
    linewidth=1,
    markeredgecolor="black",
)
ax2.axhline(0, color="r", linestyle="--")
ax2.set_xlabel(r"Bagpipes $\log_{10}(M_*/M_\odot)$", fontsize=16)
ax2.set_ylabel("Residuals", fontsize=16)
ax2.grid()
ax2.yaxis.set_tick_params(labelsize=14)
ax2.xaxis.set_tick_params(labelsize=14)

# Add density contours


plt.tight_layout()
plt.show()
fig.savefig("plots/bagpipes_vs_synference_stellar_mass.png", dpi=300)

In [None]:
%matplotlib inline

x = cat["stellar_mass_50"]
y = post_tab["log_surviving_mass_50"]
mask = np.isfinite(x) & np.isfinite(y)
x = x[mask]
y = y[mask]
# Calculate the point density

# Pick one of the ones with deltamass > 0.75 dex and mass > 9.5

dmass = np.abs(post_tab["log_surviving_mass_50"] - cat["stellar_mass_50"])
mask = (dmass > 0.75) & (cat["stellar_mass_50"] > 9.5)

idx = np.where(mask)[0]
idx = np.random.choice(idx)

xi = cat["stellar_mass_50"][idx]
yi = post_tab["log_surviving_mass_50"][idx]

print(
    f"Selected galaxy ID: {cat['UNIQUE_ID'][idx]}, {cat['#ID'][idx]}, z={cat['input_redshift'][idx]:.2f}"  # noqa: E501
)


density = np.histogram2d(x, y, bins=30)[0]

fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))
plt.contour(
    density,
    extent=[x.min(), x.max(), y.min(), y.max()],
    cmap="cmr.torch_r",
    alpha=0.8,
    linewidths=1,
    levels=5,
)

plt.scatter(
    xi,
    yi,
    c="blue",
    s=20,
    alpha=0.7,
    edgecolors="black",
    label="Selected",
)
# Add black 1:1 line
plt.plot([7, 11], [7, 11], "k--", label="1:1 Line", linewidth=1)
plt.xlabel(r"Bagpipes $\log_{10}(M_*/M_\odot)$", fontsize=16)
plt.ylabel(r"Synference $\log_{10}(M_*/M_\odot)$", fontsize=16)

In [None]:
%matplotlib inline

x = cat["sfr_50"]
y = post_tab["log_sfr_50"]
mask = np.isfinite(x) & np.isfinite(y)
x = x[mask]
y = y[mask]
# Calculate the point density
x = np.log10(x)
# Pick one of the ones with Synference log SFR < 1 and bagpipes log sfr >2

dsfr = np.abs(post_tab["log_sfr_50"] - np.log10(cat["sfr_50"]))
mask = (dsfr > 1.0) & (np.log10(cat["sfr_50"]) > 2.0)
idx = np.where(mask)[0]
idx = np.random.choice(idx)

xi = np.log10(cat["sfr_50"][idx])
yi = post_tab["log_sfr_50"][idx]

print(
    f"Selected galaxy ID: {cat['UNIQUE_ID'][idx]}, {cat['#ID'][idx]}, z={cat['input_redshift'][idx]:.2f}"  # noqa: E501
)


density = np.histogram2d(x, y, bins=30)[0]

fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))
plt.contour(
    density,
    extent=[x.min(), x.max(), y.min(), y.max()],
    cmap="cmr.torch_r",
    alpha=0.8,
    linewidths=1,
    levels=5,
)

plt.scatter(
    xi,
    yi,
    c="blue",
    s=20,
    alpha=0.7,
    edgecolors="black",
    label="Selected",
)
# Add black 1:1 line
plt.xlabel(r"Bagpipes $\log_{10}(M_*/M_\odot)$", fontsize=16)
plt.ylabel(r"Synference $\log_{10}(M_*/M_\odot)$", fontsize=16)

plt.show()

plt.scatter(dsfr, cat["chisq_phot"], s=1)
plt.yscale("log")

mask = (dsfr > 3) & (cat["chisq_phot"] < 100)
idx = np.where(mask)[0]
print(cat["UNIQUE_ID"][idx])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

# Generate sample data


# Calculate density using KDE
xy = np.vstack([x, y])
kde = gaussian_kde(xy)
density = kde(xy)

# Define density threshold
threshold = np.percentile(density, 20)  # Show points below 70th percentile

fig, ax = plt.subplots(figsize=(10, 8))

# Plot contours for high-density regions
xi, yi = np.mgrid[x.min() - 1 : x.max() + 1 : 50j, y.min() - 1 : y.max() + 1 : 50j]
zi = kde(np.vstack([xi.flatten(), yi.flatten()]))
zi = zi.reshape(xi.shape)

contours = ax.contour(xi, yi, zi, levels=8, colors="blue", alpha=0.6, linewidths=1.5, zorder=4)
ax.contourf(xi, yi, zi, levels=8, alpha=0.3, cmap="Blues", zorder=3)

# Plot individual points only in low-density regions
low_density_mask = density < threshold
ax.scatter(x[low_density_mask], y[low_density_mask], c="red", s=30, alpha=0.7, zorder=2)

ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_title("Scatter Plot with Density Contours")
plt.show()

In [None]:
import cmasher

cmap = cmasher.get_sub_cmap(cmap="cmr.torch_r", start=0, stop=0.7)


def plot_comparison(x, y, xerr, yerr, xlabel, ylabel, filename, contours=True):
    """Plot comparison between two sets of measurements with error bars and residuals."""
    fig, (ax1, ax2) = plt.subplots(
        2,
        1,
        figsize=(5, 6),
        gridspec_kw={"height_ratios": [3, 1]},
        sharex=True,
        constrained_layout=True,
    )

    if contours:
        # Calculate density using KDE
        mask = np.isfinite(x) & np.isfinite(y)
        x = x[mask]
        y = y[mask]
        # print median offset and std dev
        median_offset = np.median(y - x)
        std_offset = np.std(y - x)
        print(f"Median offset: {median_offset:.3f}, Std dev: {std_offset:.3f}")
        xerr = [xerr[0][mask], xerr[1][mask]]
        yerr = [yerr[0][mask], yerr[1][mask]]
        xy = np.vstack([x, y])
        kde = gaussian_kde(xy)
        density = kde(xy)

        # Define density threshold
        threshold = np.percentile(density, 20)
        low_density_mask = density < threshold

        # Plot contours for high-density regions
        xi, yi = np.mgrid[x.min() - 1 : x.max() + 1 : 50j, y.min() - 1 : y.max() + 1 : 50j]
        zi = kde(np.vstack([xi.flatten(), yi.flatten()]))
        zi = zi.reshape(xi.shape)

        # Main plot contours
        levels = np.percentile(zi, [80, 95, 96, 99, 99.5, 100])
        contours = ax1.contour(
            xi, yi, zi, levels=levels, colors="black", alpha=0.3, linewidths=1.0, zorder=4
        )
        ax1.contourf(xi, yi, zi, levels=levels, alpha=0.7, cmap=cmap, zorder=3)

        # Contours for ax2
        xy = np.vstack([x, y - x])
        kde = gaussian_kde(xy)
        density = kde(xy)
        zi = kde(np.vstack([xi.flatten(), (yi - xi).flatten()]))
        zi = zi.reshape(xi.shape)
        ax2.contour(
            xi, yi - xi, zi, levels=levels, colors="black", alpha=0.3, linewidths=1.0, zorder=4
        )
        ax2.contourf(xi, yi - xi, zi, levels=levels, alpha=0.7, cmap=cmap, zorder=3)

    # Show rolling average with 16th and 84th percentiles
    bins = np.linspace(np.nanmin(x), np.nanmax(x), 10)
    bin_centers = 0.5 * (bins[1:] + bins[:-1])
    bin_medians = []
    bin_p16 = []
    bin_p84 = []
    for i in range(len(bins) - 1):
        in_bin = (x >= bins[i]) & (x < bins[i + 1])
        if np.sum(in_bin) > 3:
            bin_medians.append(np.median(y[in_bin]))
            bin_p16.append(np.percentile(y[in_bin], 16))
            bin_p84.append(np.percentile(y[in_bin], 84))
        else:
            bin_medians.append(np.nan)
            bin_p16.append(np.nan)
            bin_p84.append(np.nan)

    bin_medians = np.array(bin_medians)
    bin_p16 = np.array(bin_p16)
    bin_p84 = np.array(bin_p84)

    ax1.plot(
        bin_centers,
        bin_medians,
        color="dodgerblue",
        linewidth=2,
        alpha=0.6,
        label="Rolling Median",
        zorder=20,
        path_effects=[plt.matplotlib.patheffects.withStroke(linewidth=3, foreground="white")],
    )
    """ax1.fill_between(
        bin_centers,
        bin_p16,
        bin_p84,
        color="dodgerblue",
        alpha=0.3,
        label="16th-84th Percentile",
        zorder=20,
    )"""

    xorig = np.copy(x)
    yorig = np.copy(y)
    x = x[low_density_mask] if contours else x
    y = y[low_density_mask] if contours else y
    xerr = [xerr[0][low_density_mask], xerr[1][low_density_mask]] if contours else xerr
    yerr = [yerr[0][low_density_mask], yerr[1][low_density_mask]] if contours else yerr

    ax1.errorbar(
        x, y, xerr=xerr, yerr=yerr, alpha=0.2, linestyle="none", marker="none", color="grey"
    )
    ax1.scatter(
        x, y, alpha=0.5, s=8, edgecolor="black", zorder=10, linewidths=0.5, color="darkgrey"
    )

    ax1.plot(
        [np.nanmin(x), np.nanmax(x)],
        [np.nanmin(x), np.nanmax(x)],
        label="1:1 Line",
        linestyle="--",
        color="black",
        linewidth=2,
        zorder=15,
    )
    ax1.set_xlim(np.nanmin(x), np.nanmax(x))
    ax1.set_ylabel(ylabel, fontsize=16)
    ax1.grid()
    ax1.yaxis.set_tick_params(labelsize=14)

    residuals = y - x
    ax2.errorbar(
        x, residuals, yerr=yerr, fmt="o", alpha=0.2, marker="none", linestyle="none", color="grey"
    )
    ax2.scatter(
        x, residuals, alpha=0.5, s=8, edgecolor="black", zorder=10, linewidths=0.5, color="darkgrey"
    )
    ax2.axhline(0, color="black", linestyle="--", linewidth=2, zorder=15)
    ax2.set_xlabel(xlabel, fontsize=16)
    ax2.set_ylabel("Residuals", fontsize=16)
    ax2.grid()
    ax2.yaxis.set_tick_params(labelsize=14)
    ax2.xaxis.set_tick_params(labelsize=14)

    residualsorig = yorig - xorig
    bins = np.linspace(np.nanmin(xorig), np.nanmax(xorig), 10)
    bin_centers = 0.5 * (bins[1:] + bins[:-1])
    bin_medians = []
    bin_p16 = []
    bin_p84 = []
    for i in range(len(bins) - 1):
        in_bin = (xorig >= bins[i]) & (xorig < bins[i + 1])
        if np.sum(in_bin) > 3:
            bin_medians.append(np.median(residualsorig[in_bin]))
            bin_p16.append(np.percentile(residualsorig[in_bin], 16))
            bin_p84.append(np.percentile(residualsorig[in_bin], 84))
        else:
            bin_medians.append(np.nan)
            bin_p16.append(np.nan)
            bin_p84.append(np.nan)

    bin_medians = np.array(bin_medians)
    bin_p16 = np.array(bin_p16)
    bin_p84 = np.array(bin_p84)
    ax2.plot(
        bin_centers,
        bin_medians,
        color="dodgerblue",
        linewidth=2,
        alpha=0.6,
        label="Rolling Median",
        zorder=20,
        path_effects=[plt.matplotlib.patheffects.withStroke(linewidth=3, foreground="white")],
    )

    ax2.set_ylim(-1.5, 1.5)
    ax1.legend()

    # plt.tight_layout()
    # plt.show()
    fig.savefig(filename, dpi=300)
    return fig, (ax1, ax2)


plot_comparison(
    cat["stellar_mass_50"],
    post_tab["log_surviving_mass_50"],
    xerr=[
        post_tab["log_surviving_mass_50"] - post_tab["log_surviving_mass_16"],
        post_tab["log_surviving_mass_84"] - post_tab["log_surviving_mass_50"],
    ],
    yerr=[
        post_tab["log_mass_50"] - post_tab["log_mass_16"],
        post_tab["log_mass_84"] - post_tab["log_mass_50"],
    ],
    xlabel=r"Bagpipes $\log_{10}(\rm M_\star/M_\odot)$",
    ylabel=r"Synference $\log_{10}(\rm M_\star/M_\odot)$",
    filename="plots/bagpipes_vs_synference_stellar_mass.png",
)

plot_comparison(
    np.log10(cat["sfr_10myr_50"]),
    post_tab["log10_floor_sfr_10_50"],
    xerr=[
        np.log10(cat["sfr_10myr_50"]) - np.log10(cat["sfr_10myr_16"]),
        np.log10(cat["sfr_10myr_84"]) - np.log10(cat["sfr_10myr_50"]),
    ],
    yerr=[
        post_tab["log10_floor_sfr_10_50"] - post_tab["log10_floor_sfr_10_16"],
        post_tab["log10_floor_sfr_10_84"] - post_tab["log10_floor_sfr_10_50"],
    ],
    xlabel=r"Bagpipes SFR$_{\rm 10}$ [$\log_{10}(\rm M_\odot$/yr)]",
    ylabel=r"Synference SFR$_{\rm 10}$ [$\log_{10}(\rm M_\odot$/yr)]",
    filename="plots/bagpipes_vs_synference_sfr_10.png",
)
plot_comparison(
    np.log10(cat["mass_weighted_age_50"] * 1e3),
    post_tab["log10_mass_weighted_age_50"],
    xerr=[
        np.log10(cat["mass_weighted_age_50"]) - np.log10(cat["mass_weighted_age_16"]),
        np.log10(cat["mass_weighted_age_84"]) - np.log10(cat["mass_weighted_age_50"]),
    ],
    yerr=[
        post_tab["log10_mass_weighted_age_50"] - post_tab["log10_mass_weighted_age_16"],
        post_tab["log10_mass_weighted_age_84"] - post_tab["log10_mass_weighted_age_50"],
    ],
    xlabel=r"Bagpipes Age$_{\rm MW}$ [$\log_{10}(\rm Myr)$]",
    ylabel=r"Synference Age$_{\rm MW}$ [$\log_{10}(\rm Myr)$]",
    filename="plots/bagpipes_vs_synference_mass_weighted_age.png",
)

plot_comparison(
    np.log10(cat["dust:Av_50"]),
    post_tab["log10_Av_50"],
    xerr=[
        np.log10(cat["dust:Av_50"]) - np.log10(cat["dust:Av_16"]),
        np.log10(cat["dust:Av_84"]) - np.log10(cat["dust:Av_50"]),
    ],
    yerr=[
        post_tab["log10_Av_50"] - post_tab["log10_Av_16"],
        post_tab["log10_Av_84"] - post_tab["log10_Av_50"],
    ],
    xlabel=r"Bagpipes $\log_{10}(\rm A_V)$",
    ylabel=r"Synference $\log_{10}(\rm A_V)$",
    filename="plots/bagpipes_vs_synference_Av.png",
)

In [None]:
def plot_comparison(x, y, xerr, yerr, xlabel, ylabel, filename, contours=True):
    """Plot comparison between two sets of measurements with error bars and residuals."""
    # Set publication-quality style
    plt.rcParams.update(
        {
            "font.family": "serif",
            "font.size": 11,
            "axes.labelsize": 13,
            "axes.titlesize": 13,
            "xtick.labelsize": 11,
            "ytick.labelsize": 11,
            "legend.fontsize": 10,
            "figure.dpi": 100,
            "savefig.dpi": 300,
            "axes.linewidth": 1.2,
            "grid.alpha": 0.3,
            "grid.linewidth": 0.5,
        }
    )

    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(6, 7), gridspec_kw={"height_ratios": [3, 1], "hspace": 0.05}
    )

    # Color scheme
    contour_color = "#2E5C8A"  # Professional blue
    point_color = "#4A5568"  # Slate gray
    error_color = "#A0AEC0"  # Light gray
    line_color = "#E53E3E"  # Muted red

    if contours:
        # Calculate density using KDE
        mask = np.isfinite(x) & np.isfinite(y)
        x_filtered = x[mask]
        y_filtered = y[mask]
        if xerr is not None:
            xerr_filtered = [xerr[0][mask], xerr[1][mask]]
        if yerr is not None:
            yerr_filtered = [yerr[0][mask], yerr[1][mask]]

        xy = np.vstack([x_filtered, y_filtered])
        kde = gaussian_kde(xy)
        density = kde(xy)

        # Define density threshold
        threshold = np.percentile(density, 20)
        low_density_mask = density < threshold

        # Create smoother grid for contours
        xi, yi = np.mgrid[
            x_filtered.min() - 0.5 : x_filtered.max() + 0.5 : 100j,
            y_filtered.min() - 0.5 : y_filtered.max() + 0.5 : 100j,
        ]
        zi = kde(np.vstack([xi.flatten(), yi.flatten()]))
        zi = zi.reshape(xi.shape)

        # Main plot contours
        levels = np.percentile(zi, [20, 40, 60, 80, 95, 99, 99.9])
        ax1.contourf(xi, yi, zi, levels=levels, alpha=0.15, cmap="Blues", zorder=1)
        ax1.contour(
            xi, yi, zi, levels=levels, colors=contour_color, alpha=0.5, linewidths=1.0, zorder=2
        )

        # Residuals plot contours
        xy_resid = np.vstack([x_filtered, y_filtered - x_filtered])
        kde_resid = gaussian_kde(xy_resid)
        yi_resid = yi - xi
        zi_resid = kde_resid(np.vstack([xi.flatten(), yi_resid.flatten()]))
        zi_resid = zi_resid.reshape(xi.shape)

        ax2.contourf(xi, yi_resid, zi_resid, levels=levels, alpha=0.15, cmap="Blues", zorder=1)
        ax2.contour(
            xi,
            yi_resid,
            zi_resid,
            levels=levels,
            colors=contour_color,
            alpha=0.5,
            linewidths=1.0,
            zorder=2,
        )

        # Use filtered data for scatter plots
        x_plot = x_filtered[low_density_mask]
        y_plot = y_filtered[low_density_mask]
        if xerr is not None:
            xerr_plot = [xerr_filtered[0][low_density_mask], xerr_filtered[1][low_density_mask]]
        else:
            xerr_plot = None
        if yerr is not None:
            yerr_plot = [yerr_filtered[0][low_density_mask], yerr_filtered[1][low_density_mask]]
        else:
            yerr_plot = None
    else:
        x_plot, y_plot = x, y
        xerr_plot, yerr_plot = xerr, yerr

    # Main comparison plot
    ax1.errorbar(
        x_plot,
        y_plot,
        xerr=xerr_plot,
        yerr=yerr_plot,
        alpha=0.15,
        linestyle="none",
        marker="none",
        color=error_color,
        zorder=3,
        elinewidth=0.8,
    )
    ax1.scatter(
        x_plot,
        y_plot,
        alpha=0.6,
        s=12,
        edgecolor="white",
        linewidths=0.3,
        color=point_color,
        zorder=4,
    )

    # 1:1 reference line
    xlim_range = [np.nanmin(x), np.nanmax(x)]
    ax1.plot(
        xlim_range,
        xlim_range,
        color=line_color,
        linestyle="--",
        linewidth=1.5,
        alpha=0.8,
        label="1:1",
        zorder=5,
    )

    ax1.set_xlim(xlim_range)
    ax1.set_ylim(xlim_range)
    ax1.set_ylabel(ylabel)
    ax1.legend(frameon=True, fancybox=False, edgecolor="gray", framealpha=0.9, loc="upper left")
    ax1.grid(True, alpha=0.25, linestyle=":", linewidth=0.5)
    ax1.tick_params(direction="in", which="both", top=True, right=True)

    # Residuals plot
    residuals = y_plot - x_plot
    ax2.errorbar(
        x_plot,
        residuals,
        yerr=yerr_plot,
        alpha=0.15,
        marker="none",
        linestyle="none",
        color=error_color,
        zorder=3,
        elinewidth=0.8,
    )
    ax2.scatter(
        x_plot,
        residuals,
        alpha=0.6,
        s=12,
        edgecolor="white",
        linewidths=0.3,
        color=point_color,
        zorder=4,
    )
    ax2.axhline(0, color=line_color, linestyle="--", linewidth=1.5, alpha=0.8, zorder=5)

    ax2.set_xlabel(xlabel)
    ax2.set_ylabel("Residuals")
    ax2.set_ylim(-1.5, 1.5)
    ax2.grid(True, alpha=0.25, linestyle=":", linewidth=0.5)
    ax2.tick_params(direction="in", which="both", top=True, right=True)

    plt.tight_layout()
    fig.savefig(filename, dpi=300, bbox_inches="tight", facecolor="white")
    plt.show()

    return fig, (ax1, ax2)


# Usage examples
plot_comparison(
    cat["stellar_mass_50"],
    post_tab["log_surviving_mass_50"],
    xerr=[
        post_tab["log_surviving_mass_50"] - post_tab["log_surviving_mass_16"],
        post_tab["log_surviving_mass_84"] - post_tab["log_surviving_mass_50"],
    ],
    yerr=[
        post_tab["log_mass_50"] - post_tab["log_mass_16"],
        post_tab["log_mass_84"] - post_tab["log_mass_50"],
    ],
    xlabel=r"Bagpipes $\log_{10}(\mathrm{M}_*/\mathrm{M}_\odot)$",
    ylabel=r"Synference $\log_{10}(\mathrm{M}_*/\mathrm{M}_\odot)$",
    filename="plots/bagpipes_vs_synference_stellar_mass.png",
)

plot_comparison(
    np.log10(cat["sfr_10myr_50"]),
    post_tab["log10_floor_sfr_10_50"],
    xerr=[
        np.log10(cat["sfr_10myr_50"]) - np.log10(cat["sfr_10myr_16"]),
        np.log10(cat["sfr_10myr_84"]) - np.log10(cat["sfr_10myr_50"]),
    ],
    yerr=[
        post_tab["log10_floor_sfr_10_50"] - post_tab["log10_floor_sfr_10_16"],
        post_tab["log10_floor_sfr_10_84"] - post_tab["log10_floor_sfr_10_50"],
    ],
    xlabel=r"Bagpipes SFR (10 Myr) [$\log_{10} \mathrm{M}_\odot$/yr]",
    ylabel=r"Synference SFR (10 Myr) [$\log_{10} \mathrm{M}_\odot$/yr]",
    filename="plots/bagpipes_vs_synference_sfr_10.png",
)

plot_comparison(
    np.log10(cat["mass_weighted_age_50"] * 1e3),
    post_tab["log10_mass_weighted_age_50"],
    xerr=[
        np.log10(cat["mass_weighted_age_50"]) - np.log10(cat["mass_weighted_age_16"]),
        np.log10(cat["mass_weighted_age_84"]) - np.log10(cat["mass_weighted_age_50"]),
    ],
    yerr=[
        post_tab["log10_mass_weighted_age_50"] - post_tab["log10_mass_weighted_age_16"],
        post_tab["log10_mass_weighted_age_84"] - post_tab["log10_mass_weighted_age_50"],
    ],
    xlabel=r"Bagpipes Mass-weighted Age [$\log_{10}(\mathrm{Myr})$]",
    ylabel=r"Synference Mass-weighted Age [$\log_{10}(\mathrm{Myr})$]",
    filename="plots/bagpipes_vs_synference_mass_weighted_age.png",
)

plot_comparison(
    np.log10(cat["dust:Av_50"]),
    post_tab["log10_Av_50"],
    xerr=[
        np.log10(cat["dust:Av_50"]) - np.log10(cat["dust:Av_16"]),
        np.log10(cat["dust:Av_84"]) - np.log10(cat["dust:Av_50"]),
    ],
    yerr=[
        post_tab["log10_Av_50"] - post_tab["log10_Av_16"],
        post_tab["log10_Av_84"] - post_tab["log10_Av_50"],
    ],
    xlabel=r"Bagpipes $\log_{10}(\mathrm{A}_\mathrm{V})$",
    ylabel=r"Synference $\log_{10}(\mathrm{A}_\mathrm{V})$",
    filename="plots/bagpipes_vs_synference_Av.png",
)

In [None]:
# in three redshift bins - z = 0 - 1, 2 - 3, 3 - 6, plot Bagpipes and Synference SFMS

fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharey=True, dpi=200)

# Bagpipes on right
# Synference on left

z_bins = [(0, 1), (2, 3), (3, 6)]
# flexoki colors
colors = ["#377eb8", "#4daf4a", "#e41a1c"]
labels = ["$0 < z < 1$", "$2 < z < 3$", "$3 < z < 6$"]
for (z_min, z_max), color, label in zip(z_bins, colors, labels):
    mask = (cat["specz"] >= z_min) & (cat["specz"] < z_max)
    axs[0].scatter(
        cat["stellar_mass_50"][mask],
        np.log10(cat["sfr_10myr_50"][mask]),
        color=color,
        alpha=0.5,
        s=5,
        label=label,
        edgecolor="black",
        linewidth=0.5,
    )
    axs[1].scatter(
        post_tab["log_surviving_mass_50"][mask],
        post_tab["log10_floor_sfr_10_50"][mask],
        color=color,
        alpha=0.5,
        s=4,
        label=label,
        edgecolor="black",
        linewidth=0.5,
    )
    # Plot LOWESS fit to each
    from statsmodels.nonparametric.smoothers_lowess import lowess

    if np.sum(mask) > 10:
        lowess_fit = lowess(
            np.log10(cat["sfr_10myr_50"][mask]),
            cat["stellar_mass_50"][mask],
            frac=0.6,
            it=0,
            return_sorted=True,
        )
        axs[0].plot(
            lowess_fit[:, 0],
            lowess_fit[:, 1],
            color=color,
            linewidth=2,
            path_effects=[plt.matplotlib.patheffects.withStroke(linewidth=3, foreground="white")],
        )

        lowess_fit = lowess(
            post_tab["log10_floor_sfr_10_50"][mask],
            post_tab["log_surviving_mass_50"][mask],
            frac=0.6,
            it=0,
            return_sorted=True,
        )
        axs[1].plot(
            lowess_fit[:, 0],
            lowess_fit[:, 1],
            color=color,
            linewidth=2,
            path_effects=[plt.matplotlib.patheffects.withStroke(linewidth=3, foreground="white")],
        )

axs[0].set_xlabel(r"$\log_{10}(M_*/M_\odot)$", fontsize=14)
axs[0].set_ylabel(r"$\log_{10}(SFR_{10 Myr}/M_\odot yr^{-1})$", fontsize=14)
axs[0].text(
    0.05, 0.95, "Bagpipes", transform=axs[0].transAxes, fontsize=16, verticalalignment="top"
)
axs[0].grid()

axs[1].set_xlabel(r"$\log_{10}(M_*/M_\odot)$", fontsize=14)
axs[1].text(
    0.05, 0.95, "Synference", transform=axs[1].transAxes, fontsize=16, verticalalignment="top"
)
axs[1].grid()
axs[1].legend(fontsize=10)

plt.tight_layout()
fig.savefig("plots/bagpipes_vs_synference_sfms.png", dpi=300)

In [None]:
plt.hist(np.log10(cat["sfr_10myr_50"]), bins=50, histtype="stepfilled", alpha=0.7, label="Bagpipes")
plt.hist(
    post_tab["log10_floor_sfr_10_50"], bins=50, histtype="stepfilled", alpha=0.7, label="Synference"
)
plt.legend()

In [None]:
# how many rows have a missing value in new_table

for col in new_table.colnames:
    a = new_table[col]
    if hasattr(a, "mask"):
        print(col, np.sum(a.mask))

## FSPS

In [None]:
bagpipes_results = "/home/tharvey/work/synference/priv/data/cats/sfh_dense_basis_dust_Calzetti.fits"
jades_catalog = "/home/tharvey/work/synference/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"  # noqa: E501
data_err_file = "/home/tharvey/work/JADES-DR3-GS_MASTER_Sel-F277W+F356W+F444W_v13.fits"

bagpipes = Table.read(bagpipes_results)
jades = Table.read(jades_catalog)

assert len(bagpipes) == len(jades)

# combine the tables
cat = hstack([jades, bagpipes])

model_name = "FSPS_DenseBasis_v4_final"  # "DenseBasis_v4_final_custom_small"
extra = "nsf_0"


fsps_fitter = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl",
    grid_path="/home/tharvey/work/synference/grids/grid_FSPS_Chab_DenseBasis_SFH_0.01_z_14_logN_6.0_Calzetti_v4_multinode.hdf5",
)

fsps_fitter.feature_array_flags["empirical_noise_models"] = empirical_noise_models

In [None]:
post_tab_fsps = fsps_fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=False,
)

In [None]:
# Plot FSPS mass vs BPASS mass

plot_comparison(
    post_tab["log_surviving_mass_50"],
    post_tab_fsps["log_surviving_mass_50"],
    xerr=[
        post_tab["log_surviving_mass_50"] - post_tab["log_surviving_mass_16"],
        post_tab["log_surviving_mass_84"] - post_tab["log_surviving_mass_50"],
    ],
    yerr=[
        post_tab_fsps["log_surviving_mass_50"] - post_tab_fsps["log_surviving_mass_16"],
        post_tab_fsps["log_surviving_mass_84"] - post_tab_fsps["log_surviving_mass_50"],
    ],
    xlabel=r"BPASS Stellar Mass [$\log_{10}(\mathrm{M}_*/\mathrm{M}_\odot)$]",
    ylabel=r"FSPS Stellar Mass [$\log_{10}(\mathrm{M}_*/\mathrm{M}_\odot)$]",
    filename="plots/fsps_vs_bpass_stellar_mass.png",
)

In [None]:
import h5py

folder = "/home/tharvey/work/synference/priv/JADES-DR3-GS-pipes/"
galaxy_id = "20"
with h5py.File(
    f"{folder}/JADES-DR3-GS-pipes/sfh_dense_basis_dust_Calzetti/{galaxy_id}_JADES-DR3-GS-North.h5",
    "r",
) as f:
    print(f)

Export some of the validation dataset to Bagpipes format

In [None]:
from tqdm import tqdm

from synference.utils import asinh_err_to_f_jy, asinh_to_f_jy

SNR_bins = [5, np.inf]
idxs = fitter.bin_noisy_testing_data(snr_bins=SNR_bins, return_indices=True)[0]

features = fitter._X_test[idxs]
params = fitter._y_test[idxs]


out = {}

for band in tqdm(empirical_noise_models.keys()):
    index = np.where(np.array(fitter.feature_names) == band)[0]
    err_index = np.where(np.array(fitter.feature_names) == f"unc_{band}")[0][0]
    b = empirical_noise_models[band].b
    ferr = []
    flux = []
    for idx in idxs:
        f = fitter._X_test[idx, index]
        fluxes = asinh_to_f_jy(f, f_b=b)
        f_err = fitter._X_test[idx, err_index]
        flux_errs = asinh_err_to_f_jy(f, f_err, f_b=b)
        flux.append(np.squeeze(fluxes))
        ferr.append(np.squeeze(flux_errs))
    band = band.split(".")[-1]
    out[band] = np.array(flux)
    out[f"unc_{band}"] = np.array(ferr)
out["redshift"] = features[:, -1]

In [None]:
# Assemble into table
from astropy.table import Table

new_table = Table(data=out)
new_table["ID"] = range(len(new_table))

new_table.write(
    "/home/tharvey/work/synference/priv/data/cats/BPASS_DenseBasis_v4_5sig_val_cat.fits",
    overwrite=True,
)  # noqa: E501

# Pick a random subset of 1000 rows from new_table
import numpy as np

if False:
    random_indices = np.random.choice(new_table["ID"], size=1000, replace=False)
    subset_table = new_table[random_indices]

    subset_table.write(
        "/home/tharvey/work/synference/priv/data/cats/BPASS_DenseBasis_v4_5sig_val_cat_subset_1000.fits",
        overwrite=True,
    )  # noqa: E501
subset_table = Table.read(
    "/home/tharvey/work/synference/priv/data/cats/BPASS_DenseBasis_v4_5sig_val_cat_subset_1000.fits"  # noqa: E501
)

In [None]:
parameters_table_data = {}
rows = fitter._y_test[idxs]
param_names = fitter.fitted_parameter_names

for i, name in enumerate(param_names):
    parameters_table_data[name] = rows[:, i]
post_tab = Table(data=parameters_table_data)
ids = range(len(post_tab))
post_tab["ID"] = ids

subset_params_mask = np.isin(ids, subset_table["ID"])
post_tab_sub = post_tab[subset_params_mask]

In [None]:
subset_table.sort("ID")
subset_table["ID"]

In [None]:
comparison_pipes_path = "/home/tharvey/work/synference/priv/sfh_dense_basis_dust_Calzetti.fits"
comparison_pipes = Table.read(comparison_pipes_path)
# convertto Int
comparison_pipes["ID"] = [int(i) for i in comparison_pipes["#ID"]]
from astropy.table import join

cat = join(post_tab_sub, comparison_pipes, keys="ID")

In [None]:
plt.scatter(cat["log_surviving_mass"], cat["stellar_mass_50"], s=1)
plt.plot([4, 12], [4, 12], color="red", linestyle="--")

In [None]:
param_pairs = [
    [
        "log_surviving_mass",
        "stellar_mass_50",
        False,
        None,
        r"$\log_{10}(\mathrm{M}_*/\mathrm{M}_\odot)$",
    ],
    [
        "log10_floor_sfr_10",
        "sfr_10myr_50",
        True,
        None,
        r"$\log_{10}(\mathrm{SFR}_{10 Myr}/\mathrm{M}_\odot yr^{-1})$",
    ],
    [
        "log10_mass_weighted_age",
        "mass_weighted_age_50",
        True,
        3,
        r"$\log_{10}(\mathrm{Age}_{MW}/\mathrm{Myr})$",
    ],
    ["log10_Av", "dust:Av_50", True, None, r"$\log_{10}(\mathrm{A}_\mathrm{V})$"],
]

for syn_param, bp_param, unlog, scale, name in param_pairs:
    # Compute RMSE and R^2
    from sklearn.metrics import mean_squared_error, r2_score

    if unlog:
        bp_param_vals = np.log10(cat[bp_param])
        lo = np.log10(cat[bp_param.replace("_50", "_16")])
        hi = np.log10(cat[bp_param.replace("_50", "_84")])
    else:
        bp_param_vals = cat[bp_param]
        lo = cat[bp_param.replace("_50", "_16")]
        hi = cat[bp_param.replace("_50", "_84")]

    if scale is not None:
        bp_param_vals += scale
        lo += scale
        hi += scale

    # check for nans
    rmse = np.sqrt(mean_squared_error(cat[syn_param], bp_param_vals))
    r2 = r2_score(cat[syn_param], bp_param_vals)
    print(f"Comparison for {syn_param} vs {bp_param}: RMSE = {rmse:.3f}, R^2 = {r2:.3f}")
    fig, axs = plot_comparison(
        cat[syn_param],
        bp_param_vals,
        yerr=[bp_param_vals - lo, hi - bp_param_vals],
        xerr=None,
        xlabel=f"True {name}",
        ylabel=f"Bagpipes {name}",
        filename=f"plots/bagpipes_vs_true_{bp_param}.png",
        contours=False,
    )

    # put rmse and r2 on the plot
    axs[0].text(
        0.98,
        0.10,
        f"RMSE = {rmse:.3f}\nRÂ² = {r2:.3f}",
        transform=axs[0].transAxes,
        fontsize=12,
        verticalalignment="top",
        horizontalalignment="right",
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
    )

    fig.savefig(f"plots/bagpipes_vs_true_{bp_param}_with_stats.png", dpi=300, bbox_inches="tight")

In [None]:
plt.scatter(np.log10(cat["sfr_10myr_50"]), cat["log10_floor_sfr_10"], s=1)
plt.plot([-4, 4], [-4, 4], color="red", linestyle="--")

In [None]:
import os

import numpy as np

folder = "/home/tharvey/work/synference/priv/true_comp_posteriors"
parameters = os.listdir(folder)

log = ["dust:Av", "sfr_10myr", "mass_weighted_age", "iyer2019:metallicity", "iyer2019:sfr"]
scale = {"mass_weighted_age": 3, "iyer2019:metallicity": -1.698970}

total_samples = []
for param in parameters:
    param_samples = []
    for id in cat["ID"]:
        path = f"{folder}/{param}/{id}.txt"
        if not os.path.exists(path):
            print(f"Missing file: {path}")
        samples = np.genfromtxt(path)
        if len(samples) < 500:
            # resample from existing samples with replacement
            samples = np.random.choice(samples, size=500, replace=True)

        if param in log:
            samples = np.log10(samples)

        if param in scale:
            samples += scale[param]

        param_samples.append(samples)
    total_samples.append(param_samples)
total_samples_array = np.array(total_samples)
total_samples_array = total_samples_array.transpose(2, 1, 0)

In [None]:
params = {
    "stellar_mass": ["log_surviving_mass", False, None],
    "sfr_10myr": ["log10_floor_sfr_10", True, None],
    "mass_weighted_age": ["log10_mass_weighted_age", True, 3],
    "dust:Av": ["log10_Av", True, None],
    "beta_C94": ["beta", False, None],
    "formed_mass": ["log_mass", False, None],
    "iyer2019:dirichletr1": ["sfh_quantile_25", False, None],
    "iyer2019:dirichletr2": ["sfh_quantile_50", False, None],
    "iyer2019:dirichletr3": ["sfh_quantile_75", False, None],
    "iyer2019:sfr": ["log_sfr", True, None],
    "iyer2019:metallicity": ["log10metallicity", True, None],
}


invert_params = {i[0]: key for key, i in params.items()}

# Use this to construct truths array

truths = []
for param in parameters:
    syn_param, unlog, scale = params[param]
    vals = cat[syn_param]
    truths.append(vals)

truths_array = np.array(truths).T

In [None]:
import tarp

out = tarp.get_tarp_coverage(
    total_samples_array,
    truths_array,
    norm=True,
    bootstrap=True,
    num_bootstrap=100,
)

In [None]:
# import gridspec

fig = plt.figure(
    figsize=(9, 9),
    dpi=150,
)

gs = fig.add_gridspec(3, 6, height_ratios=[1, 1, 1.44])

ax_top_l = fig.add_subplot(gs[0, :2])
ax_top_r = fig.add_subplot(gs[0, 2:4])
ax_top_c = fig.add_subplot(gs[0, 4:6])
ax_middle_l = fig.add_subplot(gs[1, :2])
ax_middle_r = fig.add_subplot(gs[1, 2:4])
ax_middle_c = fig.add_subplot(gs[1, 4:6])
ax1 = fig.add_subplot(gs[2, :3])
ax2 = fig.add_subplot(gs[2, 3:])
axs = [ax_top_l, ax_top_r, ax_top_c]


# TARP and Coverage
ax1.plot([0, 1], [0, 1], "k--", label="Ideal", alpha=0.5)
ax1.set_xlabel("Predicted Percentile")
ax1.set_ylabel("Empirical Percentile")
ax1.text(0.05, 0.90, "TARP", transform=ax1.transAxes, fontsize=16)

ax2.plot([0, 1], [0, 1], "k--", label="Ideal", alpha=0.5)
ax2.set_xlabel("Credibility Level")
ax2.set_ylabel("Expected Coverage")
ax2.text(0.05, 0.90, "SBC", transform=ax2.transAxes, fontsize=16)

ax_params = [
    "log_mass",
    "log10_floor_sfr_10",
    "log10_Av",
    "log10_mass_weighted_age",
    "beta",
    "log_surviving_mass",
]
param_latex = [
    r"\log(\rm M_\star/M_\odot)",
    r"\log(\rm SFR_{10 \rm Myr}/ M_\odot yr^{-1})",
    r"\log(\rm A_V / \rm mag)",
    r"\log(\rm Age_M / \rm yr)",
    r"\beta",
    r"\rm \log M_{\star,surv}/M_\odot)",
]

param_labels = [
    r"$\log M_{*,surv}$",
    r"$\log$ SFR",
    r"$\log M_*$",
    "$\\beta$",
    "$t_{50}$",
    r"$\rm \log_{10} \ A_V$",
    "$t_{75}$",
    r"$\log$ SFR$_{10}$",
    r"$\log Z/Z_\odot$",
    "$t_{25}$",
    r"$\log$ Age$_M$",
]

indices = [parameters.index(invert_params[param]) for param in ax_params]

# ax_top_r.set_xscale('log')
# ax_top_r.set_yscale('log')
# ax_top_r.set_xlim(1e-4, 1e3)
# ax_top_r.set_ylim(1e-4, 1e3)
# TARP

ecp, alpha = out[0], out[1]
ecp_mean = np.mean(ecp, axis=0)
ecp_std = np.std(ecp, axis=0)
ax1.plot(alpha, ecp_mean, label="TARP", color="b")
ax1.fill_between(alpha, ecp_mean - ecp_std, ecp_mean + ecp_std, alpha=0.5, color="b")
ax1.fill_between(alpha, ecp_mean - 2 * ecp_std, ecp_mean + 2 * ecp_std, alpha=0.2, color="b")


import copy

samples = copy.deepcopy(total_samples_array)
trues = copy.deepcopy(truths_array)

ndata, npars = trues.shape
ranks = (samples < trues[None, ...]).sum(axis=0)

unicov = [np.sort(np.random.uniform(0, 1, ndata)) for j in range(200)]
unip = np.percentile(unicov, [5, 16, 84, 95], axis=0)

cdf = np.linspace(0, 1, len(ranks))
cmap = plt.get_cmap("RdYlBu")
colors = [cmap(i / npars) for i in range(npars)]

for i in range(npars):
    xr = np.sort(ranks[:, i])
    xr = xr / xr[-1]
    ax2.plot(cdf, cdf, "k--")

    ax2.plot(xr, cdf, lw=2, label=param_labels[i], alpha=1, color=colors[i])
ax2.set(adjustable="box", aspect="equal")
ax1.set(adjustable="box", aspect="equal")

ax2.set_xlabel("Predicted Percentile")
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.legend(fontsize=8, loc="lower right", framealpha=0.9, frameon=True)

for i, ax in enumerate([ax_top_l, ax_top_r, ax_top_c, ax_middle_l, ax_middle_r, ax_middle_c]):
    param = ax_params[i]
    idx = indices[i]
    print(idx, param)
    # r2 = met["R_squared"][order_index]
    # rmse = met["RMSE"][order_index]
    # Recalculate RMSE and R2 for just high SNR data
    from sklearn.metrics import mean_squared_error, r2_score

    y_true = trues[:, idx]
    y_pred = np.median(samples[:, :, idx], axis=0)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)

    # Get 16/50/84 percentiles
    # p16 = np.percentile(samples[:, :, idx], 16, axis=0)
    # p50 = np.percentile(samples[:, :, idx], 50, axis=0)
    # p84 = np.percentile(samples[:, :, idx], 84, axis=0)
    # ax.errorbar(true[:, idx], p50, yerr=[p50 - p16, p84 - p50], fmt='o',
    #  alpha=0.1, color='C0', markersize=2)
    # Instead, plot as hexbin and count all draws
    # broadcast true to match samples shape
    # Do Hexbins of LOW SNR data in blues
    true_full = np.repeat(trues[:, idx][np.newaxis, :], samples.shape[0], axis=0).flatten()
    yy = samples[:, :, idx].flatten()

    # if i == 1:
    #    true_full = np.log10(true_full)
    #    yy = np.log10(yy)
    #    yy = np.clip(yy, -6, None)
    #    true_full = np.clip(true_full, -6, None)
    # if i == 2:
    #    yy = 10**yy
    #    true_full = 10**true_full
    pparams = {}
    if param == "beta":
        pparams = dict(extent=(-2.8, 2, -2.8, 2))
        ax.set_xlim(-2.8, 0)
        ax.set_ylim(-2.8, 0)
    if param == "log10_floor_sfr_10":
        # Remove data below -3
        mask = (true_full > -3) & (yy > -3)
        true_full = true_full[mask]
        yy = yy[mask]
    if param == "log10_Av":
        mask = (true_full > -2) & (yy > -2)
        true_full = true_full[mask]
        yy = yy[mask]

    ax.hexbin(
        true_full,
        yy,
        gridsize=50,
        cmap="cmr.sunburst_r",
        mincnt=1,
        **pparams,
    )  # bins='log')
    # force square aspect ratio
    ax.set_xlim(true_full.min(), true_full.max())
    ax.set_ylim(true_full.min(), true_full.max())

    # ax.set(adjustable="box", aspect="equal")

    if i > 2:
        ax.set_xlabel("True Value")
    if i == 0 or i == 3:
        ax.set_ylabel("Inferred Value")
    ax.plot(
        [true_full.min(), true_full.max()], [true_full.min(), true_full.max()], "k--", alpha=0.5
    )
    ax.text(0.05, 0.90, f"${param_latex[i]}$", transform=ax.transAxes, fontsize=11)
    ax.text(0.05, 0.80, "RMSE = {:.2f}".format(rmse), transform=ax.transAxes, fontsize=10)
    ax.text(0.05, 0.70, f"$R^2={r2:.2f}$", transform=ax.transAxes, fontsize=10)
plt.tight_layout()
plt.subplots_adjust(hspace=0.2, wspace=0.30)


fig.savefig("plots/bagpipes_validation.png", dpi=300, bbox_inches="tight")