In [None]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
from astropy.cosmology import Planck18
from synthesizer.emission_models import (
    TotalEmission,
)
from synthesizer.emission_models.attenuation import Calzetti2000
from synthesizer.grid import Grid
from synthesizer.instruments import FilterCollection, Instrument
from synthesizer.parametric import SFH, ZDist
from unyt import Jy, Myr

from sbifitter import SBI_Fitter

file_path = os.path.dirname(os.path.realpath(__file__))
grid_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "grids")
output_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "models")

grid_path = (
    f"""{grid_folder}/grid_Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v1.hdf5"""  # noqa E501
)

In [None]:
fitter = SBI_Fitter.init_from_hdf5(
    "Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v1_ensemble_redshift_nonorm_lampe",
    grid_path,
    return_output=False,
)

fitter.create_feature_array_from_raw_photometry(extra_features=[], normalize_method=None);

Let's set up a 6 network model ensemble with both MDN and MAF models, with a range of hidden features. 


Notes

Learning rate 1e-5 took ~8 hours to train all nets, typically >1000 epochs each  
1e-3 took ~ 1 hour, typically ~200 epochs each

In [None]:
fitter.run_single_sbi(
    n_nets=2,
    backend="lampe",
    engine="NPE",
    stop_after_epochs=20,
    hidden_features=[50, 75],
    learning_rate=1e-3,
    num_transforms=5,
    num_components=5,
    model_type=["mdn", "mdn"],
)

In [None]:
fitter = SBI_Fitter.init_from_hdf5(
    "Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v1_ensemble_redshift_nonorm_lampe_maf",
    grid_path,
    return_output=False,
)

fitter.create_feature_array_from_raw_photometry(extra_features=[], normalize_method=None)
fitter.run_single_sbi(
    n_nets=2,
    backend="lampe",
    engine="NPE",
    stop_after_epochs=20,
    hidden_features=[50, 75],
    learning_rate=1e-3,
    num_transforms=[5, 3],
    model_type=["maf", "maf"],
)

Setup  a simulator so we can try and recreate the SED

In [None]:
from sbifitter import GalaxySimulator

grid_dir = os.path.environ("SYNTHESIZER_GRID_DIR")
grid_name = "bpass-2.2.1-bin_chabrier03-0.1,300.0_cloudy-c23.01-sps.hdf5"

grid = Grid(
    grid_name,
    grid_dir=grid_dir,
)

filter_codes = [
    "JWST/NIRCam.F070W",
    "JWST/NIRCam.F090W",
    "JWST/NIRCam.F115W",
    "JWST/NIRCam.F140M",
    "JWST/NIRCam.F150W",
    "JWST/NIRCam.F162M",
    "JWST/NIRCam.F182M",
    "JWST/NIRCam.F200W",
    "JWST/NIRCam.F210M",
    "JWST/NIRCam.F250M",
    "JWST/NIRCam.F277W",
    "JWST/NIRCam.F300M",
    "JWST/NIRCam.F335M",
    "JWST/NIRCam.F356W",
    "JWST/NIRCam.F360M",
    "JWST/NIRCam.F410M",
    "JWST/NIRCam.F430M",
    "JWST/NIRCam.F444W",
    "JWST/NIRCam.F460M",
    "JWST/NIRCam.F480M",
]
filterset = FilterCollection(filter_codes)
instrument = Instrument("JWST", filters=filterset)

sfh = SFH.LogNormal
zdist = ZDist.DeltaConstant

priors = {
    "redshift": (5.0, 10.0),
    "log_mass": (7.0, 10.0),
    "log10metallicity": (-3.0, 0.3),
    "tau_v": (0.0, 1.5),
    "peak_age": (0, 500),
    "max_age": (500, 1000),
    "tau": (0.3, 1.5),
}

emission_model = TotalEmission(
    grid=grid,
    fesc=0.1,
    fesc_ly_alpha=0.1,
    dust_curve=Calzetti2000(),
    dust_emission_model=None,
)

# This tells the emission model we will have a parameter called 'tau_v'.
emitter_params = {"stellar": ["tau_v"]}


simulator = GalaxySimulator(
    sfh_model=sfh,
    zdist_model=zdist,
    grid=grid,
    instrument=instrument,
    emission_model=emission_model,
    emission_model_key="total",
    emitter_params=emitter_params,
    param_units={"peak_age": Myr, "max_age": Myr},
    normalize_method=None,
    output_type=["photo_fnu", "fnu"],
    out_flux_unit="ABmag",
)

We're not fitting for max_age, but we need it to recreate the SFH. Here we write a function to calculate it from the other parameters. This could also be used to margnalize over unfitted parameters e.g. metallicity, but we are not doing that here.

In [None]:
marginalized_parameters = {}


def max_age_from_z(params, max_redshift=20, cosmo=Planck18):
    """Calculate the maximum age of the galaxy based on its redshift."""
    redshift = params["redshift"]
    max_age = (cosmo.age(redshift) - cosmo.age(max_redshift)).to(u.Myr).value
    return max_age


marginalized_parameters["max_age"] = max_age_from_z

In [None]:
# Pick some random photometry

index = np.random.choice(fitter._test_indices)
params = fitter.feature_array[index, :].copy()

params

In [None]:
%matplotlib inline

fig = fitter.recover_SED(
    X_test=params,
    simulator=simulator,
    marginalized_parameters=marginalized_parameters,
    true_parameters=fitter.fitted_parameter_array[index, :],
)

fig

In [None]:
for i in range(10):
    index = np.random.choice(fitter._test_indices)
    params = fitter.feature_array[index, :].copy()

    fig = fitter.recover_SED(
        X_test=params,
        simulator=simulator,
        marginalized_parameters=marginalized_parameters,
        true_parameters=fitter.fitted_parameter_array[index, :],
        plot_name=f"recovered_sed_{index}.png",
    )
    plt.show()

Train a noisy model

In [None]:
fitter = SBI_Fitter.init_from_hdf5(
    "Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v2_ensemble_redshift_nonorm_sbi_noisy",
    grid_path,
    return_output=False,
)

load = True
if load:
    fitter.load_model_from_pkl(
        f"{output_folder}/Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v2_ensemble_redshift_nonorm_sbi_noisy"  # noqa E501
    )
else:
    # Draw depths from a normal distribution.
    depth_centers = [
        30 if fitter.raw_photometry_names[i].endswith("W") else 29.5
        for i in range(len(fitter.raw_photometry_names))
    ]
    depth_sigma = [
        0.2 if fitter.raw_photometry_names[i].endswith("W") else 0.3
        for i in range(len(fitter.raw_photometry_names))
    ]
    # Draw 20 possible depths from a normal distribution
    depths = np.random.normal(
        loc=depth_centers,
        scale=depth_sigma,
        size=(20, len(fitter.raw_photometry_names)),
    )
    depths_jy = 10 ** ((depths - 8.90) / -2.5) * Jy  # 30 AB mag in all 22 filters

    # Set 5 possible options - no missing bands, missing M bands, missing M bands except
    # F410M, missing M bands except F410M and F335M, missing M bands
    # except F410M and F070W
    missing_bands = np.empty((5, len(fitter.raw_photometry_names)))
    missing_bands[0] = 0
    missing_bands[1] = [
        True if band.endswith("M") else False for band in fitter.raw_photometry_names
    ]
    missing_bands[2] = [
        True if band.endswith("M") and band not in ["NIRCam.F410M"] else False
        for band in fitter.raw_photometry_names
    ]
    missing_bands[3] = [
        True
        if band.endswith("M") and band not in ["NIRCam.F410M", "NIRCam.F335M"]
        else False
        for band in fitter.raw_photometry_names
    ]
    missing_bands[4] = [
        True
        if (band.endswith("M") and band not in ["NIRCam.F410M"]) or band == "NIRCam.F070W"
        else False
        for band in fitter.raw_photometry_names
    ]

    fitter.create_feature_array_from_raw_photometry(
        extra_features=[],
        normalize_method=None,
        scatter_fluxes=5,
        include_errors_in_feature_array=True,
        depths=depths_jy,
        simulate_missing_fluxes=False,
        include_flags_in_feature_array=False,
        missing_flux_options=missing_bands,
    )

    fitter.run_single_sbi(
        n_nets=1,
        backend="sbi",
        engine="NPE",
        stop_after_epochs=20,
        hidden_features=[75],
        learning_rate=1e-3,
        num_transforms=[3],
        model_type=["maf"],
    )

Quick test

In [None]:
from synthesizer.parametric import SFH
from unyt import Angstrom, Msun, nJy, unyt_quantity

sfh = {
    "tau": np.float32(1.1300783),
    "peak_age": unyt_quantity(547.21136, dtype=np.float32, units="Myr"),
    "max_age": unyt_quantity(546.58285752, "Myr"),
    "min_age": unyt_quantity(0, "yr"),
}

sfh = SFH.LogNormal(**sfh)

# t, sfh = sfh.calculate_sfh()

In [None]:
from synthesizer.emission_models import (
    TotalEmission,
)
from synthesizer.emission_models.attenuation import Calzetti2000
from synthesizer.emissions import plot_spectra
from synthesizer.grid import Grid
from synthesizer.parametric import SFH, Galaxy, Stars, ZDist

galaxy = Galaxy(
    redshift=10,
    stars=Stars(
        grid.log10age,
        grid.metallicity,
        sf_hist=sfh,
        initial_mass=1e7 * Msun,
        metal_dist=ZDist.DeltaConstant(log10metallicity=2e-2),
        tau_v=0,
    ),
)

galaxy.stars.plot_sfzh()

In [None]:
galaxy.stars.get_spectra(emission_model)

galaxy.get_observed_spectra(cosmo=Planck18)

plot_spectra(
    galaxy.stars.spectra["total"],
    quantity_to_plot="fnu",
    show=True,
    xlimits=(10000, 50_000) * Angstrom,
    ylimits=(0.01, 10) * nJy,
)

In [None]:
galaxy.stars.spectra["total"].fnu

In [None]:
d_njy = depths.to(nJy).value[0]


import astropy.units as u

d = d_njy * u.nJy
d /= 5

print(f"Depth is {d.to(u.ABmag)}")


test_fluxes = np.linspace(26, 35, 1000) * u.ABmag

mag_err_lower = lambda x: np.log10(  # noqa: E731
    x.to(u.nJy).value / (x.to(u.nJy).value - d.to(u.nJy).value)
)
mag_err_upper = lambda x: np.log10(1 + np.abs(d.to(u.nJy).value / x.to(u.nJy).value))  # noqa: E731

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(test_fluxes, mag_err_lower(test_fluxes), label="lower error")
ax.plot(test_fluxes, mag_err_upper(test_fluxes), label="upper error")
ax.set_xlabel("Flux (ABmag)")
ax.set_ylabel("Error (ABmag)")
ax.vlines(d.to(u.ABmag).value, 0, 1, color="red", linestyle="--", label="1 sig error")
ax.vlines(
    (5 * d).to(u.ABmag).value,
    0,
    1,
    color="blue",
    linestyle="--",
    label="5 sig depth",
)
ax.legend()

Let's test a model which can handle missing photometry. We will remove some datapoints randomly from the training set, and also add a binary feature for each band which marks whether the band is missing or not. 

In [None]:
noisy_fitter = SBI_Fitter.init_from_hdf5(
    "Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v1_ensemble_redshift_nonorm_missing",
    grid_path,
    return_output=False,
)

noisy_fitter.create_feature_array_from_raw_photometry(
    extra_features=[],
    normalize_method=None,
    include_flags_in_feature_array=True,
    simulate_missing_fluxes=True,
    missing_flux_fraction=0.1,
)

noisy_fitter.run_single_sbi(
    n_nets=1,
    backend="sbi",
    engine="NPE",
    stop_after_epochs=20,
    hidden_features=100,
    learning_rate=1e-3,
    num_transforms=4,
    model_type="maf",
)

In [None]:
index = np.random.choice(noisy_fitter._test_indices)
params = noisy_fitter.feature_array[index, :].copy()
print(params)
fig = noisy_fitter.recover_SED(
    X_test=params,
    simulator=simulator,
    marginalized_parameters=marginalized_parameters,
    true_parameters=noisy_fitter.fitted_parameter_array[index, :],
    plot_name=f"recovered_sed_{index}.png",
)


fig

Now remove some photometry - here we are removing 6 filter measurements 

In [None]:
remove_idx = [5, 7, 9, 12, 14, 19]
for remove in remove_idx:
    params[remove] = 99
    params[remove + 20] = 1

fig = noisy_fitter.recover_SED(
    X_test=params,
    simulator=simulator,
    marginalized_parameters=marginalized_parameters,
    true_parameters=noisy_fitter.fitted_parameter_array[index, :],
    plot_name=f"recovered_sed_{index}.png",
)

fig

### Delayed Exponential SFH model with noise


In [None]:
grid_path = f"""{output_folder}/grid_BPASS_DelayedExponential_SFH_0.01_z_12_logN_4.7_Chab_Calzetti_v1.hdf5"""  # noqa E501

fitter = SBI_Fitter.init_from_hdf5(
    "BPASS_DelayedExponential_SFH_0.01_z_12_logN_4.7_Chab_Calzetti_v1_redshift_nonorm_sbi_noisy",
    grid_path,
    return_output=False,
)

nfilt = len(fitter.raw_photometry_names)

depths = (
    10 ** ((np.array([30] * nfilt) - 8.90) / -2.5) * Jy
)  # 30 AB mag in all 22 filters
fitter.create_feature_array_from_raw_photometry(
    extra_features=[],
    normalize_method=None,
    scatter_fluxes=5,
    include_errors_in_feature_array=True,
    depths=depths,
)

fitter.run_single_sbi(
    n_nets=1,
    backend="sbi",
    engine="NPE",
    stop_after_epochs=20,
    hidden_features=75,
    learning_rate=1e-3,
    num_transforms=3,
    model_type="maf",
)

Model with missing data and photometry

In [None]:
fitter = SBI_Fitter.init_from_hdf5(
    "Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v2_ensemble_redshift_nonorm_sbi_noisy_missing_lampe",
    grid_path,
    return_output=False,
)

load = True
if load:
    fitter.load_model_from_pkl(
        f"{output_folder}/Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v2_ensemble_redshift_nonorm_sbi_noisy_missing_lampe"  # noqa E501
    )
else:
    # Draw depths from a normal distribution.
    depth_centers = [
        30 if fitter.raw_photometry_names[i].endswith("W") else 29.5
        for i in range(len(fitter.raw_photometry_names))
    ]
    depth_sigma = [
        1.0 if fitter.raw_photometry_names[i].endswith("W") else 1.3
        for i in range(len(fitter.raw_photometry_names))
    ]
    # Draw 20 possible depths from a normal distribution
    depths = np.random.normal(
        loc=depth_centers,
        scale=depth_sigma,
        size=(20, len(fitter.raw_photometry_names)),
    )
    depths_jy = 10 ** ((depths - 8.90) / -2.5) * Jy  # 30 AB mag in all 22 filters

    # Set 5 possible options - no missing bands, missing M bands, missing M bands except
    # F410M, missing M bands except F410M and F335M, missing M bands
    # except F410M and F070W
    missing_bands = np.empty((5, len(fitter.raw_photometry_names)))
    missing_bands[0] = 0
    missing_bands[1] = [
        True if band.endswith("M") else False for band in fitter.raw_photometry_names
    ]
    missing_bands[2] = [
        True if band.endswith("M") and band not in ["NIRCam.F410M"] else False
        for band in fitter.raw_photometry_names
    ]
    missing_bands[3] = [
        True
        if band.endswith("M") and band not in ["NIRCam.F410M", "NIRCam.F335M"]
        else False
        for band in fitter.raw_photometry_names
    ]
    missing_bands[4] = [
        True
        if (band.endswith("M") and band not in ["NIRCam.F410M"]) or band == "NIRCam.F070W"
        else False
        for band in fitter.raw_photometry_names
    ]

    fitter.create_feature_array_from_raw_photometry(
        extra_features=[],
        normalize_method=None,
        scatter_fluxes=5,
        include_errors_in_feature_array=True,
        depths=depths_jy,
        simulate_missing_fluxes=True,
        include_flags_in_feature_array=True,
        missing_flux_options=missing_bands,
    )

    fitter.run_single_sbi(
        n_nets=3,
        backend="lampe",
        engine="NPE",
        model_type="nsf",
        stop_after_epochs=20,
        hidden_features=[75, 50, 25],
        learning_rate=1e-4,
        num_transforms=[3, 4, 5],
    )

In [None]:
index = np.random.choice(fitter._test_indices)
index = 87692
params = fitter.feature_array[index, :].copy()

"""
for i, name in enumerate(fitter.feature_names):
    if 'M' in name and 'flag' not in name:
        params[i] = 99
    elif 'M' in name and 'flag' in name:
        params[i] = 1"""

print(params)

In [None]:
%matplotlib inline
fig = fitter.recover_SED(
    X_test=params,
    simulator=simulator,
    marginalized_parameters=marginalized_parameters,
    true_parameters=fitter.fitted_parameter_array[index, :],
    plot_name=f"recovered_sed_{index}.png",
    plot_closest_draw_to={"tau_v": 2},
    param_labels=[
        "Redshift",
        r"$\log_{10}\rm M_{\star}$",
        r"$\tau_V$ (mag)",
        r"$\tau$",
        r"$\rm Peak\ Age (Myr)$",
        r"$\log_{10} Z_{\star}$",
    ],
)
fig

In [None]:
params = [7.046468, 11.061219, 1.931601, 0.27388036, 73.22359, -1.3983577]
dict = {i: j for i, j in zip(fitter.simple_fitted_parameter_names, params)}
dict["max_age"] = max_age_from_z(dict, max_redshift=20, cosmo=Planck18)

print(dict)
out = simulator(dict)

In [None]:
grid_path = f"""{output_folder}/
            grid_Pop_II_LogNormal_SFH_5_z_12_logN_4.7_BPASS_Chab_Calzetti_v1_fesc0.0.hdf5"""
model_name = """Pop_II_LogNormal_SFH_5_z_12_
            logN_5.0_BPASS_Chab_v1_fesc0.0_ensemble_redshift_nonorm_sbi_lampe"""
load = False

fitter = SBI_Fitter.init_from_hdf5(model_name, grid_path, return_output=False)


if load:
    fitter.load_model_from_pkl(f"{output_folder}/{model_name}")
else:
    fitter.create_feature_array_from_raw_photometry(
        extra_features=[],
        normalize_method=None,
    )

    fitter.run_single_sbi(
        n_nets=3,
        backend="lampe",
        engine="NPE",
        model_type="nsf",
        stop_after_epochs=20,
        hidden_features=180,
        learning_rate=0.0005,
        num_transforms=5,
    )