In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from astropy.table import Table, hstack, join
from unyt import Jy

from synference import SBI_Fitter, create_uncertainty_models_from_EPOCHS_cat

plt.style.use("paper.style")

In [None]:
fitter = SBI_Fitter.load_saved_model(
    "/home/tharvey/work/ltu-ili_testing/models/DenseBasis_v3_just_redshift/DenseBasis_v3_just_redshift_nsf_posterior.pkl",
    grid_path="/home/tharvey/work/ltu-ili_testing/libraries/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode.hdf5",
)

bands = [
    "F435W",
    "F606W",
    "F775W",
    "F814W",
    "F850LP",
    "F090W",
    "F115W",
    "F150W",
    "F200W",
    "F277W",
    "F335M",
    "F356W",
    "F410M",
    "F444W",
]
hst_bands = ["F435W", "F606W", "F775W", "F814W", "F850LP"]
new_band_names = [
    (f"HST/ACS_WFC.{band.upper()}" if band in hst_bands else f"JWST/NIRCam.{band.upper()}")
    for band in bands
]

bagpipes_results = "/home/tharvey/work/synference/priv/data/cats/sfh_dense_basis_dust_Calzetti.fits"
jades_catalog = "/home/tharvey/work/synference/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"  # noqa: E501
data_err_file = "/home/tharvey/work/JADES-DR3-GS_MASTER_Sel-F277W+F356W+F444W_v13.fits"

bagpipes = Table.read(bagpipes_results)
jades = Table.read(jades_catalog)

assert len(bagpipes) == len(jades)

# combine the tables
cat = hstack([jades, bagpipes])

empirical_noise_models = create_uncertainty_models_from_EPOCHS_cat(
    data_err_file,
    bands,
    new_band_names,
    plot=False,
    hdu="OBJECTS",
    save=False,
    min_flux_error=0,
    model_class="asinh",
)

fitter.feature_array_flags["empirical_noise_models"] = empirical_noise_models

In [None]:
fig = fitter.plot_parameter_deviations(snr_bins=[5, 10, np.inf])


fig.savefig(
    "/home/tharvey/work/ltu-ili_testing/models/DenseBasis_v3_just_redshift/plots/nsf/parameter_deviations.png"
)

In [None]:
def flux_cols_syntax(band):
    """Returns the syntax for the flux columns in the table."""
    return f"FLUX_TOTAL_{band}"


def magerr_cols_syntax(band):
    """Returns the syntax for the magnitude error columns in the table."""
    return f"FLUXERR_TOTAL_{band}"


new_band_names = fitter.feature_names

bands = [
    band.split(".")[-1]
    for band in new_band_names
    if not (band.startswith("unc_") or band == "redshift")
]
new_table = Table()

# Do this but just keep in Jy

for band in bands:
    flux = cat[flux_cols_syntax(band)] / 1e9
    err = cat[magerr_cols_syntax(band)] / 1e9
    new_table[band] = flux
    new_table[f"unc_{band}"] = err


new_table["redshift"] = cat["specz"]

conversion_dict = {band: new_band for band, new_band in zip(bands, new_band_names)}
conversion_dict.update(
    {f"unc_{band}": f"unc_{new_band}" for band, new_band in zip(bands, new_band_names)}
)

In [None]:
post_tab = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=False,
)

In [None]:
plt.rcParams.update({"font.size": 12})

color = "teal"
edgecolor = "black"
fig = plt.figure(figsize=(5, 5))
gs = fig.add_gridspec(2, 1, height_ratios=[3, 1], hspace=0.1, wspace=0.3)
max = fig.add_subplot(gs[0, 0])
ax = [max, fig.add_subplot(gs[1, 0], sharex=max)]

ax[0].scatter(
    cat["specz"],
    post_tab["redshift_50"],
    s=6,
    alpha=0.7,
    color=color,
    edgecolor=edgecolor,
    linewidth=0.2,
    zorder=5,
)

ax[0].errorbar(
    cat["specz"],
    post_tab["redshift_50"],
    yerr=[
        post_tab["redshift_50"] - post_tab["redshift_16"],
        post_tab["redshift_84"] - post_tab["redshift_50"],
    ],
    marker="none",
    linestyle="none",
    ecolor="gray",
    alpha=0.3,
    zorder=-1,
)
ax[0].plot([0, 10], [0, 10], color="k", linestyle="--")
# Plot 0.15(1+z) lines
z = np.linspace(0, 10, 100)

ax[0].fill_between(z, z - 0.15 * (1 + z), z + 0.15 * (1 + z), color="grey", alpha=0.1)
ax[0].set_ylabel("Inferred Redshift")
# ax[0].set_xlim(0, 10)
# ax[0].set_ylim(0, 10)
# Only show x tick labels on bottom plot
plt.setp(ax[0].get_xticklabels(), visible=False)

# Show a single fitted line

# On ax2 show delta_z as  a function of specz with 0.15(1+z) lines
ax[1].scatter(
    cat["specz"],
    (post_tab["redshift_50"] - cat["specz"]) / (1 + cat["specz"]),
    s=6,
    alpha=0.7,
    color=color,
    edgecolor=edgecolor,
    linewidth=0.2,
    zorder=5,
)
ax[1].fill_between(z, -0.15, 0.15, color="grey", alpha=0.1)
ax[1].set_xlabel("Spectroscopic Redshift")
ax[1].set_ylabel(r"$\Delta z / (1 + z)$")
ax[1].set_xlim(0, 10)
ax[1].set_ylim(-1, 1)
ax[1].axhline(0, color="k", linestyle="--")

# add hisrogram in new axis on top of plot not using gridspec

fax = fig.add_axes([0.68, 0.45, 0.21, 0.21])


delta_z = (post_tab["redshift_50"] - cat["specz"]) / (1 + cat["specz"])
nmad = 1.48 * np.nanmedian(np.abs(delta_z - np.nanmedian(delta_z)))
bias = np.nanmedian(delta_z)
# label 0
fax.axvline(0, color="k", linestyle="--")
outlier_frac = np.nansum(np.abs(delta_z) > 0.15) / len(delta_z)
# Shade 15% region
fax.axvspan(-0.15, 0.15, color="grey", alpha=0.1)
fax.hist(delta_z, bins=50, range=(-0.6, 0.6), histtype="stepfilled", color=color)
fax.set_xlabel(r"$\Delta z / (1 + z)$")
fax.set_yticks([])
# fax.axvline(nmad, color='r', linestyle='--', label='NMAD')
# fax.axvline(bias, color='g', linestyle='--', label='Bias')
# fax.legend()
ax[0].text(
    0.05,
    0.95,
    f"Outlier Fraction: {outlier_frac:.3f}\nNMAD: {nmad:.3f}\nBias: {bias:.3f}",
    transform=ax[0].transAxes,
    verticalalignment="top",
)

fig.savefig("plots/redshift_inference.png", dpi=300)

EAZY Comparison

In [None]:
full_cat = Table.read(data_err_file, hdu="EAZY_FSPS_LARSON")
cat = join(cat, full_cat, keys="UNIQUE_ID", join_type="inner")

plt.rcParams.update({"font.size": 12})


color = "teal"
edgecolor = "black"
fig = plt.figure(figsize=(5, 5))

gs = fig.add_gridspec(2, 1, height_ratios=[3, 1], hspace=0.1, wspace=0.3)
max = fig.add_subplot(gs[0, 0])
ax = [max, fig.add_subplot(gs[1, 0], sharex=max)]

ax[0].scatter(
    cat["specz"],
    cat["zbest_fsps_larson_zfree"],
    s=6,
    alpha=0.7,
    color=color,
    edgecolor=edgecolor,
    linewidth=0.2,
    zorder=5,
)
yerr = [
    cat["zbest_fsps_larson_zfree"].data - np.squeeze(cat["zbest_16_fsps_larson_zfree"].data),
    np.squeeze(cat["zbest_84_fsps_larson_zfree"].data)
    - np.squeeze(cat["zbest_fsps_larson_zfree"].data),
]
yerr = np.abs(yerr)
ax[0].errorbar(
    cat["specz"],
    cat["zbest_fsps_larson_zfree"],
    yerr=yerr,
    marker="none",
    linestyle="none",
    ecolor="gray",
    alpha=0.3,
    zorder=-1,
)
ax[0].plot([0, 10], [0, 10], color="k", linestyle="--")
# Plot 0.15(1+z) lines
z = np.linspace(0, 10, 100)

ax[0].fill_between(z, z - 0.15 * (1 + z), z + 0.15 * (1 + z), color="grey", alpha=0.1)
ax[0].set_ylabel("Inferred Redshift")
ax[0].set_xlim(0, 10)
ax[0].set_ylim(0, 10)
# Only show x tick labels on bottom plot
plt.setp(ax[0].get_xticklabels(), visible=False)

# Show a single fitted line

# On ax2 show delta_z as  a function of specz with 0.15(1+z) lines
ax[1].scatter(
    cat["specz"],
    (cat["zbest_fsps_larson_zfree"] - cat["specz"]) / (1 + cat["specz"]),
    s=6,
    alpha=0.7,
    color=color,
    edgecolor=edgecolor,
    linewidth=0.2,
    zorder=5,
)
ax[1].fill_between(z, -0.15, 0.15, color="grey", alpha=0.1)
ax[1].set_xlabel("Spectroscopic Redshift")
ax[1].set_ylabel(r"$\Delta z / (1 + z)$")
ax[1].set_xlim(0, 10)
ax[1].set_ylim(-1, 1)
ax[1].axhline(0, color="k", linestyle="--")

# add hisrogram in new axis on top of plot not using gridspec

fax = fig.add_axes([0.68, 0.45, 0.21, 0.21])


delta_z = (cat["zbest_fsps_larson_zfree"] - cat["specz"]) / (1 + cat["specz"])
nmad = 1.48 * np.nanmedian(np.abs(delta_z - np.nanmedian(delta_z)))
bias = np.nanmedian(delta_z)
# label 0
fax.axvline(0, color="k", linestyle="--")
outlier_frac = np.nansum(np.abs(delta_z) > 0.15) / len(delta_z)
# Shade 15% region
fax.axvspan(-0.15, 0.15, color="grey", alpha=0.1)
fax.hist(delta_z, bins=50, range=(-0.6, 0.6), histtype="stepfilled", color=color)
fax.set_xlabel(r"$\Delta z / (1 + z)$")
fax.set_yticks([])
# fax.axvline(nmad, color='r', linestyle='--', label='NMAD')
# fax.axvline(bias, color='g', linestyle='--', label='Bias')
# fax.legend()
ax[0].text(
    0.05,
    0.95,
    f"Outlier Fraction: {outlier_frac:.3f}\nNMAD: {nmad:.3f}\nBias: {bias:.3f}",
    transform=ax[0].transAxes,
    verticalalignment="top",
)

fig.savefig("plots/redshift_inference.png", dpi=300)

In [None]:
import h5py

folder = "/home/tharvey/work/synference/priv/JADES-DR3-GS-pipes/"
galaxy_id = "20"
with h5py.File(
    f"{folder}/sfh_dense_basis_dust_Calzetti/{galaxy_id}_JADES-DR3-GS-North.h5", "r"
) as f:
    print(f)