In [None]:
%load_ext autoreload
%autoreload 2

import os

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scoresbibm.utils.data_utils import (
    load_model,
)

rng_seed_for_training = 42
master_rng_key = jax.random.PRNGKey(rng_seed_for_training)

file_path = os.path.dirname(os.path.realpath(__file__))
grid_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "grids")
output_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "models")

In [None]:
def load_full_model(dir_path, model_id, simulator=None):
    """Load a full model from the specified directory and model ID."""
    from joblib import load
    from scoresbibm.tasks import get_task
    from scoresbibm.utils.edge_masks import get_edge_mask_fn
    from simformer import GalaxyPrior

    model = load_model(
        dir_path=dir_path,
        model_id=model_id,
    )

    try:
        meta = load(f"{dir_path}/models/data_{model_id}.joblib")

        model.__dict__.update(meta)
        task_name = model.edge_mask_fn_params.get("task")
        task = get_task(task_name)
        task.__dict__.update(meta)

        task.prior_dist = GalaxyPrior(
            prior_ranges=task.prior_dict, param_order=task.param_names_ordered
        )

        model.edge_mask_fn = get_edge_mask_fn(model.edge_mask_fn_params["name"], task)

        if simulator is not None:
            model.simulator = simulator
        else:
            print(
                "No simulator provided. Please provide a simulator to use with the model."
            )
    except FileNotFoundError:
        pass

    return model

In [None]:
reloaded_model = load_full_model(
    dir_path=f"{output_folder}/simformer/",
    model_id="simformer_galaxy_score_model_test_v2",
)

In [None]:
x = [
    236.549,
    83.717,
    37.53,
    36.727,
    36.302,
    36.072,
    35.91,
    35.483,
    35.272,
    35.114,
    34.698,
    34.554,
    34.172,
    33.674,
]
theta = [11.181, 8.2, -2.267, 1.957, 452.194, 823.456, 0.531]
theta_dim = len(theta)
x_dim = len(x)

x_o = jnp.array([x], dtype=jnp.float32)

posterior_condition_mask = jnp.array(
    [0] * theta_dim + [1] * reloaded_model._x_dim, dtype=jnp.bool_
)

samples = reloaded_model.sample_batched(
    num_samples=1000,
    x_o=x_o,
    condition_mask=posterior_condition_mask,
    rng=master_rng_key,
)

In [None]:
import corner
import numpy as np

samples = np.array(samples, dtype=np.float32)

corner.corner(
    samples,
    labels=reloaded_model.param_names_ordered,
    show_titles=True,
    truths=theta,
    quantiles=[0.16, 0.5, 0.84],
    title_kwargs={"fontsize": 12},
);

In [None]:
filter_codes = [
    "JWST/NIRCam.F090W",
    "JWST/NIRCam.F115W",
    "JWST/NIRCam.F150W",
    "JWST/NIRCam.F162M",
    "JWST/NIRCam.F182M",
    "JWST/NIRCam.F200W",
    "JWST/NIRCam.F210M",
    "JWST/NIRCam.F250M",
    "JWST/NIRCam.F277W",
    "JWST/NIRCam.F300M",
    "JWST/NIRCam.F335M",
    "JWST/NIRCam.F356W",
    "JWST/NIRCam.F410M",
    "JWST/NIRCam.F444W",
]

x = np.array(
    [
        236.549,
        83.717,
        37.53,
        36.727,
        36.302,
        36.072,
        35.91,
        35.483,
        35.272,
        35.114,
        34.698,
        34.554,
        34.172,
        33.674,
    ]
)
theta = [11.181, 8.2, -2.267, 1.957, 452.194, 823.456, 0.531]

mask = [0 if code.endswith("W") else 1 for code in filter_codes]
mask_bool = np.array(mask, dtype=np.bool_)

x = x[~mask_bool]
x = jnp.array([x], dtype=jnp.float32)
# sample missing photometry as well
posterior_condition_mask = jnp.array([0] * theta_dim + mask, dtype=jnp.bool_)
crash
samples_nomedium = reloaded_model.sample_batched(
    num_samples=1000,
    x_o=x,
    condition_mask=posterior_condition_mask,
    rng=master_rng_key,
)

samples_nomedium = np.array(samples_nomedium[0], dtype=np.float32)

theta_samples = samples_nomedium[:, :theta_dim]
phot_samples = samples_nomedium[:, theta_dim:]

In [None]:
x

Corner plot with both

In [None]:
samples_nomedium = np.array(samples_nomedium, dtype=np.float32)

fig = corner.corner(
    samples,
    labels=reloaded_model.param_names_ordered,
    show_titles=True,
    truths=theta,
    quantiles=[0.16, 0.5, 0.84],
    title_kwargs={"fontsize": 12},
    color="blue",
    label="Full samples",
)
# add medium samples
corner.corner(
    theta_samples,
    labels=reloaded_model.param_names_ordered,
    show_titles=True,
    truths=theta,
    quantiles=[0.16, 0.5, 0.84],
    title_kwargs={"fontsize": 12},
    fig=fig,
    color="orange",
)
# Add mock label with orange and blue

plt.legend(
    handles=[
        plt.Line2D([0], [0], color="blue", lw=4, label="All NIRCam bands"),
        plt.Line2D([0], [0], color="orange", lw=4, label="Only NIRCam widebands"),
    ],
    bbox_to_anchor=(-1, 4),
    loc="upper left",
    fontsize=12,
)
# plt.title("Posterior samples with and without medium photometry", fontsize=14)
plt.show()

In [None]:
x = np.array(
    [
        236.549,
        83.717,
        37.53,
        36.727,
        36.302,
        36.072,
        35.91,
        35.483,
        35.272,
        35.114,
        34.698,
        34.554,
        34.172,
        33.674,
    ]
)

missing_phot = np.array(x)[mask_bool]

print(missing_phot)

missing_phot_names = np.array(filter_codes)[mask_bool]
fig, axes = plt.subplots(1, len(missing_phot), figsize=(20, 4))
for i, ax in enumerate(axes):
    (1, 1000, 7)
    samples_i = phot_samples[:, i]
    print(len(samples_i))
    ax.hist(
        samples_i,
        bins=50,
        density=True,
        alpha=0.5,
        label=f"Posterior samples for {missing_phot_names[i]}",
    )
    ax.set_title(f"{missing_phot_names[i]}")
    ax.axvline(x=missing_phot[i], color="red", linestyle="--", label="True value")
    # ax.set_xlim([missing_phot[i] - 0.5, missing_phot[i] + 0.5])
    ax.set_xlabel("Flux (ABmag)")

In [None]:
samples_nomedium.shape