### Model Testing Figures for Paper

In [None]:
import os

os.environ["PATH"] = "/usr/local/texlive/2025/bin/x86_64-linux:" + os.environ["PATH"]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from synference import SBI_Fitter

plt.style.use("paper.style")

model_name = "PBP_LogNorm"
extra = "mdn_noisy_zfix_v2"


fitter = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl"
)

Figure 2: Calibration and recovery.

In [None]:
fitter.plot_coverage()

## Calculate TARP

In [None]:
import tarp

sample_path = f"/home/tharvey/work/synference/models/{model_name}/plots/posterior_samples.npy"
true = fitter._y_test

samples = np.load(sample_path)


out = tarp.get_tarp_coverage(
    samples,
    true,
    norm=True,
    bootstrap=True,
    num_bootstrap=100,
)


## Load metrics

In [None]:
metrics = f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_metrics.json"

import json

with open(metrics, "r") as f:
    met = json.load(f)

In [None]:
# import gridspec

fig = plt.figure(figsize=(9, 9), dpi=100)  # constrained_layout=True)

gs = fig.add_gridspec(2, 6, height_ratios=[1, 1.44], hspace=-0.15, wspace=0.65)

ax_top_l = fig.add_subplot(gs[0, :2])
ax_top_r = fig.add_subplot(gs[0, 2:4])
ax_top_c = fig.add_subplot(gs[0, 4:6])
ax1 = fig.add_subplot(gs[1, :3])
ax2 = fig.add_subplot(gs[1, 3:])
axs = [ax_top_l, ax_top_r, ax_top_c]


# TARP and Coverage
ax1.plot([0, 1], [0, 1], "k--", label="Ideal", alpha=0.5)
ax1.set_xlabel("Predicted Percentile")
ax1.set_ylabel("Empirical Percentile")
ax1.text(0.05, 0.90, "TARP", transform=ax1.transAxes, fontsize=16)

ax2.plot([0, 1], [0, 1], "k--", label="Ideal", alpha=0.5)
ax2.set_xlabel("Credibility Level")
ax2.set_ylabel("Expected Coverage")
ax2.text(0.05, 0.90, "SBC", transform=ax2.transAxes, fontsize=16)

ax_params = ["log_mass", "sfr_10", "log10_Av"]
param_latex = [
    r"\log(\rm M/M_\odot)",
    r"\log(\rm SFR_{10 \rm Myr}/ M_\odot yr^{-1})",
    r"\log(\rm A_V)",
]
indices = [list(fitter.fitted_parameter_names).index(p) for p in ax_params]

# ax_top_r.set_xscale('log')
# ax_top_r.set_yscale('log')
# ax_top_r.set_xlim(1e-4, 1e3)
# ax_top_r.set_ylim(1e-4, 1e3)
# TARP

ecp, alpha = out[0], out[1]
ecp_mean = np.mean(ecp, axis=0)
ecp_std = np.std(ecp, axis=0)
ax1.plot(alpha, ecp_mean, label="TARP", color="b")
ax1.fill_between(alpha, ecp_mean - ecp_std, ecp_mean + ecp_std, alpha=0.5, color="b")
ax1.fill_between(alpha, ecp_mean - 2 * ecp_std, ecp_mean + 2 * ecp_std, alpha=0.2, color="b")


# SBC
trues = fitter._y_test

ndata, npars = trues.shape
ranks = (samples < trues[None, ...]).sum(axis=0)

unicov = [np.sort(np.random.uniform(0, 1, ndata)) for j in range(200)]
unip = np.percentile(unicov, [5, 16, 84, 95], axis=0)

cdf = np.linspace(0, 1, len(ranks))
for i in range(npars):
    xr = np.sort(ranks[:, i])
    xr = xr / xr[-1]
    ax2.plot(cdf, cdf, "k--")

    ax2.plot(xr, cdf, lw=2, label=fitter.fitted_parameter_names[i])
ax2.set(adjustable="box", aspect="equal")
ax1.set(adjustable="box", aspect="equal")

ax2.set_xlabel("Predicted Percentile")
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.legend()

for i, ax in enumerate([ax_top_l, ax_top_r, ax_top_c]):
    param = ax_params[i]
    idx = indices[i]
    # Get 16/50/84 percentiles
    # p16 = np.percentile(samples[:, :, idx], 16, axis=0)
    # p50 = np.percentile(samples[:, :, idx], 50, axis=0)
    # p84 = np.percentile(samples[:, :, idx], 84, axis=0)
    # ax.errorbar(true[:, idx], p50, yerr=[p50 - p16, p84 - p50], fmt='o',
    #  alpha=0.1, color='C0', markersize=2)
    # Instead, plot as hexbin and count all draws
    # broadcast true to match samples shape

    true_full = np.repeat(true[:, idx][np.newaxis, :], samples.shape[0], axis=0).flatten()
    yy = samples[:, :, idx].flatten()
    if i == 1:
        true_full = np.log10(true_full)
        yy = np.log10(yy)
        yy = np.clip(yy, -6, None)
        true_full = np.clip(true_full, -6, None)
    ax.hexbin(
        true_full,
        yy,
        gridsize=50,
        cmap="cmr.sunburst_r",
        mincnt=1,
    )  # bins='log')
    ax.set(adjustable="box", aspect="equal")

    ax.set_xlabel("True Value")
    if i == 0:
        ax.set_ylabel("Inferred Value")
    ax.plot(
        [true_full.min(), true_full.max()], [true_full.min(), true_full.max()], "k--", alpha=0.5
    )
    ax.text(0.05, 0.90, f"${param_latex[i]}$", transform=ax.transAxes, fontsize=11)

fig.savefig(
    f"/home/tharvey/work/synference/models/{model_name}/plots/{model_name}_{extra}_validation.png",
    dpi=300,
    bbox_inches="tight",
)
fig.savefig(f"plots/{model_name}_{extra}_validation.png", dpi=300, bbox_inches="tight")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


# CORRECTED FUNCTION
def calculate_quantile_error_per_param(
    true_params: np.ndarray, posteriors: np.ndarray
) -> np.ndarray:
    """Calculates a quantile-based error metric for each parameter of each observation.

    This version is corrected to handle posteriors with the shape
    (n_samples, n_points, n_dim), as indicated by the traceback.
    """
    # Your shapes:
    # true_params: (n_points, n_dim) -> (25000, 11)
    # posteriors:  (n_samples, n_points, n_dim) -> (1000, 25000, 11)

    # 1. Reshape true_params to (1, n_points, n_dim) to enable broadcasting.
    # This aligns the `n_points` and `n_dim` axes correctly.
    true_params_reshaped = true_params[np.newaxis, :, :]

    # 2. Perform the comparison. Broadcasting now works:
    # (1000, 25000, 11) < (1, 25000, 11) -> a boolean array of (1000, 25000, 11)
    comparison = posteriors < true_params_reshaped

    # 3. Take the mean over the `n_samples` axis (which is now axis=0).
    # This collapses the samples dimension, leaving a result of shape (n_points, n_dim).
    quantiles = np.mean(comparison, axis=0)

    errors = np.abs(quantiles - 0.5)

    return errors


def plot_colored_corner(
    data: np.ndarray, errors: np.ndarray, param_names: list[str] | None = None
) -> None:
    """Creates a corner plot where scatter points are colored by the error metric."""
    n_dim = data.shape[1]
    if param_names is None:
        param_names = [f"Param {i}" for i in range(n_dim)]

    fig, axes = plt.subplots(n_dim, n_dim, figsize=(12, 12))

    # Adjust subplots to be tight
    fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, hspace=0.1, wspace=0.1)

    for i in range(n_dim):
        for j in range(n_dim):
            ax = axes[i, j]

            # Diagonal: Plot 1D histograms of the true parameter distribution
            if i == j:
                sns.histplot(data[:, i], ax=ax, bins=20, kde=True)
            # Hide upper triangle
            elif i < j:
                ax.axis("off")
            # Lower triangle: Plot 2D scatter plots
            else:
                # Color the points by the error metric
                sc = ax.scatter(
                    data[:, j],
                    data[:, i],
                    cmap="viridis",
                    s=5,
                    alpha=0.7,
                    vmin=0,
                    vmax=0.5,
                    c=errors,
                )

            # Labels for the outer plots only
            if j == 0 and i > 0:
                ax.set_ylabel(param_names[i])
            if i == n_dim - 1:
                ax.set_xlabel(param_names[j])

            # Remove ticks for inner plots to clean up the view
            if i < n_dim - 1:
                ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
            if j > 0:
                ax.tick_params(axis="y", which="both", left=False, right=False, labelleft=False)

    # Add a colorbar
    cbar_ax = fig.add_axes([0.65, 0.7, 0.25, 0.02])
    cbar = fig.colorbar(sc, cax=cbar_ax, orientation="horizontal")
    cbar.set_label("Prediction Error (|quantile - 0.5|)")

    plt.suptitle("Corner Plot of True Parameters Colored by Prediction Error", fontsize=16, y=0.95)
    plt.show()


errors = calculate_quantile_error_per_param(true, samples)
errors = np.mean(errors, axis=1)
plot_colored_corner(true, errors, param_names=fitter.fitted_parameter_names.tolist())

In [None]:
errors

## Pick a few examples and compare posteriors with nested sampling and the truth

In [None]:
sample = fitter._X_test[99]
true_params = fitter._y_test[99]
predicted_params = fitter.sample_posterior(sample, num_samples=1000)


sampler = fitter.fit_observation_using_sampler(
    observation=sample, sampler="nautilus", sampler_kwargs={"n_live": 500}
)
points, log_w, log_l = sampler.posterior()

# Plot corner of both on same (see through, only contours) and true value

import corner

figure = corner.corner(
    points,
    weights=np.exp(log_w - log_w.max()),
    labels=fitter.fitted_parameter_names,
    color="C1",
    plot_datapoints=False,
    fill_contours=True,
    levels=(0.68, 0.95),
    plot_density=True,
    smooth=1.0,
    label_kwargs={"fontsize": 16},
    title_kwargs={"fontsize": 14},
    quantiles=[0.16, 0.5, 0.84],
    show_titles=True,
)

corner.corner(
    predicted_params,
    labels=fitter.fitted_parameter_names,
    color="C0",
    plot_datapoints=False,
    fill_contours=False,
    plot_density=True,
    smooth=1.0,
    fig=figure,
    quantiles=[0.16, 0.5, 0.84],
    show_titles=True,
)

# Add true value lines
ndim = true_params.shape[0]
axes = np.array(figure.axes).reshape((ndim, ndim))
for i in range(ndim):
    for j in range(ndim):
        ax = axes[i, j]
        if i == j:
            ax.axvline(true_params[i], color="k", linestyle="--")
        elif i > j:
            ax.axvline(true_params[j], color="k", linestyle="--")
            ax.axhline(true_params[i], color="k", linestyle="--")
            ax.plot(true_params[j], true_params[i], "ks")

figure.suptitle("Comparison of Posterior Samples: MDN vs Nested Sampling", fontsize=16)
figure.savefig(
    f"/home/tharvey/work/synference/models/{model_name}/plots/{model_name}_{extra}_posterior_comparison.png",
    dpi=300,
    bbox_inches="tight",
)

### 5 SED recoveries and true SEDs (different redshifts, masses, dust etc)

### SFH Recovery Examples