In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from astropy.table import Table, hstack
from unyt import Jy

from synference import SBI_Fitter, create_uncertainty_models_from_EPOCHS_cat

plt.style.use("paper.style")

# Test Different Approaches to missing data

In [None]:
fitter = SBI_Fitter.load_saved_model(
    "/home/tharvey/work/ltu-ili_testing/models/DenseBasis_v3_just_redshift/DenseBasis_v3_just_redshift_nsf_posterior.pkl",
    grid_path="/home/tharvey/work/ltu-ili_testing/grids/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode.hdf5",
)

bands = [
    "F435W",
    "F606W",
    "F775W",
    "F814W",
    "F850LP",
    "F090W",
    "F115W",
    "F150W",
    "F200W",
    "F277W",
    "F335M",
    "F356W",
    "F410M",
    "F444W",
]
hst_bands = ["F435W", "F606W", "F775W", "F814W", "F850LP"]
new_band_names = [
    (f"HST/ACS_WFC.{band.upper()}" if band in hst_bands else f"JWST/NIRCam.{band.upper()}")
    for band in bands
]

bagpipes_results = "/home/tharvey/work/synference/priv/data/cats/sfh_dense_basis_dust_Calzetti.fits"
jades_catalog = "/home/tharvey/work/synference/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"  # noqa: E501
data_err_file = "/home/tharvey/work/JADES-DR3-GS_MASTER_Sel-F277W+F356W+F444W_v13.fits"

bagpipes = Table.read(bagpipes_results)
jades = Table.read(jades_catalog)

assert len(bagpipes) == len(jades)

# combine the tables
cat = hstack([jades, bagpipes])

empirical_noise_models = create_uncertainty_models_from_EPOCHS_cat(
    data_err_file,
    bands,
    new_band_names,
    plot=False,
    hdu="OBJECTS",
    save=False,
    min_flux_error=0,
    model_class="asinh",
)

fitter.feature_array_flags["empirical_noise_models"] = empirical_noise_models


def flux_cols_syntax(band):
    """Returns the syntax for the flux columns in the table."""
    return f"FLUX_TOTAL_{band}"


def magerr_cols_syntax(band):
    """Returns the syntax for the magnitude error columns in the table."""
    return f"FLUXERR_TOTAL_{band}"


new_band_names = fitter.feature_names

bands = [
    band.split(".")[-1]
    for band in new_band_names
    if not (band.startswith("unc_") or band == "redshift")
]
new_table = Table()


conversion_dict = {band: new_band for band, new_band in zip(bands, new_band_names)}
conversion_dict.update(
    {f"unc_{band}": f"unc_{new_band}" for band, new_band in zip(bands, new_band_names)}
)

# Do this but just keep in Jy

for band in bands:
    flux = cat[flux_cols_syntax(band)] / 1e9
    err = cat[magerr_cols_syntax(band)] / 1e9
    new_table[band] = flux
    new_table[f"unc_{band}"] = err


new_table["redshift"] = cat["specz"]

# Unmask any cell in any column which is masked but is not nan


# Drop all rows which for any band have a non nan flux measurement but a corresponidng nan error.

mask = np.zeros(len(new_table), dtype=bool)
for i in range(len(new_table)):
    for band in bands:
        if np.isfinite(new_table[band][i]) and not np.isfinite(new_table[f"unc_{band}"][i]):
            mask[i] = True
            break

new_table = new_table[~mask]

In [None]:
conversion_dict

In [None]:
new_table_miny = new_table

In [None]:
test_new = new_table_miny[:100].copy()
f444w = test_new["F444W"].copy()

test_new["F444W"] = np.nan
test_new["unc_F444W"] = np.nan

post_tab = fitter.fit_catalogue(
    test_new,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=False,
    missing_data_flag=np.nan,
    missing_data_mcmc=True,
)

In [None]:
ab__f444w_predict = post_tab["predicted_JWST/NIRCam.F444W"]
true_f444w = -2.5 * np.log10(f444w) + 8.90
plt.figure(figsize=(6, 6))
plt.scatter(true_f444w, ab__f444w_predict, s=5)

plt.plot([20, 30], [20, 30], color="black", ls="--")

In [None]:
post_tab = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=False,
    missing_data_flag=np.nan,
    missing_data_mcmc=True,
)

In [None]:
post_tab

In [None]:
missing_data_mask = post_tab["has_missing_data"]

# Plot spec-z vs photo-z in different colors for missing data and not missing data

plt.figure(figsize=(6, 6))
plt.scatter(
    post_tab["redshift"][~missing_data_mask],
    post_tab["redshift_50"][~missing_data_mask],
    label="No missing data",
    alpha=0.5,
)
plt.scatter(
    post_tab["redshift"][missing_data_mask],
    post_tab["redshift_50"][missing_data_mask],
    label="Missing data",
    alpha=0.5,
)
plt.plot([0, 15], [0, 15], color="k", linestyle="--")
plt.xlim(0, 15)
plt.ylim(0, 15)
plt.xlabel("Spec-z")
plt.ylabel("Photo-z")

plt.legend()

Missing Full Data

In [None]:
from synference import SBI_Fitter

model_name = "BPASS_DenseBasis_v4_final_missing_bands"  # "DenseBasis_v4_final_custom_small"
extra = "nsf"


fitter_missing = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl",
    grid_path="/home/tharvey/work/synference/grids/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v4_multinode.hdf5",
)

model_name = "BPASS_DenseBasis_v4_final"  # "DenseBasis_v4_final_custom_small"
extra = "nsf_0"
fitter_full = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl",
    grid_path="/home/tharvey/work/synference/grids/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_6.0_Calzetti_v4_multinode.hdf5",
)


bands = [
    "F435W",
    "F606W",
    "F775W",
    "F814W",
    "F850LP",
    "F090W",
    "F115W",
    "F150W",
    "F200W",
    "F277W",
    "F335M",
    "F356W",
    "F410M",
    "F444W",
]
hst_bands = ["F435W", "F606W", "F775W", "F814W", "F850LP"]
new_band_names = [
    (f"HST/ACS_WFC.{band.upper()}" if band in hst_bands else f"JWST/NIRCam.{band.upper()}")
    for band in bands
]

bagpipes_results = "/home/tharvey/work/synference/priv/data/cats/sfh_dense_basis_dust_Calzetti.fits"
jades_catalog = "/home/tharvey/work/synference/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"  # noqa: E501
data_err_file = "/home/tharvey/work/JADES-DR3-GS_MASTER_Sel-F277W+F356W+F444W_v13.fits"

bagpipes = Table.read(bagpipes_results)
jades = Table.read(jades_catalog)

assert len(bagpipes) == len(jades)

# combine the tables
cat = hstack([jades, bagpipes])

empirical_noise_models = create_uncertainty_models_from_EPOCHS_cat(
    data_err_file,
    bands,
    new_band_names,
    plot=False,
    hdu="OBJECTS",
    save=False,
    min_flux_error=0,
    model_class="asinh",
)

fitter_missing.feature_array_flags["empirical_noise_models"] = empirical_noise_models
fitter_full.feature_array_flags["empirical_noise_models"] = empirical_noise_models


def flux_cols_syntax(band):
    """Returns the syntax for the flux columns in the table."""
    return f"FLUX_TOTAL_{band}"


def magerr_cols_syntax(band):
    """Returns the syntax for the magnitude error columns in the table."""
    return f"FLUXERR_TOTAL_{band}"


feature_names = fitter_missing.feature_names

bands = [
    band.split(".")[-1]
    for band in feature_names
    if not (band.startswith("unc_") or band == "redshift") or band.startswith("flag_")
]


new_table = Table()

# Do this but just keep in Jy

for band in bands:
    flux = cat[flux_cols_syntax(band)] / 1e9
    err = cat[magerr_cols_syntax(band)] / 1e9
    new_table[band] = flux
    new_table[f"unc_{band}"] = err


new_table["redshift"] = cat["specz"]

# Unmask any cell in any column which is masked but is not nan


# Drop all rows which for any band have a non nan flux measurement but a corresponidng nan error.

mask = np.zeros(len(new_table), dtype=bool)
for i in range(len(new_table)):
    for band in bands:
        if np.isfinite(new_table[band][i]) and not np.isfinite(new_table[f"unc_{band}"][i]):
            mask[i] = True
            break

new_table = new_table[~mask]

conversion_dict = {band: new_band for band, new_band in zip(bands, new_band_names)}
conversion_dict.update(
    {f"unc_{band}": f"unc_{new_band}" for band, new_band in zip(bands, new_band_names)}
)

conversion_dict["redshift"] = "redshift"

new_table_missing = new_table.copy()
missing_data = np.zeros(len(new_table_missing), dtype=bool)
for band in new_band_names:
    band_simple = band.split(".")[-1]
    if hasattr(new_table_missing[band_simple], "mask"):
        mask_nan = np.isnan(new_table_missing[band_simple].filled(np.nan))
    else:
        print(band_simple)
        mask_nan = np.isnan(new_table_missing[band_simple])
    new_table_missing[f"flag_{band}"] = mask_nan.astype(bool)
    missing_data = missing_data | mask_nan

new_table_missing["missing_data_flag"] = missing_data.astype(bool)


def demask_table(table):
    """Demask all columns in an astropy table by filling masked values with np.nan."""
    for col in table.colnames:
        if hasattr(table[col], "mask"):
            table[col] = table[col].filled(np.nan)
    return table


new_table_missing = demask_table(new_table_missing)

new_table_missing_only = new_table_missing[new_table_missing["missing_data_flag"]]

In [None]:
post_tab = fitter_full.fit_catalogue(
    new_table_missing_only,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=True,
    check_out_of_distribution=False,
    missing_data_flag=np.nan,
    missing_data_mcmc=True,
)
post_tab[0][0]

In [None]:
# delete row 220
new_table_missing_only = new_table_missing_only[np.arange(len(new_table_missing_only)) != 221]

In [None]:
from astropy.table import vstack

test_table = vstack([new_table_missing_only[:219], new_table_missing[222:]])

In [None]:
post_tab = fitter_full.fit_catalogue(
    new_table_missing_only,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=False,
    missing_data_flag=np.nan,
    missing_data_mcmc=True,
    log_times=True,
)

In [None]:
post_tab.write("../priv/data/cats/jades_dr3_missing_data_posteriors.csv")

In [None]:
post_tab_missing = fitter_missing.fit_catalogue(
    new_table_missing_only,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    return_feature_array=False,
    check_out_of_distribution=False,
    missing_data_mcmc=False,
    missing_data_flag=np.nan,
    log_times=True,
)

In [None]:
key = "log_mass_50"

plt.plot(
    [np.nanmin(post_tab[key]), np.nanmax(post_tab[key])],
    [np.nanmin(post_tab[key]), np.nanmax(post_tab[key])],
    color="black",
    ls="--",
)
plt.xlabel("SBI++ Approach")
plt.ylabel("Flagging Approach")

plt.show()


def get_width(post_tab, key):
    """Get the width of the posterior for a given key."""
    return (post_tab[key.replace("50", "84")] - post_tab[key.replace("50", "16")]) / 2


width_clipped = get_width(post_tab, key)
width_missing = get_width(post_tab_missing, key)

plt.figure(figsize=(6, 6), dpi=150)
plt.scatter(
    width_clipped,
    width_missing,
    alpha=0.5,
    s=5,
)

plt.plot(
    [np.nanmin(width_clipped), np.nanmax(width_clipped)],
    [np.nanmin(width_clipped), np.nanmax(width_clipped)],
    color="black",
    ls="--",
)

In [None]:
properties = [
    "log_mass",
    "log10metallicity",
    "log10_Av",
    "log10_mass_weighted_age",
    "log10_floor_sfr_10",
    "log_surviving_mass",
    "beta",
]

param_latex = [
    r"\log(\rm M_\star/M_\odot)",
    r"\log(Z_\star/Z_\odot)",
    r"\log(\rm A_V / \rm mag)",
    r"\log(\rm Age_{MW} / \rm yr)",
    r"\log(\rm SFR_{10 \rm Myr}/ M_\odot yr^{-1})",
    r"\rm \log M_{\star,surv}/M_\odot)",
    r"\beta",
]

# Do a rolling median of the parameter recovery between SBI++ and flagging approach

import cmasher

colors = cmasher.take_cmap_colors("cmr.guppy", len(properties), cmap_range=(0, 1.0))

fig, ax = plt.subplots(figsize=(6, 6), dpi=150)


# Inset histogram showing normalized posterior width ratio
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

axins = inset_axes(ax, width="40%", height="30%", loc="upper left", borderpad=2)
axins.axvline(1, color="black", ls="--")

for i, prop in enumerate(properties):
    x = post_tab[prop + "_50"]
    y = post_tab_missing[prop + "_50"]
    pmin = np.nanmin(np.concatenate([x, y]))
    pmax = np.nanmax(np.concatenate([x, y]))

    # Rescale to 0-1
    x_rescaled = (x - pmin) / (pmax - pmin)
    y_rescaled = (y - pmin) / (pmax - pmin)

    bins = np.linspace(0, 1, 10)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    bin_medians = np.zeros(len(bin_centers))
    bin_stds = np.zeros(len(bin_centers))
    for j in range(len(bins) - 1):
        in_bin = (x_rescaled >= bins[j]) & (x_rescaled < bins[j + 1])
        if np.sum(in_bin) > 0:
            bin_medians[j] = np.nanmedian(y_rescaled[in_bin])
            bin_stds[j] = np.nanstd(y_rescaled[in_bin])
        else:
            bin_medians[j] = np.nan
            bin_stds[j] = np.nan

    ax.plot(
        bin_centers,
        bin_medians,
        label=f"${param_latex[i]}$",
        color=colors[i],
        path_effects=[plt.matplotlib.patheffects.withStroke(linewidth=3, foreground="w")],
        lw=2,
    )

    ax.fill_between(
        bin_centers,
        bin_medians - bin_stds,
        bin_medians + bin_stds,
        color=colors[i],
        alpha=0.1,
        edgecolor=None,
    )

    # Compute the ratio of the posterior widths - get a KDE of the ratio
    width_clipped = get_width(post_tab, f"{prop}_50")
    width_missing = get_width(post_tab_missing, f"{prop}_50")
    width_ratio = width_missing / width_clipped
    width_ratio = width_ratio[np.isfinite(width_ratio) & (width_ratio < 10)]
    from scipy.stats import gaussian_kde

    kde = gaussian_kde(width_ratio)
    xkde = np.linspace(0, 3, 100)
    axins.plot(
        xkde,
        kde(xkde),
        color=colors[i],
        path_effects=[plt.matplotlib.patheffects.withStroke(linewidth=3, foreground="w")],
    )

    axins.set_xlabel("Learnt / SBI++ Posterior FWHM")
    axins.set_yticks([])

ax.plot([0, 1], [0, 1], color="black", ls="--", zorder=10, lw=1.6)
ax.set_xlabel("Normalized Parameter Value (SBI++)", fontsize=12)
ax.set_ylabel("Normalized Parameter Value (Learnt)", fontsize=12)
ax.legend(loc="lower right")
ax.grid(False)

fig.savefig("plots/missing_data_comparison.png", dpi=300, bbox_inches="tight")