In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np

from synference import Simformer_Fitter, load_unc_model_from_hdf5

# Load and create models

In [None]:
grid_pop3_path = "/home/tharvey/work/synference/libraries/grid_Yggdrasil_Chab_Burst_SFH_0.001_z_14_logN_4.4_v1.hdf5"

fitter_pop3 = Simformer_Fitter.init_from_hdf5(
    "Yggdrasil_Chab_Burst_SFH_0.001_z_14_logN_4.4_v1", grid_pop3_path
)

bpass_path = "/home/tharvey/work/synference/libraries/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode.hdf5"  # noqa: E501

fitter_bpass = Simformer_Fitter.init_from_hdf5(
    "BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode", bpass_path
)


noise_model_path = "/home/tharvey/work/synference/priv/data/JOF_psfmatched_asinh_noise_model.h5"
noise_models = load_unc_model_from_hdf5(noise_model_path)

hst_filters = ["F435W", "F606W", "F775W", "F814W", "F850LP"]

nm = {f"JWST/NIRCam.{i}": j for i, j in noise_models.items() if i not in hst_filters}

Smoothing params for asinh

In [None]:
print([-2.5 * np.log10(i.b.to_value("nJy")) + 31.4 for i in nm.values()])

# Create feature arrays

In [None]:
fa_input = dict(
    extra_features=["redshift"],
    empirical_noise_models=nm,
    include_errors_in_feature_array=True,
    scatter_fluxes=True,
    normed_flux_units="asinh",
    photometry_to_remove=[
        "CTIO/DECam.u",
        "CTIO/DECam.g",
        "CTIO/DECam.r",
        "CTIO/DECam.i",
        "CTIO/DECam.z",
        "CTIO/DECam.Y",
        "LSST/LSST.u",
        "LSST/LSST.g",
        "LSST/LSST.r",
        "LSST/LSST.i",
        "LSST/LSST.z",
        "LSST/LSST.Y",
        "PAN-STARRS/PS1.g",
        "PAN-STARRS/PS1.r",
        "PAN-STARRS/PS1.i",
        "PAN-STARRS/PS1.w",
        "PAN-STARRS/PS1.z",
        "PAN-STARRS/PS1.y",
        "Paranal/VISTA.Z",
        "Paranal/VISTA.Y",
        "Paranal/VISTA.J",
        "Paranal/VISTA.H",
        "Paranal/VISTA.Ks",
        "Subaru/HSC.g",
        "Subaru/HSC.r",
        "Subaru/HSC.i",
        "Subaru/HSC.z",
        "Subaru/HSC.Y",
        "CFHT/MegaCam.u",
        "CFHT/MegaCam.g",
        "CFHT/MegaCam.r",
        "CFHT/MegaCam.i",
        "CFHT/MegaCam.z",
        "Euclid/VIS.vis",
        "Euclid/NISP.Y",
        "Euclid/NISP.J",
        "Euclid/NISP.H",
        "HST/ACS_WFC.F475W",
        "HST/WFC3_IR.F105W",
        "JWST/NIRCam.F070W",
        "HST/WFC3_IR.F110W",
        "HST/WFC3_IR.F125W",
        "JWST/NIRCam.F140M",
        "HST/WFC3_IR.F140W",
        "HST/WFC3_IR.F160W",
        "JWST/NIRCam.F360M",
        "JWST/NIRCam.F430M",
        "JWST/NIRCam.F460M",
        "JWST/NIRCam.F480M",
        "JWST/MIRI.F560W",
        "JWST/MIRI.F770W",
        "JWST/MIRI.F1000W",
        "JWST/MIRI.F1130W",
        "JWST/MIRI.F1280W",
        "JWST/MIRI.F1500W",
        "JWST/MIRI.F1800W",
        "JWST/MIRI.F2100W",
        "JWST/MIRI.F2550W",
        "Spitzer/IRAC.I1",
        "Spitzer/IRAC.I2",
        "Spitzer/IRAC.I3",
        "Spitzer/IRAC.I4",
        "HST/ACS_WFC.F435W",
        "HST/ACS_WFC.F606W",
        "HST/ACS_WFC.F775W",
        "HST/ACS_WFC.F814W",
        "HST/ACS_WFC.F850LP",
    ],
)


fitter_pop3.create_feature_array_from_raw_photometry(
    **fa_input, parameter_transformations={"sfh": np.log10}
)
fitter_bpass.create_feature_array_from_raw_photometry(
    **fa_input, parameter_transformations={"Av": np.log10}
);

# Plot feature arrays for comparison

In [None]:
fitter_bpass.plot_histogram_feature_array();

In [None]:
fitter_bpass.plot_histogram_parameter_array();

In [None]:
fitter_pop3.plot_histogram_feature_array();

In [None]:
fitter_pop3.plot_histogram_parameter_array();

# Train simformer models

In [None]:
fitter_pop3.run_single_sbi(name_append="v1", evaluate_model=False);

In [None]:
fitter_bpass.run_single_sbi(name_append="v1", evaluate_model=False)

In [None]:
import harmonic as hm


class HarmonicEvidence:
    """Class to compute evidence using the targeted harmonic mean estimator."""

    def __init__(self):
        """Class to compute evidence."""
        self.evidence = None

    def from_nde(self, npe, nle, x, shape=(10000,), **sample_kwargs):
        """Estimate evidence from a posterior object.

        Args:
            npe (_type_): A Neural Posterior Estimator object with a
                .sample() method.
            nle (_type_): A Neural Likelihood Estimator object with a
                .potential() method.
            x (_type_): The input data to the posterior.
            shape (tuple, optional): The shape of the samples to draw from the
                posterior. Defaults to (10000,).
            **sample_kwargs: Additional keyword arguments to pass to the
                posterior's .sample() method.
        """
        samples = npe.sample(shape, x, **sample_kwargs)
        lnprob = nle.potential(samples, x)
        self.from_samples(samples, lnprob)

    def from_samples(self, samples, lnprob):
        """Estimate evidence from samples and log-probabilities.

        # TODO: Add support for multiple chains

        Args:
            samples (array-like): An (N, D) array of samples,
                where N is the number of samples, D is the dimensionality of
                the samples.
            lnprob (array-like): An (N,) array of log-likelihoods for each
                sample
        """
        if len(samples.shape) != 2:
            raise ValueError("Samples shape must be (N, D).")
        if len(samples) != len(lnprob):
            raise ValueError("Input samples and lnprob must be same length.")
        self._calculate_evidence(samples, lnprob)

    def _calculate_evidence(self, samples, lnprob):
        """Internal function to calculate evidence with the targeted harmonic mean estimator.

        Args:
            samples (array-like): An (N, D) array of samples, where C is the
                number of chains, N is the number of samples, and D is the
                dimensionality of the samples.
            lnprob (array-like): An (N,) array of proxy log-probabilities for
                each sample.
        """
        samples = np.asarray(samples)
        lnprob = np.asarray(lnprob)
        ndim = samples.shape[-1]

        # Split samples into train/test
        chains = hm.Chains(ndim)
        chains.add_chain(samples, lnprob)
        chains.split_into_blocks(100)
        chains_train, chains_infer = hm.utils.split_data(chains, training_proportion=0.5)

        # Select RealNVP Model
        n_scaled_layers = 2
        n_unscaled_layers = 4
        temperature = 0.8

        model = hm.model.RealNVPModel(
            ndim,
            n_scaled_layers=n_scaled_layers,
            n_unscaled_layers=n_unscaled_layers,
            standardize=True,
            temperature=temperature,
        )
        epochs_num = 20
        # Train model
        model.fit(chains_train.samples, epochs=epochs_num, verbose=True)

        # Save harmonic's evidence class
        self.evidence = hm.Evidence(chains_infer.nchains, model)
        self.evidence.add_chains(chains_infer)

    def get_evidence(self):
        """Get the evidence."""
        if self.evidence is None:
            raise ValueError("Evidence has not been computed.")
        return self.evidence.compute_evidence()

    def get_ln_evidence(self):
        """Get the natural log of the evidence."""
        if self.evidence is None:
            raise ValueError("Evidence has not been computed.")
        return self.evidence.compute_ln_evidence()

    def get_bayes_factor(self, ev2):
        """Get the Bayes factor between this evidence and another."""
        if self.evidence is None:
            raise ValueError("Evidence has not been computed.")
        if ev2.evidence is None:
            raise ValueError("Evidence for model 2 has not been computed.")
        return hm.evidence.compute_bayes_factor(self.evidence, ev2.evidence)

    def get_ln_bayes_factor(self, ev2):
        """Get the natural log of the Bayes factor between this evidence and another."""
        if self.evidence is None:
            raise ValueError("Evidence has not been computed.")
        if ev2.evidence is None:
            raise ValueError("Evidence for model 2 has not been computed.")
        return hm.evidence.compute_ln_bayes_factor(self.evidence, ev2.evidence)


def get_bayes_factor(fitter1, fitter2, obs, shape=(10000,)):
    """Get the Bayes factor between two fitters for a given observation.

    Args:
        fitter1 (Simformer_Fitter): The first fitter.
        fitter2 (Simformer_Fitter): The second fitter.
        obs (array-like): The observation to compute the Bayes factor for.
        shape (tuple, optional): The shape of the samples to draw from the
            posterior. Defaults to (10000,).

    Returns:
        float: The Bayes factor (evidence ratio) between the two fitters.

    """
    he1 = HarmonicEvidence()
    he2 = HarmonicEvidence()

    samples1 = fitter1.sample_posterior(obs)
    print(samples1.shape)
    lnprob1 = fitter1.log_prob(X_test=obs, theta=samples1)

    print(lnprob1.shape)
    he1.from_samples(samples1, lnprob1)

    samples2 = fitter2.sample_posterior(obs)
    lnprob2 = fitter2.log_prob(X_test=obs, theta=samples2)
    he2.from_samples(samples2, lnprob2)

    bf = he1.get_bayes_factor(he2)
    return bf

In [None]:
fitter_bpass.posteriors

In [None]:
# Pick a random observation from each fitters validation set (._X_test)
bpass_obs = fitter_bpass._X_test
pop3_obs = fitter_pop3._X_test
random_index_bpass = np.random.randint(0, len(bpass_obs))
random_index_pop3 = np.random.randint(0, len(pop3_obs))

bpass_obs_i = bpass_obs[random_index_bpass]
pop3_obs_i = pop3_obs[random_index_pop3]

pop3_obs_i, bpass_obs_i

In [None]:
# Take mock example data.
print("Bayes factor (BPASS/PopIII) for BPASS-like observation:")
print(get_bayes_factor(fitter_bpass, fitter_pop3, bpass_obs_i, shape=(1000,)))

print("Bayes factor (BPASS/PopIII) for PopIII-like observation:")
print(get_bayes_factor(fitter_bpass, fitter_pop3, pop3_obs_i, shape=(1000,)))