# Catalogue Fitting

Synference has a method specifically for fitting catalogues of sources - `fit_catalogue`.

This is designed to be a flexible and fast way to fit your full catalogue.

It can handle:
1. Transforming observations to match the expected model features.
2. Prior predictive checks/outlier detection to remove sources that are unlikely to be well modelled.
3. Optional imputation of missing data.
4. Rapid posterior inference using trained model.
5. Optional SED recovery (see [this notebook](sed_recovery.ipynb) for more details).
6. Returning structured results.

In [None]:
import os

import numpy as np
from IPython.display import display
from synthesizer import get_grids_dir
from unyt import Jy

from synference import SBI_Fitter, load_unc_model_from_hdf5, test_data_dir

print(get_grids_dir())

available_grids = os.listdir(get_grids_dir())
if "test_grid.hdf5" not in available_grids:
    cmd = f"synthesizer-download --test-grids --destination {get_grids_dir()}"
    os.system(cmd)

library_path = os.path.join(get_grids_dir(), "test_grid.hdf5")

In [None]:
?SBI_Fitter.fit_catalogue

We'll use a subset of the JADES spectroscopic catalogue used in the synference paper for this example.

In [None]:
from astropy.table import Table

cat = Table.read(f"{test_data_dir}/jades_spec_catalogue_subset.fits")

The first step is the same as our SED recovery example - we will load our trained model and the noise model used to train the model. We need the noise model to transform our observations to match the model features.

In [None]:
library_path = (
    f"{test_data_dir}/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_2.7_Calzetti_v3_multinode.hdf5"  # noqa: E501
)

fitter = SBI_Fitter.load_saved_model(
    model_file=f"{test_data_dir}", library_path=library_path, device="cpu"
)

nm_path = f"{test_data_dir}/BPASS_DenseBasis_v4_final_nsf_0_params_empirical_noise_models.h5"
noise_models = load_unc_model_from_hdf5(nm_path)

fitter.feature_array_flags["empirical_noise_models"] = noise_models

Now we need to explain our catalogue to synference. We create a 'conversion_dict' dictionary which maps the feature names to columns in the table.
**The expected column names for flux are the SVO format filter names, e.g. JWST/NIRCam.F070W, or just F070W if it is unambiguous.**
**For flux errors it is the same, but with 'unc_' prefixed, e.g. unc_JWST/NIRCam.F070W or unc_F070W.**
**Any other features (e.g. redshift) should be mapped to the column name in the catalogue.**


In this case our catalogue is almost in the correct format, so we just need to convert the band names to include the facility and instrument. We do this by getting the band names from the fitted model (`fitter.feature_names`) and replacing the relevant parts of the strings.

We then create our conversion dictionary to map from e.g. 'unc_F200W' to 'unc_JWST/NIRCam.F200W'.

In [None]:
def band_to_instrument(band):
    """Band name to instrument mapping."""
    if band in ["F435W", "F606W", "F775W", "F814W", "F850LP"]:
        return f"HST/ACS_WFC.{band}"
    return f"JWST/NIRCam.{band}"


bands = [
    i.split(".")[-1]
    for i in fitter.feature_names
    if not (i.startswith("unc_") or i.startswith("redshift"))
]

print(bands)

conversion_dict = {band: band_to_instrument(band) for band in bands}
conversion_dict.update({f"unc_{band}": f"unc_{band_to_instrument(band)}" for band in bands})
conversion_dict["redshift"] = "redshift"

conversion_dict

Now we can run the inference. We specify the catalogue table, the conversion dictionary, the input flux units (Jy), and then we have some choices.

We can choose to:

1. Run predictive checks to remove outliers. 
2. Impute or remove missing data.
3. Run posterior inference, setting the number of posterior samples to draw.    
4. Recover SEDs for each source.

In [None]:
fitter.recreate_simulator_from_library(
    override_library_path=library_path, override_grid_path="test_grid.hdf5"
)

post_tab = fitter.fit_catalogue(
    cat,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    check_out_of_distribution=False,
    recover_SEDs=False,
    missing_data_flag=np.nan,
    num_samples=300,
    append_to_input=False,
)

We can look at the output table, which is also an `astropy.table.Table`.

In [None]:
post_tab

And we can plot the star-forming main sequence from the fitted catalogue.

In [None]:
import matplotlib.pyplot as plt

plt.scatter(post_tab["log_mass_50"], post_tab["log_sfr_50"])
plt.xlabel("Stellar Mass (log M_sun)")
plt.ylabel("SFR (log M_sun/yr)")

You may notice some of rows and nans. This is because those rows had missing data that we chose not to impute, which can't be handled by the SBI inference method.

We could set `missing_data_mcmc` to `True` to impute those missing values instead.

We can also recover and plot the SEDs, which we'll do for the first few sources in the catalogue.

In [None]:
post_tab, data = fitter.fit_catalogue(
    cat[:4],
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    check_out_of_distribution=False,
    recover_SEDs=True,
    plot_SEDs=True,
    missing_data_flag=np.nan,
    num_samples=300,
    append_to_input=False,
)

If you recover the SEDs then a nested dictionary with the SED posteriors are also returned. Initial key is the source index or ID, then within that there are keys for 'wav', 'fnu_quantiles', 'wav', 'phot_wav', 'phot_fnu_draws' and 'fig'.

In [None]:
data[1].keys()

Here are the SED plots:

In [None]:
for key in data.keys():
    display(data[key]["fig"])

We can check the feature array used for inference by setting `return_feature_array=True`.

In [None]:
feature_array, mask = post_tab, data = fitter.fit_catalogue(
    cat,
    columns_to_feature_names=conversion_dict,
    flux_units=Jy,
    check_out_of_distribution=False,
    return_feature_array=True,
    missing_data_flag=np.nan,
    num_samples=300,
    append_to_input=False,
)

feature_array

We can also run prior predictive checks to identify outliers, by setting `check_out_of_distribution=True`. The available list of methods is any of those in PYOD (https://pyod.readthedocs.io/en/latest/). You can set them with the 'outlier_methods' argument as a list of strings.