Figure showing sample of observational galaxies.    

In [None]:
%load_ext autoreload
%autoreload 2

import os

os.environ["PATH"] = "/usr/local/texlive/2025/bin/x86_64-linux:" + os.environ["PATH"]

import matplotlib.pyplot as plt
import numpy as np
from astropy.table import Table
from unyt import Jy

from synference import SBI_Fitter, create_uncertainty_models_from_EPOCHS_cat

plt.style.use("paper.style")

In [None]:
tab = Table.read(
    "/home/tharvey/work/synference/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"
)

In [None]:
mag = -2.5 * np.log10(tab["FLUX_TOTAL_F277W"]) + 31.4
z = tab["specz"].data
source = tab["spec_source"]

In [None]:
import cmasher
import matplotlib.gridspec as gridspec

# Prepare data
sources = np.unique(source)
colors = cmasher.take_cmap_colors("cmr.guppy", len(sources), cmap_range=(0.1, 0.9))
markers = ["o", "s", "D", "^", "v", "<", ">", "p", "*", "h", "H", "X"]
fig = plt.figure(figsize=(5, 4))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 4], hspace=0.05)

source_names = {
    "3d-hst grism": "3D-HST Grism",
    "3d-hst spec": "3D-HST Spec",
    "NIRSpec_prism": "NIRSpec",
    "MUSE": "MUSE",
    "ESO Master Catalog": "ESO Master Catalog",
}

# Histogram (top)
ax_hist = fig.add_subplot(gs[0])
for i, src in enumerate(sources):
    mask = source == src
    mask = np.squeeze(mask)
    ax_hist.hist(
        z[mask],
        bins=30,
        alpha=0.7,
        label=source_names.get(str(src), str(src)),
        color=colors[i % len(colors)],
        histtype="stepfilled",
    )
ax_hist.set_ylabel("Count")
# ax_hist.set_xticklabels([])
# hide x labels for histogram
ax_hist.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
# Scatter plot (bottom)
ax_scatter = fig.add_subplot(gs[1], sharex=ax_hist)
for i, src in enumerate(sources):
    mask = source == src
    mask = np.squeeze(mask)
    ax_scatter.scatter(
        z[mask],
        mag[mask],
        s=12,
        alpha=0.7,
        label=source_names.get(str(src), str(src)),
        color=colors[i % len(colors)],
        marker=markers[i % len(markers)],
        edgecolors="black",
        linewidths=0.2,
    )
ax_scatter.set_xlabel("Spectroscopic Redshift (z)")
ax_scatter.set_ylabel("F277W Flux (AB Mag)")
ax_scatter.legend(
    title=r"Spec-$z$ Source",
    loc="best",
    frameon=False,
    fancybox=False,
    ncol=1,
    edgecolor="black",
    fontsize=8,
)
ax_scatter.invert_yaxis()  # Magnitude axis is inverted
ax_scatter.set_xticks(np.arange(0, 15, 1))
plt.savefig("plots/obs_sample.png", dpi=300, bbox_inches="tight")
plt.show()

Noise model

In [None]:
model_name = "DenseBasis_v3_final_custom_small"
extra = "nsf_zfix"
data_err_file = "/home/tharvey/work/JADES-DR3-GS_MASTER_Sel-F277W+F356W+F444W_v13.fits"
fitter = SBI_Fitter.load_saved_model(
    f"/home/tharvey/work/synference/models/{model_name}/{model_name}_{extra}_posterior.pkl",
    grid_path="/home/tharvey/work/synference/libraries/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode.hdf5",
)
"""bands =["F435W", "F606W", "F775W", "F814W", "F850LP", 'F090W', 'F115W',
    'F150W', 'F200W', 'F277W', 'F335M', 'F356W', 'F410M', 'F444W']
hst_bands = ["F435W", "F606W", "F775W", "F814W", "F850LP"]
new_band_names = [
    (f"HST/ACS_WFC.{band.upper()}" if band in hst_bands else f"JWST/NIRCam.{band.upper()}")
        for band in bands
]"""
bands = ["F444W"]
new_band_names = [f"JWST/NIRCam.{band.upper()}" for band in bands]
empirical_noise_models = create_uncertainty_models_from_EPOCHS_cat(
    data_err_file,
    bands,
    new_band_names,
    plot=True,
    hdu="OBJECTS",
    save=False,
    min_flux_error=0,
    model_class="asinh",
)

In [None]:
fig, ax = plt.subplots(figsize=(4, 6), ncols=1, nrows=2, constrained_layout=True)
# One showing AB mag vs asinh mag, one showing asinh mag vs mag error

# 2D linspace for fill_between
ab = np.linspace(20, 35, 100)
ab_2d = np.repeat(ab[:, np.newaxis], 1000, axis=1)
ab_jy = 3631 * 10 ** (-0.4 * ab_2d) * Jy

out = []
for i in range(ab_2d.shape[0]):
    out.append(empirical_noise_models["JWST/NIRCam.F444W"].apply_noise(ab_jy[i, :]))

out = np.array(out)

mag = out[:, 0]
err = out[:, 1]

# do percentile for fill_between
p16 = np.percentile(mag, 16, axis=1)
p84 = np.percentile(mag, 84, axis=1)
p50 = np.percentile(mag, 50, axis=1)
p5 = np.percentile(mag, 5, axis=1)
p95 = np.percentile(mag, 95, axis=1)

ax[0].fill_between(ab, p16, p84, color="lightgray", alpha=0.5, label="16-84th percentile")
ax[0].fill_between(ab, p5, p95, color="gray", alpha=0.5, label="5-95th percentile")
ax[0].plot(ab, p50, color="black", linestyle="-", label="Median")

ax[0].set_xlabel("AB Magnitude [$m$]")
ax[0].set_ylabel("Asinh Magnitude [$m'$]")
ax[0].text(
    20,
    -2.5 * np.log10(empirical_noise_models["JWST/NIRCam.F444W"].b) + 8.0,
    "Softening parameter",
    rotation=0,
    verticalalignment="center",
    horizontalalignment="left",
)

ax[0].axhline(
    -2.5 * np.log10(empirical_noise_models["JWST/NIRCam.F444W"].b) + 8.9,
    color="black",
    linestyle="--",
    label="Softening parameter",
)


err_p50 = np.percentile(err, 50, axis=1)

ax[1].fill_between(
    mag[:, 0],
    np.percentile(err, 16, axis=1),
    np.percentile(err, 84, axis=1),
    color="lightgray",
    alpha=0.5,
    label="16-84th percentile",
)
ax[1].fill_between(
    mag[:, 0],
    np.percentile(err, 5, axis=1),
    np.percentile(err, 95, axis=1),
    color="gray",
    alpha=0.5,
    label="5-95th percentile",
)
ax[1].plot(mag[:, 0], err_p50, color="black", linestyle="-", label=r"Median $\sigma(m)$")

mag = np.random.uniform(20, 35, 10000)
ab_jy = 3631 * 10 ** (-0.4 * mag) * Jy
mag, err = empirical_noise_models["JWST/NIRCam.F444W"].apply_noise(ab_jy)
# Draw black see through contours, not hexbin
contoru_data = np.histogram2d(mag, err, bins=100)
ax[1].contourf(
    contoru_data[1][:-1],
    contoru_data[2][:-1],
    contoru_data[0].T,
    levels=10,
    cmap="Greys",
    alpha=0.3,
    norm="log",
)


ax[1].set_xlabel("Asinh Magnitude [$m'$]")
ax[1].set_ylabel(r"Asinh Magnitude Error [$\sigma'$]")

band = "F444W"
table = Table.read(data_err_file, hdu="OBJECTS")
from astropy import units as u

flux = (u.Jy * table[f"FLUX_APER_{band}_aper_corr_Jy"]).to("Jy").value
flux_err = (table[f"loc_depth_{band}"] * u.ABmag).to("Jy").value / 5
converted_mag, converted_mag_err = empirical_noise_models["JWST/NIRCam.F444W"].apply_scalings(
    flux, flux_err
)
# Pick 1000 random points to overplot
rand = np.random.choice(len(converted_mag), size=1000, replace=False)
converted_mag = converted_mag[rand]
converted_mag_err = converted_mag_err[rand]
plt.scatter(
    converted_mag,
    converted_mag_err,
    alpha=0.5,
    color="steelblue",
    s=1,
    zorder=1,
    label="Catalog photometry",
)
ax[1].legend(frameon=False, fancybox=False, edgecolor="black", fontsize=11)
ax[1].set_xlim(None, 30)
fig.savefig("plots/asinh_noise_model.png", dpi=300, bbox_inches="tight")

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np

plt.style.use("paper.style")


def create_speed_plot():
    """Bar plot comparing CPU and GPU inference times for different techniques."""
    # Data structure containing the median, 16th, and 84th percentiles for timing.
    data = {
        "MDN": {"CPU": [0.00355, 0.00365, 0.00384], "GPU": [0.0017, 0.0018, 0.0019]},
        "NSF": {"GPU": [0.04624, 0.0470, 0.04740], "CPU": [0.06871, 0.06895, 0.07299]},
        # "MC NSF": {"GPU": [1.8, 2.1, 2.5], "CPU": [1.9, 2.2, 2.8]},
        "EAZY-py\n(Photo-$z$ only)": {
            "CPU": [0.018, 0.018, 0.018]
        },  # Corrected middle value to match others
        "Bagpipes": {"CPU": [55.052, 101.4, 139.924]},
    }

    # --- Plotting Setup ---
    # Using a professional plot style
    # plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(7, 5), dpi=150)
    # Increase axis line width for better visibility
    for axis in ["top", "bottom", "left", "right"]:
        ax.spines[axis].set_linewidth(1.2)
        ax.spines[axis].set_color("black")

    techniques = list(data.keys())
    x = np.arange(len(techniques))

    # Using a colormap for distinct, professional colors for each technique
    colors = plt.cm.RdYlBu(np.linspace(0, 1, len(techniques) + 2))

    bar_width = 0.35
    cpu_hatch = ""
    gpu_hatch = "///"

    # --- Bar Plotting Loop ---
    offset = 0
    for i, (technique, color) in enumerate(zip(techniques, colors)):
        # Plot GPU bar if data exists
        if "GPU" in data[technique]:
            gpu_data = data[technique]["GPU"]
            median, p16, p84 = gpu_data[1], gpu_data[0], gpu_data[2]
            error = [[median - p16], [p84 - median]]
            print(error)
            gpu_bar = ax.bar(
                x[i] - bar_width / 2,
                median,
                bar_width,
                label=f"{technique} GPU",
                color=color,
                yerr=error,
                capsize=5,
                hatch=gpu_hatch,
                edgecolor="black",
                alpha=0.70,
            )
            # Add value label on top of the bar
            if median >= 1:
                ax.bar_label(gpu_bar, labels=[f"{median:.1f}s"], padding=3, fontsize=12)
            else:
                ax.bar_label(gpu_bar, labels=[f"{median:.2g}s"], padding=3, fontsize=12)
        else:
            # Adjust to remove gap if no GPU data
            offset = -bar_width / 2
        # Plot CPU bar if data exists
        if "CPU" in data[technique]:
            cpu_data = data[technique]["CPU"]
            median, p16, p84 = cpu_data[1], cpu_data[0], cpu_data[2]
            error = [[median - p16], [p84 - median]]
            print(technique)
            print(error)

            cpu_bar = ax.bar(
                x[i] + bar_width / 2 + offset,
                median,
                bar_width,
                label=f"{technique} CPU",
                color=color,
                yerr=error,
                capsize=5,
                hatch=cpu_hatch,
                edgecolor="black",
                alpha=0.65,
            )
            # Add value label on top of the bar
            if median >= 1:
                ax.bar_label(cpu_bar, labels=[f"{median:.1f}s"], padding=3, fontsize=12)
            else:
                ax.bar_label(cpu_bar, labels=[f"{median:.2g}s"], padding=3, fontsize=12)

    # --- Axes and Labels Configuration ---
    ax.set_yscale("log")
    ax.set_ylabel("Inference Time per Object [seconds]", fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(techniques, fontsize=12, rotation=0, ha="center")

    # Improve grid and spine appearance
    ax.grid(axis="x", linestyle="", which="both")  # Remove vertical grid lines
    ax.grid(axis="y", linestyle="--", which="major", alpha=0.7)

    # --- Custom Legend for Hatching ---
    cpu_patch = mpatches.Patch(
        facecolor="grey", hatch=cpu_hatch, edgecolor="black", label="CPU: Single Core"
    )
    gpu_patch = mpatches.Patch(
        facecolor="grey", hatch=gpu_hatch, edgecolor="black", label="GPU: Nvidia H100 95GB"
    )

    # Add a slight background fill for all SBI methods
    ax.axvspan(-0.5, 1.5, color="lightblue", alpha=0.2, ymax=0.81)

    ax.text(
        0.5, 45, "Synference", horizontalalignment="center", fontsize=18, color="blue", alpha=0.7
    )

    # Adjust layout to prevent labels from being cut off
    plt.tight_layout()

    # Add a manual bar for Prospector which is cutoff at the top

    ax.set_ylim(ax.get_ylim())

    ax.bar(
        4,
        30 * 60,
        bar_width,
        label="Parrot\n(Prospector\nemulator)",
        color=colors[-2],
        edgecolor="black",
        alpha=0.65,
        hatch=cpu_hatch,
    )
    # Add value label (where we can see it) with upward arrow
    ax.text(
        4,
        100,
        "$\\sim 30$\nmin",
        ha="center",
        va="top",
        fontsize=12,
        color="black",
    )
    ax.bar(
        5,
        100000,
        bar_width,
        label="Prospector CPU",
        color=colors[-1],
        edgecolor="black",
        alpha=0.65,
        hatch=cpu_hatch,
    )
    # Add value label (where we can see it) with upward arrow
    ax.text(
        5,
        100,
        "$>$10\nhours",
        ha="center",
        va="top",
        fontsize=12,
        color="black",
    )
    ax.annotate(
        "",
        xy=(5, 260),
        xytext=(5, 120),
        arrowprops=dict(facecolor="black", arrowstyle="-|>"),
    )
    # Add to x-labels
    ax.set_xticks(list(ax.get_xticks()) + [4, 5])
    ax.set_xticklabels(
        list(techniques) + ["Parrot\n(Prospector\nemulator)", "Prospector"],
        fontsize=12,
        rotation=0,
        ha="center",
    )
    # Show a logo - just float on the plot, doesn't need to be actually anchored

    path = "/home/tharvey/work/synference/docs/source/gfx/synference_logo.png"
    logo = plt.imread(path)

    # Add a fake axis to place the logo
    new_ax = fig.add_axes([0.12, 0.44, 0.25, 0.25], anchor="NE", zorder=10)
    new_ax.imshow(logo, aspect=logo.shape[1] / logo.shape[0])
    new_ax.axis("off")  # Hide the axis
    ax.set_ylim(None, 2200)
    line = ax.axhline(
        1 / 11.57, color="black", linestyle="--", zorder=-1, label="Euclid Average Detection Rate"
    )

    ax.legend(handles=[cpu_patch, gpu_patch, line], fontsize=12, loc="upper left")
    # ax.text(
    #    2.16,
    #    1/11.57 * 1.4,
    #    "Speed required\nto match Euclid",
    #   horizontalalignment="center",
    #    fontsize=10,
    #    color="black",
    # )
    # Save the figure
    plt.savefig("plots/speed_plot.png", dpi=300, bbox_inches="tight")

    # Display the plot
    plt.show()


create_speed_plot()