In [None]:
%load_ext autoreload
%autoreload 2

import os

import matplotlib.pyplot as plt
from astropy.table import Table

from sbifitter import SBI_Fitter

file_path = os.path.dirname(os.path.realpath(__file__))
grid_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "grids")
output_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "models")

In [None]:
grid_path = f"""{grid_folder}/grid_BPASS_DelayedExponential_SFH_0.01_z_12_logN_5.7_Chab_CF00_v1.hdf5"""  # noqa

fitter = SBI_Fitter.init_from_hdf5(
    "BPASS_Chab_LogNorm_5_z_12_phot_grid2", grid_path, return_output=False
)

model_name = "BPASS_Chab_DelayedExpSFH_0.01_z_12_CF00_v1sbipp_settings_posterior"
model_name = "BPASS_Chab_DelayedExpSFH_0.01_z_12_CF00_v1sbipp_settings_nsf_zfix_posterior"

fitter.load_model_from_pkl(
    f"{output_folder}/models/BPASS_Chab_DelayedExpSFH_0.01_z_12_CF00_v1/{model_name}.pkl"
);

In [None]:
# Do the cheeky thing of setting the flags explicity

fitter.feature_array_flags = dict(
    normalize_method=None,
    extra_features=[],
    normed_flux_units="AB",
    normalization_unit="AB",
    scatter_fluxes=True,
    empirical_noise_models=True,
    depths=None,
    include_errors_in_feature_array=True,
    min_flux_pc_error=0.9,
    simulate_missing_fluxes=False,
    missing_flux_value=99.0,
    missing_flux_fraction=0.9,
    missing_flux_options=None,
    include_flags_in_feature_array=False,
    override_phot_grid=None,
    override_phot_grid_units=None,
    norm_mag_limit=40,
    remove_nan_inf=None,
    parameters_to_remove=[],
    photometry_to_remove=None,
    drop_dropouts=False,
    drop_dropout_fraction=1.0,
    raw_photometry_names=fitter.feature_names[:9],
    error_names=fitter.feature_names[9:],
    flag_names=[],
    norm_name=None,
)

# Test on real galaxies

Firstly load our real catalogue, and setup a mapping between feature names and column names using a dictionary.

In [None]:
file = """/home/tharvey/Downloads/JADES-Deep-GS_MASTER_Sel-f277W+f356W+f444W
            _v9_loc_depth_masked_10pc_EAZY_matched_selection_ext_src_UV.fits"""
table = Table.read(file)

new_table = Table()


def mag_cols_syntax(band):
    """Returns the syntax for the magnitude columns in the table."""
    return f"MAG_APER_{band}_aper_corr"


def magerr_cols_syntax(band):
    """Returns the syntax for the magnitude error columns in the table."""
    return f"MAGERR_APER_{band}_u1_loc_depth"


table = table[
    (table["selected_gal_all_criteria_delta_chi2_4_fsps_larson_no_bd"])
    & (table["sigma_f444W"][:, 0] > 10)
    & (table["sigma_f277W"][:, 0] > 10)
    & (table["sigma_f356W"][:, 0] > 10)
]
bands = [i.split("_")[-1] for i in table.colnames if i.startswith("loc_depth")]
new_band_names = ["HST/ACS_WFC.F606W"] + [
    f"JWST/NIRCam.{band.upper()}" for band in bands[1:]
]

for band in bands:
    new_table[mag_cols_syntax(band)] = table[mag_cols_syntax(band)][:, 0]
    new_table[magerr_cols_syntax(band)] = table[magerr_cols_syntax(band)][:, 0]

new_table["z_eazy"] = table["zbest_fsps_larson"]

conversion_dict = {
    mag_cols_syntax(band): new_band for band, new_band in zip(bands, new_band_names)
}
conversion_dict.update(
    {
        magerr_cols_syntax(band): f"unc_{new_band}"
        for band, new_band in zip(bands, new_band_names)
    }
)

 Features is catalogue converted to match features. Mask is a 1D boolean array showing rows which were removed
due to missing data

In [None]:
obs_features, obs_mask = fitter.create_features_from_observations(
    new_table, columns_to_feature_names=conversion_dict, flux_units="AB"
)

We can check if any observations are out of distribution given our training data

In [None]:
from sbifitter import test_out_of_distribution

a, b = test_out_of_distribution(fitter.feature_array, obs_features, sigma_threshold=5.0)

Or we can let the code do this internally and it will add the columns to the table for us

In [None]:
output = fitter.fit_catalogue(
    new_table,
    columns_to_feature_names=conversion_dict,
    flux_units="AB",
    sample_method="direct",
    timeout_seconds_per_row=10,
)

In [None]:
output_masked = output[output["redshift_50"] != 0.0]
plt.scatter(output_masked["z_eazy"], output_masked["redshift_50"])
plt.xlabel("z_eazy")

plt.plot(plt.xlim(), plt.xlim(), ls="--", color="k")
plt.ylabel("z_sbi")
plt.xlabel("z_eazy")

In [None]:
import torch

x = torch.tensor(obs_features, dtype=torch.float32, device=fitter.device)
fitter.posteriors.sample_batched(x=x, sample_shape=(1000,))

In [None]:
fitter._prior.dist.low, fitter._prior.dist.high

for i in range(len(fitter.simple_fitted_parameter_names)):
    print(
        fitter.simple_fitted_parameter_names[i],
        fitter._prior.dist.low[i].item(),
        fitter._prior.dist.high[i].item(),
    )

In [None]:
import torch

# max
(
    7.1637e00,
    8.4432e00,
    3.2980e-02,
    6.2265e-02,
    1.6504e03,
)
5.2127e03, -1.4724e00

# min
(
    9.8384e-01,
    7.4450e00,
    -5.1909e-01,
    -3.3764e-01,
    1.9143e02,
)
5.1072e02, -3.3405e00

# mean
a = [
    6.1481e00,
    8.1235e00,
    -7.4607e-02,
    -5.8678e-02,
    9.8795e02,
    7.5202e02,
    -2.5445e00,
]
a = torch.tensor(a, dtype=torch.float32, device=fitter.device)

help(fitter._prior.support.check)  # (a)

In [None]:
fitter.grid_path