### Model Testing Figures for Paper

In [None]:
%load_ext autoreload
%autoreload 2

import os

os.environ["PATH"] = "/usr/local/texlive/2025/bin/x86_64-linux:" + os.environ["PATH"]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from synference import SBI_Fitter

plt.style.use("paper.style")
model_name = "DenseBasis_v3_final_custom_small"
extra = "nsf_zfix"


fitter = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl",
    grid_path="/home/tharvey/work/synference/grids/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode.hdf5",
)

Figure 2: Calibration and recovery.

In [None]:
fitter.plot_coverage()

## Calculate TARP

In [None]:
import tarp

sample_path = (
    f"/home/tharvey/work/synference/models/{model_name}/plots/{extra}/posterior_samples.npy"
)
true = fitter._y_test

samples = np.load(sample_path)


out = tarp.get_tarp_coverage(
    samples,
    true,
    norm=True,
    bootstrap=True,
    num_bootstrap=100,
)


## Load metrics

In [None]:
metrics = f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_metrics.json"

import json

with open(metrics, "r") as f:
    met = json.load(f)

In [None]:
fitter.fitted_parameter_names


param_labels = [
    r"$\log M_*$",
    r"$\log Z/Z_\odot$",
    r"$\rm \log_{10} \ A_V$",
    r"$\log$ SFR",
    "$t_{25}$",
    "$t_{50}$",
    "$t_{75}$",
    r"$\log$ Age$_M$",
    r"$\log$ SFR$_{10}$",
    r"$\log M_{*,surv}$",
    "$\\beta$",
]

In [None]:
# import gridspec

fig = plt.figure(figsize=(9, 9), dpi=150)  # constrained_layout=True)

gs = fig.add_gridspec(2, 6, height_ratios=[1, 1.44], hspace=-0.15, wspace=0.65)

ax_top_l = fig.add_subplot(gs[0, :2])
ax_top_r = fig.add_subplot(gs[0, 2:4])
ax_top_c = fig.add_subplot(gs[0, 4:6])
ax1 = fig.add_subplot(gs[1, :3])
ax2 = fig.add_subplot(gs[1, 3:])
axs = [ax_top_l, ax_top_r, ax_top_c]


# TARP and Coverage
ax1.plot([0, 1], [0, 1], "k--", label="Ideal", alpha=0.5)
ax1.set_xlabel("Predicted Percentile")
ax1.set_ylabel("Empirical Percentile")
ax1.text(0.05, 0.90, "TARP", transform=ax1.transAxes, fontsize=16)

ax2.plot([0, 1], [0, 1], "k--", label="Ideal", alpha=0.5)
ax2.set_xlabel("Credibility Level")
ax2.set_ylabel("Expected Coverage")
ax2.text(0.05, 0.90, "SBC", transform=ax2.transAxes, fontsize=16)

ax_params = ["log_mass", "log10_floor_sfr_10", "log10_Av"]
param_latex = [
    r"\log(\rm M/M_\odot)",
    r"\log(\rm SFR_{10 \rm Myr}/ M_\odot yr^{-1})",
    r"\log(\rm A_V / \rm mag)",
]
indices = [list(fitter.fitted_parameter_names).index(p) for p in ax_params]

# ax_top_r.set_xscale('log')
# ax_top_r.set_yscale('log')
# ax_top_r.set_xlim(1e-4, 1e3)
# ax_top_r.set_ylim(1e-4, 1e3)
# TARP

ecp, alpha = out[0], out[1]
ecp_mean = np.mean(ecp, axis=0)
ecp_std = np.std(ecp, axis=0)
ax1.plot(alpha, ecp_mean, label="TARP", color="b")
ax1.fill_between(alpha, ecp_mean - ecp_std, ecp_mean + ecp_std, alpha=0.5, color="b")
ax1.fill_between(alpha, ecp_mean - 2 * ecp_std, ecp_mean + 2 * ecp_std, alpha=0.2, color="b")


# SBC
trues = fitter._y_test

ndata, npars = trues.shape
ranks = (samples < trues[None, ...]).sum(axis=0)

unicov = [np.sort(np.random.uniform(0, 1, ndata)) for j in range(200)]
unip = np.percentile(unicov, [5, 16, 84, 95], axis=0)

cdf = np.linspace(0, 1, len(ranks))
for i in range(npars):
    xr = np.sort(ranks[:, i])
    xr = xr / xr[-1]
    ax2.plot(cdf, cdf, "k--")

    ax2.plot(xr, cdf, lw=2, label=param_labels[i])
ax2.set(adjustable="box", aspect="equal")
ax1.set(adjustable="box", aspect="equal")

ax2.set_xlabel("Predicted Percentile")
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.legend(fontsize=8)

for i, ax in enumerate([ax_top_l, ax_top_r, ax_top_c]):
    param = ax_params[i]
    idx = indices[i]
    order_index = list(fitter.fitted_parameter_names).index(param)

    r2 = met["R_squared"][order_index]
    rmse = met["RMSE"][order_index]
    # Get 16/50/84 percentiles
    # p16 = np.percentile(samples[:, :, idx], 16, axis=0)
    # p50 = np.percentile(samples[:, :, idx], 50, axis=0)
    # p84 = np.percentile(samples[:, :, idx], 84, axis=0)
    # ax.errorbar(true[:, idx], p50, yerr=[p50 - p16, p84 - p50], fmt='o',
    #  alpha=0.1, color='C0', markersize=2)
    # Instead, plot as hexbin and count all draws
    # broadcast true to match samples shape

    true_full = np.repeat(true[:, idx][np.newaxis, :], samples.shape[0], axis=0).flatten()
    yy = samples[:, :, idx].flatten()
    # if i == 1:
    #    true_full = np.log10(true_full)
    #    yy = np.log10(yy)
    #    yy = np.clip(yy, -6, None)
    #    true_full = np.clip(true_full, -6, None)
    # if i == 2:
    #    yy = 10**yy
    #    true_full = 10**true_full
    ax.hexbin(
        true_full,
        yy,
        gridsize=50,
        cmap="cmr.sunburst_r",
        mincnt=1,
    )  # bins='log')
    ax.set(adjustable="box", aspect="equal")

    ax.set_xlabel("True Value")
    if i == 0:
        ax.set_ylabel("Inferred Value")
    ax.plot(
        [true_full.min(), true_full.max()], [true_full.min(), true_full.max()], "k--", alpha=0.5
    )
    ax.text(0.05, 0.90, f"${param_latex[i]}$", transform=ax.transAxes, fontsize=11)
    ax.text(0.05, 0.80, "RMSE = {:.2f}".format(rmse), transform=ax.transAxes, fontsize=10)
    ax.text(0.05, 0.70, f"$R^2={r2:.2f}$", transform=ax.transAxes, fontsize=10)

fig.savefig(
    f"/home/tharvey/work/synference/models/{model_name}/plots/{model_name}_{extra}_validation.png",
    dpi=300,
    bbox_inches="tight",
)
fig.savefig(f"plots/{model_name}_{extra}_validation.png", dpi=300, bbox_inches="tight")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Make text bigger

plt.rcParams.update({"font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14})


# CORRECTED FUNCTION
def calculate_quantile_error_per_param(
    true_params: np.ndarray, posteriors: np.ndarray
) -> np.ndarray:
    """Calculates a quantile-based error metric for each parameter of each observation.

    This version is corrected to handle posteriors with the shape
    (n_samples, n_points, n_dim), as indicated by the traceback.
    """
    # Your shapes:
    # true_params: (n_points, n_dim) -> (25000, 11)
    # posteriors:  (n_samples, n_points, n_dim) -> (1000, 25000, 11)

    # 1. Reshape true_params to (1, n_points, n_dim) to enable broadcasting.
    # This aligns the `n_points` and `n_dim` axes correctly.
    true_params_reshaped = true_params[np.newaxis, :, :]

    # 2. Perform the comparison. Broadcasting now works:
    # (1000, 25000, 11) < (1, 25000, 11) -> a boolean array of (1000, 25000, 11)
    comparison = posteriors < true_params_reshaped

    # 3. Take the mean over the `n_samples` axis (which is now axis=0).
    # This collapses the samples dimension, leaving a result of shape (n_points, n_dim).
    quantiles = np.mean(comparison, axis=0)

    errors = np.abs(quantiles - 0.5)

    return errors


def plot_colored_corner(
    data: np.ndarray,
    errors: np.ndarray,
    param_names: list[str] | None = None,
    elabel: str = "Prediction Error (|quantile - 0.5|)",
    params_to_remove: list[str] = [],
) -> None:
    """Creates a corner plot where scatter points are colored by the error metric."""
    n_dim = data.shape[1]
    # Remove specified parameters
    if params_to_remove:
        indices_to_keep = [i for i, name in enumerate(param_names) if name not in params_to_remove]
        data = data[:, indices_to_keep]
        errors = errors
        if param_names is not None:
            param_names = [param_names[i] for i in indices_to_keep]
        n_dim = data.shape[1]
    if param_names is None:
        param_names = [f"Param {i}" for i in range(n_dim)]

    fig, axes = plt.subplots(n_dim, n_dim, figsize=(12, 12))

    # Adjust subplots to be tight
    fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, hspace=0.1, wspace=0.1)

    for i in range(n_dim):
        for j in range(n_dim):
            ax = axes[i, j]

            # Diagonal: Plot 1D histograms of the true parameter distribution
            if i == j:
                sns.histplot(
                    data[:, i],
                    ax=ax,
                    bins=20,
                    kde=True,
                    element="step",
                    color="sandybrown",
                    alpha=0.7,
                )
                ax.set_ylabel("")
            # Hide upper triangle
            elif i < j:
                ax.axis("off")
            # Lower triangle: Plot 2D scatter plots
            else:
                # Color the points by the error metric
                """sc = ax.scatter(
                    data[:, j],
                    data[:, i],
                    cmap="cmr.ember",
                    vmin=np.percentile(errors, 5),
                    vmax=np.percentile(errors, 95),
                    s=5,
                    alpha=0.7,
                    c=errors,
                )"""
                # do a hexbin - not density just average error in each bin
                sc = ax.hexbin(
                    data[:, j],
                    data[:, i],
                    C=errors,
                    reduce_C_function=np.mean,
                    gridsize=25,
                    cmap="cmr.ember",
                    vmin=np.percentile(errors, 5),
                    vmax=np.percentile(errors, 95),
                    mincnt=1,
                )
            # Labels for the outer plots only
            if j == 0 and i > 0:
                ax.set_ylabel(param_names[i])
            if i == n_dim - 1:
                ax.set_xlabel(param_names[j])

            # Remove ticks for inner plots to clean up the view
            if i < n_dim - 1:
                ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
            if j > 0:
                ax.tick_params(axis="y", which="both", left=False, right=False, labelleft=False)

    # Add a colorbar
    cbar_ax = fig.add_axes([0.35, 0.7, 0.55, 0.02])
    cbar = fig.colorbar(sc, cax=cbar_ax, orientation="horizontal")
    cbar.set_label(elabel)

    # plt.suptitle("Corner Plot of True Parameters Colored by Prediction Error", fontsize=16, y=0.95) # noqa E501
    plt.show()


sample_path = (
    f"/home/tharvey/work/synference/models/{model_name}/plots/{extra}/posterior_samples.npy"
)
true = fitter._y_test

samples = np.load(sample_path)
log_prob_path = f"/home/tharvey/work/synference/models/{model_name}/plots/{extra}/true_logprobs.npy"

log_probs = np.load(log_prob_path)

errors = calculate_quantile_error_per_param(true, samples)
errors = np.mean(errors, axis=1)

params_to_remove = ["$\\beta$"]
plot_colored_corner(
    true,
    log_probs,
    param_names=param_labels,
    elabel="$\\log P(\\theta | x)$",
    params_to_remove=params_to_remove,
)

## Pick a few examples and compare posteriors with nested sampling and the truth

In [None]:
fitter._X_test[3435]

indexes = 99, 3435, 16, 9985, 7372

In [None]:
fitter._X_test[99]

In [None]:
index = 7372

sample = fitter._X_test[index]
true_params = fitter._y_test[index]
predicted_params = fitter.sample_posterior(sample, num_samples=1000)

fitter._prior = fitter.create_priors()

sampler = fitter.fit_observation_using_sampler(
    observation=sample,
    sampler="nautilus",
    sampler_kwargs={"n_live": 500},
    remove_params=["log10_mass_weighted_age", "log10_floor_sfr_10"],
)
points, log_w, log_l = sampler.posterior()

# Plot corner of both on same (see through, only contours) and true value

In [None]:
import corner

comparison_params = [
    "log_mass",
    "log10metallicity",
    "log10_Av",
    "log_sfr",
    "sfh_quantile_25",
    "sfh_quantile_50",
    "sfh_quantile_75",
]

display_names = [
    r"$\log M_*$",
    r"$\log Z/Z_\odot$",
    "$A_V$",
    r"$\log$ SFR",
    "$t_{25}$",
    "$t_{50}$",
    "$t_{75}$",
]

mask_idx = [True if p in comparison_params else False for p in fitter.fitted_parameter_names]
mask_idx = np.array(mask_idx, dtype=bool)

figure = corner.corner(
    points[:, :-2],
    weights=np.exp(log_w - log_w.max()),
    labels=display_names,
    color="#03A89E",
    plot_datapoints=False,
    fill_contours=False,
    contour_kwargs={"linewidths": 1.5, "linestyles": "solid"},
    contourf_kwargs={
        "alpha": 0.3  # Very light fill for contours
    },
    hist_kwargs={"alpha": 0.7, "histtype": "stepfilled", "edgecolor": "black", "linewidth": 1},
    levels=(0.68, 0.95),
    plot_density=False,
    smooth=1.0,
    label_kwargs={"fontsize": 22},
    title_kwargs={"fontsize": 22},
    quantiles=[0.16, 0.5, 0.84],
    show_titles=False,
    # Increase tick label size
    tick_label_kwargs={"fontsize": 16},
)

corner.corner(
    predicted_params[:, mask_idx],
    labels=display_names,
    color="#CD3700",
    plot_datapoints=False,
    fill_contours=False,
    contour_kwargs={"linewidths": 1.5, "linestyles": "solid"},
    contourf_kwargs={
        "alpha": 0.3  # Very light fill for contours
    },
    hist_kwargs={"alpha": 0.7, "histtype": "stepfilled", "edgecolor": "black", "linewidth": 1},
    plot_density=False,
    smooth=1.0,
    fig=figure,
    quantiles=[0.16, 0.5, 0.84],
    show_titles=False,
    tick_label_kwargs={"fontsize": 16},
)

# Add true value lines
ndim = len(comparison_params)
axes = np.array(figure.axes).reshape((ndim, ndim))
for i in range(ndim):
    for j in range(ndim):
        ax = axes[i, j]
        if i == j:
            ax.axvline(true_params[i], color="k", linestyle="dotted")
            # Add our own titles with quantiles for both
            samples_param = predicted_params[:, mask_idx][:, i]
            p16, p50, p84 = np.percentile(samples_param, [16, 50, 84])
            nested = points[:, :-2][:, i]
            # need to get weighted percentiles for nested sampling
            sorted_indices = np.argsort(nested)
            nested_sorted = nested[sorted_indices]
            weights_sorted = np.exp(log_w - log_w.max())[sorted_indices]
            cumulative_weights = np.cumsum(weights_sorted)
            cumulative_weights /= cumulative_weights[-1]  # Normalize to [0, 1]
            p16_ns = np.interp(0.16, cumulative_weights, nested_sorted)
            p50_ns = np.interp(0.50, cumulative_weights, nested_sorted)
            p84_ns = np.interp(0.84, cumulative_weights, nested_sorted)
            title = (
                rf"{comparison_params[i]}" + "\n"
                rf"SBI: ${p50:.2f}_{{-{p50 - p16:.2f}}}^{{+{p84 - p50:.2f}}}$" + "\n"
                rf"NS: ${p50_ns:.2f}_{{-{p50_ns - p16_ns:.2f}}}^{{+{p84_ns - p50_ns:.2f}}}$"
            )
            # ax.set_title(title, fontsize=14)
            ax.set_title("")
            # Approximate the title with textbox so we can set the color of SBI vs NS
            sbi = rf"SBI: ${p50:.2f}_{{-{p50 - p16:.2f}}}^{{+{p84 - p50:.2f}}}$"
            ns = rf"NS: ${p50_ns:.2f}_{{-{p50_ns - p16_ns:.2f}}}^{{+{p84_ns - p50_ns:.2f}}}$"

            ax.text(
                0.5,
                1.3,
                sbi,
                color="#CD3700",
                fontsize=22,
                ha="center",
                va="center",
                transform=ax.transAxes,
                fontweight="bold",
            )
            ax.text(
                0.5,
                1.1,
                ns,
                color="#03A89E",
                fontsize=22,
                ha="center",
                va="center",
                transform=ax.transAxes,
                fontweight="bold",
            )

        elif i > j:
            ax.axvline(true_params[j], color="k", linestyle="dotted")
            ax.axhline(true_params[i], color="k", linestyle="dotted")
            ax.plot(true_params[j], true_params[i], "ks")

        # Increase font size of tick labels if i or j is 0 or ndim-1
        if i == ndim - 1 or j == 0:
            ax.tick_params(axis="x", which="major", labelsize=16)
        if j == 0 or i == ndim - 1:
            ax.tick_params(axis="y", which="major", labelsize=16)

# figure.suptitle("Comparison of Posterior Samples: MDN vs Nested Sampling", fontsize=16)

# Add upper right floating axis for SED plot

new_ax = figure.add_axes([0.50, 0.65, 0.45, 0.35])

sfh_ax = figure.add_axes([0.75, 0.42, 0.22, 0.18])

# Get SEDs given these posterior samples.
results = fitter.recover_SED(
    sample,
    samples=predicted_params,
    plot=True,
    true_parameters=true_params,
    sample_color="#CD3700",
    fig=figure,
    ax=new_ax,
    plot_histograms=False,
    plot_sfh=True,
    ax_sfh=sfh_ax,
    num_samples=50,
)

# Do the same for nested sampling samples
# add two dimensions to fake points to match expected shape

points_dict = {name: points[:, i] for i, name in enumerate(comparison_params)}
# Not designed to deal with weights, so sample with replacement according to weights
weights = np.exp(log_w - log_w.max())
weights /= weights.sum()
indices = np.random.choice(len(weights), size=len(weights), replace=True, p=weights)
points_dict = {name: points_dict[name][indices] for name in points_dict}
results_ns = fitter.recover_SED(
    sample,
    samples=points_dict,
    plot=True,
    sample_color="#03A89E",
    fig=figure,
    ax=new_ax,
    plot_histograms=False,
    plot_sfh=True,
    ax_sfh=sfh_ax,
    num_samples=50,
)

new_ax.set_ylim(25, 22)
# move legend to lower right

# Increase font size of all text in new_ax and sfh_ax
for item in (
    new_ax.title,
    new_ax.xaxis.label,
    new_ax.yaxis.label,
    new_ax.get_xticklabels(),
    new_ax.get_yticklabels(),
    new_ax.get_legend().get_texts(),
    sfh_ax.title,
    sfh_ax.xaxis.label,
    sfh_ax.yaxis.label,
    sfh_ax.get_xticklabels(),
    sfh_ax.get_yticklabels(),
):
    if hasattr(item, "__iter__"):
        for subitem in item:
            subitem.set_fontsize(16)
    else:
        item.set_fontsize(16)

# Increase marker size for all markers in new_ax
for line in new_ax.get_lines():
    line.set_markersize(8)

# Change line labels in new_ax to "SBI Posterior" and "Nested Sampling"
lines = new_ax.get_lines()
for line in lines:
    if (
        line.get_color() == "#CD3700"
        and line.get_label() != "_nolegend_"
        and "_child" not in line.get_label()
    ):
        line.set_label("Predicted SBI SED")
    elif (
        line.get_color() == "#03A89E"
        and line.get_label() != "_nolegend_"
        and "_child" not in line.get_label()
    ):
        print(line.get_label())
        line.set_label("Predicted NS SED")


# Find labelled markers and rename them
for marker in new_ax.collections + new_ax.lines:
    if not hasattr(marker, "get_label"):
        continue
    if marker.get_label() == "Posterior Phot." and marker.get_color() == "#03A89E":
        marker.set_label("Predicted NS Phot.")
    elif marker.get_label() == "Posterior Phot." and marker.get_color() == "#CD3700":
        marker.set_label("Predicted SBI Phot.")

new_ax.legend(loc="lower right", fontsize=16)

figure.savefig(
    f"plots/posterior_comparison_{index}.png",
    dpi=300,
    bbox_inches="tight",
)

### 5 SED recoveries and true SEDs (different redshifts, masses, dust etc)

In [None]:
test_samples = []
for test_sample in test_samples:
    fitter.recover_SED(X_test=test_sample)

### SFH Recovery Examples

In [None]:
np.unique(fitter.simulator.unused_params)

In [None]:
random_idx = np.random.randint(0, len(fitter._X_test), size=100)

for idx in random_idx:
    sample = fitter._X_test[idx]
    true_params = fitter._y_test[idx]
    predicted_params = fitter.recover_SED(
        sample, num_samples=50, plot=True, true_parameters=true_params
    )

In [None]:
# Sort indexes for high log10_Av and low redshift
Av_idx = list(fitter.fitted_parameter_names).index("log10_Av")
z_idx = list(fitter.feature_names).index("redshift")
mask = (fitter._y_test[:, Av_idx] > 0.5) & (fitter._X_test[:, z_idx] < 0.5)

for i in range(len(fitter._X_test[mask])):
    s = fitter._X_test[mask][i]
    true_params = fitter._y_test[mask][i]
    fitter.recover_SED(s, num_samples=50, plot=True, true_parameters=true_params)