## Here we compare the Bagpipes results for ~4000 JADES GOODS galaxies to synference results.

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from astropy.table import Table, hstack
from unyt import Jy

from synference import SBI_Fitter, create_uncertainty_models_from_EPOCHS_cat

plt.style.use("paper.style")

### Load data and trained model

In [None]:
bagpipes_results = "/home/tharvey/work/synference/priv/data/cats/sfh_dense_basis_dust_Calzetti.fits"
jades_catalog = "/home/tharvey/work/synference/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"  # noqa: E501
data_err_file = "/home/tharvey/work/JADES-DR3-GS_MASTER_Sel-F277W+F356W+F444W_v13.fits"

bagpipes = Table.read(bagpipes_results)
jades = Table.read(jades_catalog)

assert len(bagpipes) == len(jades)

# combine the tables
cat = hstack([jades, bagpipes])

model_name = "DenseBasis_v3_final_custom_small"
extra = "nsf_zfix"


fitter = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl",
    grid_path="/home/tharvey/work/synference/grids/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode.hdf5",
)

Create noise model

In [None]:
bands = [
    "F435W",
    "F606W",
    "F775W",
    "F814W",
    "F850LP",
    "F090W",
    "F115W",
    "F150W",
    "F200W",
    "F277W",
    "F335M",
    "F356W",
    "F410M",
    "F444W",
]
hst_bands = ["F435W", "F606W", "F775W", "F814W", "F850LP"]
new_band_names = [
    (f"HST/ACS_WFC.{band.upper()}" if band in hst_bands else f"JWST/NIRCam.{band.upper()}")
    for band in bands
]

empirical_noise_models = create_uncertainty_models_from_EPOCHS_cat(
    data_err_file,
    bands,
    new_band_names,
    plot=False,
    hdu="OBJECTS",
    save=False,
    min_flux_error=0,
    model_class="asinh",
)

fitter.feature_array_flags["empirical_noise_models"] = empirical_noise_models

Explain catalogue to SBI_Fitter

In [None]:
def flux_cols_syntax(band):
    """Returns the syntax for the flux columns in the table."""
    return f"FLUX_TOTAL_{band}"


def magerr_cols_syntax(band):
    """Returns the syntax for the magnitude error columns in the table."""
    return f"FLUXERR_TOTAL_{band}"


new_band_names = fitter.feature_names

bands = [
    band.split(".")[-1]
    for band in new_band_names
    if not (band.startswith("unc_") or band == "redshift")
]
new_table = Table()

# Do this but just keep in Jy

for band in bands:
    flux = cat[flux_cols_syntax(band)] / 1e9
    err = cat[magerr_cols_syntax(band)] / 1e9
    new_table[band] = flux
    new_table[f"unc_{band}"] = err


new_table["redshift"] = cat["specz"]

conversion_dict = {band: new_band for band, new_band in zip(bands, new_band_names)}
conversion_dict.update(
    {f"unc_{band}": f"unc_{new_band}" for band, new_band in zip(bands, new_band_names)}
)

conversion_dict["redshift"] = "redshift"

# Do Inference

In [None]:
post_tab = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=True,
)

In [None]:
mask = cat["input_redshift"] >= 0.01
print(f"Number of galaxies at z>0.01: {np.sum(mask)}")
cat = cat[mask]
post_tab = post_tab[mask]

### Outlier Detection

In [None]:
from synference.utils import detect_outliers_pyod

observations = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=True,
    check_out_of_distribution=False,
)

output = detect_outliers_pyod(
    fitter.feature_array,
    observations[0],
    contamination=0.01,
    methods=["iforest", "feature_bagging", "ecod", "knn", "lof", "gmm", "mcd", "kde"],
)

In [None]:
cat

In [None]:
cat["outlier"] = 0

cat[~observations[1]]["outlier"] = output

In [None]:
cat["outlier"].sum()

In [None]:
cat_unmasked = cat[~observations[1]]
cat_unmasked["outlier"] = output

In [None]:
cat_unmasked[cat_unmasked["outlier"] == 1]

plt.plot(cat["input_redshift"], cat["stellar_mass_50"], ".", alpha=0.1)
plt.plot(
    cat_unmasked[cat_unmasked["outlier"] == 1]["input_redshift"],
    cat_unmasked[cat_unmasked["outlier"] == 1]["stellar_mass_50"],
    "r.",
    label="Outliers",
)
plt.xlabel("Redshift")

In [None]:
plt.scatter(post_tab["log_mass_50"], post_tab["log10_floor_sfr_10_50"])

In [None]:
fig, (ax1, ax2) = plt.subplots(
    2, 1, figsize=(5, 6), gridspec_kw={"height_ratios": [3, 1]}, sharex=True
)
ax1.errorbar(
    cat["stellar_mass_50"],
    post_tab["log_surviving_mass_50"],
    xerr=[
        post_tab["log_surviving_mass_50"] - post_tab["log_surviving_mass_16"],
        post_tab["log_surviving_mass_84"] - post_tab["log_surviving_mass_50"],
    ],
    yerr=[
        post_tab["log_mass_50"] - post_tab["log_mass_16"],
        post_tab["log_mass_84"] - post_tab["log_mass_50"],
    ],
    fmt="o",
    alpha=0.5,
    label="Data points with error bars",
    markersize=4,
    linewidth=1,
    markeredgecolor="black",
)
ax1.plot([6, 12], [6, 12], "r--", label="1:1 Line")
ax1.set_ylabel(r"Synference $\log_{10}(M_*/M_\odot)$", fontsize=16)
ax1.grid()
ax1.yaxis.set_tick_params(labelsize=14)

residuals = post_tab["log_mass_50"] - cat["stellar_mass_50"]
ax2.errorbar(
    cat["stellar_mass_50"],
    residuals,
    yerr=[
        (post_tab["log_mass_50"] - post_tab["log_mass_16"]),
        (post_tab["log_mass_84"] - post_tab["log_mass_50"]),
    ],
    fmt="o",
    alpha=0.5,
    label="Residuals with error bars",
    markersize=4,
    linewidth=1,
    markeredgecolor="black",
)
ax2.axhline(0, color="r", linestyle="--")
ax2.set_xlabel(r"Bagpipes $\log_{10}(M_*/M_\odot)$", fontsize=16)
ax2.set_ylabel("Residuals", fontsize=16)
ax2.grid()
ax2.yaxis.set_tick_params(labelsize=14)
ax2.xaxis.set_tick_params(labelsize=14)

# Add density contours


plt.tight_layout()
plt.show()
fig.savefig("plots/bagpipes_vs_synference_stellar_mass.png", dpi=300)

In [None]:
x = cat["stellar_mass_50"]
y = post_tab["log_surviving_mass_50"]
mask = np.isfinite(x) & np.isfinite(y)
x = x[mask]
y = y[mask]
# Calculate the point density

density = np.histogram2d(x, y, bins=30)[0]

fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))
plt.contour(
    density,
    extent=[x.min(), x.max(), y.min(), y.max()],
    cmap="cmr.torch_r",
    alpha=0.8,
    linewidths=1,
    levels=5,
)
# Add black 1:1 line
plt.plot([7, 11], [7, 11], "k--", label="1:1 Line", linewidth=1)
plt.xlabel(r"Bagpipes $\log_{10}(M_*/M_\odot)$", fontsize=16)
plt.ylabel(r"Synference $\log_{10}(M_*/M_\odot)$", fontsize=16)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

# Generate sample data


# Calculate density using KDE
xy = np.vstack([x, y])
kde = gaussian_kde(xy)
density = kde(xy)

# Define density threshold
threshold = np.percentile(density, 20)  # Show points below 70th percentile

fig, ax = plt.subplots(figsize=(10, 8))

# Plot contours for high-density regions
xi, yi = np.mgrid[x.min() - 1 : x.max() + 1 : 50j, y.min() - 1 : y.max() + 1 : 50j]
zi = kde(np.vstack([xi.flatten(), yi.flatten()]))
zi = zi.reshape(xi.shape)

contours = ax.contour(xi, yi, zi, levels=8, colors="blue", alpha=0.6, linewidths=1.5, zorder=4)
ax.contourf(xi, yi, zi, levels=8, alpha=0.3, cmap="Blues", zorder=3)

# Plot individual points only in low-density regions
low_density_mask = density < threshold
ax.scatter(x[low_density_mask], y[low_density_mask], c="red", s=30, alpha=0.7, zorder=2)

ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_title("Scatter Plot with Density Contours")
plt.show()

In [None]:
def plot_comparison(x, y, xerr, yerr, xlabel, ylabel, filename, contours=True):
    """Plot comparison between two sets of measurements with error bars and residuals."""
    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=(5, 6), gridspec_kw={"height_ratios": [3, 1]}, sharex=True
    )

    if contours:
        # Calculate density using KDE
        mask = np.isfinite(x) & np.isfinite(y)
        x = x[mask]
        y = y[mask]
        xerr = [xerr[0][mask], xerr[1][mask]]
        yerr = [yerr[0][mask], yerr[1][mask]]
        xy = np.vstack([x, y])
        kde = gaussian_kde(xy)
        density = kde(xy)

        # Define density threshold
        threshold = np.percentile(density, 20)
        low_density_mask = density < threshold

        # Plot contours for high-density regions
        xi, yi = np.mgrid[x.min() - 1 : x.max() + 1 : 50j, y.min() - 1 : y.max() + 1 : 50j]
        zi = kde(np.vstack([xi.flatten(), yi.flatten()]))
        zi = zi.reshape(xi.shape)

        levels = np.linspace(zi.min(), zi.max(), 8)[1:]
        contours = ax1.contour(
            xi, yi, zi, levels=levels, colors="blue", alpha=0.6, linewidths=1.5, zorder=4
        )
        ax1.contourf(xi, yi, zi, levels=levels, alpha=0.3, cmap="Blues", zorder=3)

        # Contours for ax2
        xy = np.vstack([x, y - x])
        kde = gaussian_kde(xy)
        density = kde(xy)
        zi = kde(np.vstack([xi.flatten(), (yi - xi).flatten()]))
        zi = zi.reshape(xi.shape)
        ax2.contour(
            xi, yi - xi, zi, levels=levels, colors="blue", alpha=0.6, linewidths=1.5, zorder=4
        )
        ax2.contourf(xi, yi - xi, zi, levels=levels, alpha=0.3, cmap="Blues", zorder=3)

    x = x[low_density_mask] if contours else x
    y = y[low_density_mask] if contours else y
    xerr = [xerr[0][low_density_mask], xerr[1][low_density_mask]] if contours else xerr
    yerr = [yerr[0][low_density_mask], yerr[1][low_density_mask]] if contours else yerr

    ax1.errorbar(x, y, xerr=xerr, yerr=yerr, alpha=0.2, linestyle="none", marker="none")
    ax1.scatter(x, y, alpha=0.5, s=4, edgecolor="black", zorder=10, linewidths=0.5)

    ax1.plot([np.nanmin(x), np.nanmax(x)], [np.nanmin(x), np.nanmax(x)], "r--", label="1:1 Line")
    ax1.set_ylabel(ylabel, fontsize=16)
    ax1.grid()
    ax1.yaxis.set_tick_params(labelsize=14)

    residuals = y - x
    ax2.errorbar(x, residuals, yerr=yerr, fmt="o", alpha=0.2, marker="none", linestyle="none")
    ax2.scatter(x, residuals, alpha=0.5, s=4, edgecolor="black", zorder=10, linewidths=0.5)
    ax2.axhline(0, color="r", linestyle="--")
    ax2.set_xlabel(xlabel, fontsize=16)
    ax2.set_ylabel("Residuals", fontsize=16)
    ax2.grid()
    ax2.yaxis.set_tick_params(labelsize=14)
    ax2.xaxis.set_tick_params(labelsize=14)

    ax2.set_ylim(-1.5, 1.5)

    plt.tight_layout()
    plt.show()
    fig.savefig(filename, dpi=300)


plot_comparison(
    cat["stellar_mass_50"],
    post_tab["log_surviving_mass_50"],
    xerr=[
        post_tab["log_surviving_mass_50"] - post_tab["log_surviving_mass_16"],
        post_tab["log_surviving_mass_84"] - post_tab["log_surviving_mass_50"],
    ],
    yerr=[
        post_tab["log_mass_50"] - post_tab["log_mass_16"],
        post_tab["log_mass_84"] - post_tab["log_mass_50"],
    ],
    xlabel=r"Bagpipes $\log_{10}(M_*/M_\odot)$",
    ylabel=r"Synference $\log_{10}(M_*/M_\odot)$",
    filename="plots/bagpipes_vs_synference_stellar_mass.png",
)

plot_comparison(
    np.log10(cat["sfr_10myr_50"]),
    post_tab["log10_floor_sfr_10_50"],
    xerr=[
        np.log10(cat["sfr_10myr_50"]) - np.log10(cat["sfr_10myr_16"]),
        np.log10(cat["sfr_10myr_84"]) - np.log10(cat["sfr_10myr_50"]),
    ],
    yerr=[
        post_tab["log10_floor_sfr_10_50"] - post_tab["log10_floor_sfr_10_16"],
        post_tab["log10_floor_sfr_10_84"] - post_tab["log10_floor_sfr_10_50"],
    ],
    xlabel=rf"Bagpipes SFR (10 Myr) [$\log_{10} M_\odot$/yr]",
    ylabel=r"Synference SFR (10 Myr) [$\log_{10} M_\odot$/yr]",
    filename="plots/bagpipes_vs_synference_sfr_10.png",
)
plot_comparison(
    np.log10(cat["mass_weighted_age_50"] * 1e3),
    post_tab["log10_mass_weighted_age_50"],
    xerr=[
        np.log10(cat["mass_weighted_age_50"]) - np.log10(cat["mass_weighted_age_16"]),
        np.log10(cat["mass_weighted_age_84"]) - np.log10(cat["mass_weighted_age_50"]),
    ],
    yerr=[
        post_tab["log10_mass_weighted_age_50"] - post_tab["log10_mass_weighted_age_16"],
        post_tab["log10_mass_weighted_age_84"] - post_tab["log10_mass_weighted_age_50"],
    ],
    xlabel=r"Bagpipes Mass-weighted Age [$\log_{10} Myr$]",
    ylabel=r"Synference Mass-weighted Age [$\log_{10} Myr$]",
    filename="plots/bagpipes_vs_synference_mass_weighted_age.png",
)

plot_comparison(
    np.log10(cat["dust:Av_50"]),
    post_tab["log10_Av_50"],
    xerr=[
        np.log10(cat["dust:Av_50"]) - np.log10(cat["dust:Av_16"]),
        np.log10(cat["dust:Av_84"]) - np.log10(cat["dust:Av_50"]),
    ],
    yerr=[
        post_tab["log10_Av_50"] - post_tab["log10_Av_16"],
        post_tab["log10_Av_84"] - post_tab["log10_Av_50"],
    ],
    xlabel=r"Bagpipes $\log_{10}(A_V)$",
    ylabel=r"Synference $\log_{10}(A_V)$",
    filename="plots/bagpipes_vs_synference_Av.png",
)

In [None]:
# in three redshift bins - z = 0 - 1, 2 - 3, 3 - 6, plot Bagpipes and Synference SFMS

fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharey=True, dpi=200)

# Bagpipes on right
# Synference on left

z_bins = [(0, 1), (2, 3), (3, 6)]
colors = ["blue", "green", "red"]
labels = ["$0 < z < 1$", "$2 < z < 3$", "$3 < z < 6$"]
for (z_min, z_max), color, label in zip(z_bins, colors, labels):
    mask = (cat["specz"] >= z_min) & (cat["specz"] < z_max)
    axs[0].scatter(
        cat["stellar_mass_50"][mask],
        np.log10(cat["sfr_10myr_50"][mask]),
        color=color,
        alpha=0.5,
        s=5,
        label=label,
        edgecolor="black",
        linewidth=0.5,
    )
    axs[1].scatter(
        post_tab["log_surviving_mass_50"][mask],
        post_tab["log10_floor_sfr_10_50"][mask],
        color=color,
        alpha=0.5,
        s=4,
        label=label,
        edgecolor="black",
        linewidth=0.5,
    )

axs[0].set_xlabel(r"$\log_{10}(M_*/M_\odot)$", fontsize=14)
axs[0].set_ylabel(r"$\log_{10}(SFR_{10 Myr}/M_\odot yr^{-1})$", fontsize=14)
axs[0].text(
    0.05, 0.95, "Bagpipes", transform=axs[0].transAxes, fontsize=16, verticalalignment="top"
)
axs[0].grid()

axs[1].set_xlabel(r"$\log_{10}(M_*/M_\odot)$", fontsize=14)
axs[1].text(
    0.05, 0.95, "Synference", transform=axs[1].transAxes, fontsize=16, verticalalignment="top"
)
axs[1].grid()
axs[1].legend(fontsize=10)

plt.tight_layout()
fig.savefig("plots/bagpipes_vs_synference_sfms.png", dpi=300)

In [None]:
print(bagpipes["#ID"][bagpipes["stellar_mass_50"].argmin()])

In [None]:
len(bagpipes[bagpipes["input_redshift"] <= 0.01])

In [None]:
# how many rows have a missing value in new_table

for col in new_table.colnames:
    a = new_table[col]
    if hasattr(a, "mask"):
        print(col, np.sum(a.mask))