## Inference with an SBI Model

So you have a trained SBI model, now what? In this tutorial, we will cover how to use the trained model for inference on new data. We will cover the following topics:

1. Loading the trained model
2. Preparing new data for inference
3. Running inference with the model
4. Analyzing the results

Let's get started!

# Loading the Trained Model 

First, we need to load the trained model. We can do this using the `SBI_Fitter` class from the `sbifitter` package.

In [None]:
import numpy as np

from sbifitter import SBI_Fitter

In [None]:
fitter = SBI_Fitter.load_saved_model("test/test_test_1_posterior.pkl", device="cpu")

# Observational Data for Inference

Now that we have a trained model, we need some observational data to perform inference on. This data should have the same features as the data used to train the model. This model was trained on wide-band NIRCam photometry, so we will use some example photometry data for inference.

Below is an example of real observational data for five galaxies in the JADES GOODS-South field, each with photometry in eight different bands. We have also provided the spectroscopic redshifts for these galaxies, which can be used to fix the redshift during inference.

In [None]:
wavs = [0.7, 0.9, 1.15, 1.5, 2.0, 2.77, 3.56, 4.44]

observed_data = [
    [25.69, 25.42, 25.46, 25.30, 24.98, 25.50, 25.42, 25.49],
    [21.25, 21.13, 21.07, 20.95, 20.91, 20.95, 21.42, 21.67],
    [22.36, 22.20, 21.68, 21.21, 20.94, 20.75, 20.45, 20.31],
    [24.95, 21.98, 21.23, 20.82, 20.50, 20.33, 20.28, 20.43],
    [24.26, 23.89, 23.59, 23.22, 23.04, 22.82, 22.83, 22.67],
]

specz = [2.37, 0.43, 1.82, 1.12, 2.80]

observed_data = np.array(observed_data)

# plot the observed data
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
for i in range(len(observed_data)):
    plt.plot(
        wavs, observed_data[i], marker="o", label=f"Object {i + 1}, z={specz[i]}", linestyle="None"
    )
plt.gca().invert_yaxis()
plt.xlabel("Wavelength (microns)", fontsize=14)
plt.ylabel("Magnitude", fontsize=14)
plt.title("Observed Data", fontsize=16)
plt.legend(fontsize=12)

In [None]:
import corner

for i in range(observed_data.shape[0]):
    samples = fitter.sample_posterior(X_test=observed_data[i])

    corner.corner(
        samples,
        labels=fitter.fitted_parameter_names,
        show_titles=True,
        title_fmt=".3f",
        title_kwargs={"fontsize": 12},
        label_kwargs={"fontsize": 14},
        quantiles=[0.16, 0.5, 0.84],
        truth_color="r",
    )

### SED Recovery

From our posterior samples we can 

In [None]:
fitter.recreate_simulator_from_grid()
from astropy.cosmology import Planck18 as cosmo
from unyt import Myr

fitter.simulator.param_transforms["max_age"] = (
    "max_age",
    lambda x: cosmo.age(x["redshift"]).to_value("Myr") * Myr,
)

In [None]:
for i in range(observed_data.shape[0]):
    samples = fitter.sample_posterior(X_test=observed_data[i])
    fitter.recover_SED(X_test=observed_data[i], samples=samples)

In [None]:
from astropy.table import Table

tab = Table.read(
    "/home/tharvey/work/sbifitter/priv/data/cats/JADES_DR3_GS_Matched_Specz_total_flux_good_filt.fits"
)
filters = [i.split(".")[-1] for i in fitter.raw_observation_names]

col_names = [f"FLUX_TOTAL_{filt}" for filt in filters]
filt = tab[col_names]

filt = filt[np.isfinite(filt["FLUX_TOTAL_F070W"])]


random_indices = np.random.choice(len(filt), size=5, replace=False)
random = filt[random_indices]

array = np.array(random.as_array().data.tolist())

phot = -2.5 * np.log10(array) + 31.4


np.set_printoptions(precision=2, floatmode="fixed", suppress=True)
print(phot)

specz = tab["specz"][np.isfinite(tab["FLUX_TOTAL_F070W"])][random_indices].data