In [1]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np
import torch
from synthesizer.emission_models import (
    TotalEmission,
)
from synthesizer.emission_models.attenuation import Calzetti2000
from synthesizer.grid import Grid
from synthesizer.instruments import FilterCollection, Instrument
from synthesizer.parametric import SFH, ZDist
from unyt import Myr

from sbifitter import GalaxySimulator, SBI_Fitter

device = "cuda"

file_path = os.path.dirname(os.path.realpath(__file__))
grid_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "grids")
output_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "models")

NameError: name 'Tuple' is not defined

### Setup Grid

In [None]:
grid_dir = os.path.environ["SYNTHESIZER_GRID_DIR"]
grid_name = "bpass-2.2.1-bin_chabrier03-0.1,300.0_cloudy-c23.01-sps.hdf5"

grid = Grid(
    grid_name,
    grid_dir=grid_dir,
)

### Setup instrument

In [None]:
filter_codes = [
    "JWST/NIRCam.F090W",
    "JWST/NIRCam.F115W",
    "JWST/NIRCam.F150W",
    "JWST/NIRCam.F162M",
    "JWST/NIRCam.F182M",
    "JWST/NIRCam.F200W",
    "JWST/NIRCam.F210M",
    "JWST/NIRCam.F250M",
    "JWST/NIRCam.F277W",
    "JWST/NIRCam.F300M",
    "JWST/NIRCam.F335M",
    "JWST/NIRCam.F356W",
    "JWST/NIRCam.F410M",
    "JWST/NIRCam.F444W",
]
filterset = FilterCollection(filter_codes)
instrument = Instrument("JWST", filters=filterset)

### Setup model

In [None]:
sfh = SFH.LogNormal
zdist = ZDist.DeltaConstant

### Setup Emission Model

In [None]:
emission_model = TotalEmission(
    grid=grid,
    fesc=0.1,
    fesc_ly_alpha=0.1,
    dust_curve=Calzetti2000(),
    dust_emission_model=None,
)

# This tells the emission model we will have a parameter called 'tau_v'
# on the stellar emitter.
emitter_params = {"stellar": ["tau_v"]}

### Setup Photometry Simulator

Note the default is to return photometry, but you can return rest-frame fluxes, spectra
or observed spectra with output_type = 'fnu', 'photo_fnu', 'lnu' or 'photo_lnu' 

In [None]:
simulator = GalaxySimulator(
    sfh_model=sfh,
    zdist_model=zdist,
    grid=grid,
    instrument=instrument,
    emission_model=emission_model,
    emission_model_key="total",
    emitter_params=emitter_params,
    param_units={"peak_age": Myr, "max_age": Myr},
    normalize_method=None,  # calculate_muv,
    output_type="photo_fnu",
    out_flux_unit="ABmag",
)

Now let's test it with an input dictionary

In [None]:
params = {
    "redshift": 6,
    "log_mass": 9.0,
    "tau": 0.4,
    "log10metallicity": -0.5,
    "peak_age": 100,
    "max_age": 800,
    "tau_v": 0.4,
}

How fast is it?

In [None]:
%timeit simulator(params=params)

We're going to wrap the simulator in a small function which puts the input in the correct format (a Tensor of shape (1, *model.input_shape)) and then runs the simulator. To ensure we know the inputs, we will list them here.

In [None]:
inputs = [
    "redshift",
    "log_mass",
    "log10metallicity",
    "tau_v",
    "peak_age",
    "max_age",
    "tau",
]


def run_simulator(params, return_type="tensor"):
    """Run the galaxy simulator with the given parameters."""
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()

    if isinstance(params, dict):
        params = {i: params[i] for i in inputs}
    elif isinstance(params, (list, tuple, np.ndarray)):
        params = np.squeeze(params)
        params = {inputs[i]: params[i] for i in range(len(inputs))}

    phot = simulator(params)
    if return_type == "tensor":
        return torch.tensor(phot[np.newaxis, :], dtype=torch.float32).to(device)
    else:
        return phot


run_simulator(params)

### Model

We should be able to either train purely online, or start with a grid and then continue
with online training. First we will test pure online training.

In [None]:
fitter = SBI_Fitter(
    name="online_test3",
    simulator=run_simulator,
    parameter_names=inputs,
    raw_photometry_names=simulator.instrument.filters.filter_codes + ["norm"],
)

Now we need to set our priors. 

In [None]:
priors = {
    "redshift": (5.0, 10.0),
    "log_mass": (7.0, 10.0),
    "log10metallicity": (-3.0, 0.3),
    "tau_v": (0.0, 1.5),
    "peak_age": (0, 500),
    "max_age": (500, 1000),
    "tau": (0.3, 1.5),
}

In [None]:
fitter.run_single_sbi(
    engine="SNPE",
    learning_type="online",
    override_prior_ranges=priors,
    num_simulations=10_000,
    num_online_rounds=7,
)

In [None]:
fitter.plot_loss()

In [None]:
fitter.run_validation_from_file(
    f"{output_folder}/online_test/online_test_20250512_174838_posterior.pkl"
)

Optuna parameter optimization

In [None]:
name = "online_optuna"

suggested_hyperparameters = {
    "learning_rate": [1e-6, 1e-3],
    "hidden_features": [12, 200],
    "num_components": [2, 16],
    "training_batch_size": [32, 128],
    "num_transforms": [1, 4],
    "stop_after_epochs": [10, 30],
    "clip_max_norm": [0.1, 5.0],
    "validation_fraction": [0.1, 0.3],
    "num_online_rounds": [5, 10],
    "num_simulations": [8000, 30000],
}

fixed_hyperparameters = {
    "engine": "SNPE",
    "learning_type": "online",
    "override_prior_ranges": priors,
}

In [None]:
fitter.optimize_sbi(
    study_name=name,
    n_jobs=6,
    n_trials=20,
    suggested_hyperparameters=suggested_hyperparameters,
    fixed_hyperparameters=fixed_hyperparameters,
)

In [None]:
fitter.load_model_from_pkl(f"{output_folder}/online_test3");

In [None]:
fitter.plot_coverage(sample_method="vi")

Let's do a very simple model with just metallicity and dust. 


In [None]:
inputs = ["log10metallicity", "tau_v"]
fixed_params = {
    "redshift": 7,
    "log_mass": 9.0,
    "peak_age": 100,
    "max_age": 800,
    "tau": 0.7,
}

priors = {
    "log10metallicity": (-3.0, -1.3),
    "tau_v": (0.0, 2),
}

simulator = GalaxySimulator(
    sfh_model=sfh,
    zdist_model=zdist,
    grid=grid,
    instrument=instrument,
    emission_model=emission_model,
    emission_model_key="total",
    emitter_params=emitter_params,
    param_units={"peak_age": Myr, "max_age": Myr},
    normalize_method=None,  # calculate_muv,
    output_type="photo_fnu",
    out_flux_unit="ABmag",
    fixed_params=fixed_params,
)


def run_simulator(params, return_type="tensor"):
    """Run the galaxy simulator with the given parameters."""
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()

    if isinstance(params, dict):
        params = {i: params[i] for i in inputs}
    elif isinstance(params, (list, tuple, np.ndarray)):
        params = np.squeeze(params)
        params = {inputs[i]: params[i] for i in range(len(inputs))}

    phot = simulator(params)
    if return_type == "tensor":
        return torch.tensor(phot[np.newaxis, :], dtype=torch.float32).to(device)
    else:
        return phot


run_simulator(params)

fitter = SBI_Fitter(
    name="online_dust_zmet_test",
    simulator=run_simulator,
    parameter_names=inputs,
    raw_photometry_names=simulator.instrument.filters.filter_codes + ["norm"],
)

In [None]:
fitter.run_single_sbi(
    engine="SNPE",
    learning_type="online",
    override_prior_ranges=priors,
    num_simulations=10_000,
    num_online_rounds=1,
)

In [None]:
fitter.plot_loss()

In [None]:
fitter.plot_coverage()

What about just SFH?

In [None]:
inputs = ["peak_age", "max_age", "tau"]

fixed_params = {
    "redshift": 7,
    "log_mass": 9.0,
    "log10metallicity": -2,
    "tau_v": 0.3,
}

priors = {"peak_age": (0, 500), "max_age": (500, 1000), "tau": (0.1, 2)}

simulator = GalaxySimulator(
    sfh_model=sfh,
    zdist_model=zdist,
    grid=grid,
    instrument=instrument,
    emission_model=emission_model,
    emission_model_key="total",
    emitter_params=emitter_params,
    param_units={"peak_age": Myr, "max_age": Myr},
    normalize_method=None,  # calculate_muv,
    output_type="photo_fnu",
    out_flux_unit="ABmag",
    fixed_params=fixed_params,
)


def run_simulator(params, return_type="tensor"):
    """Run the galaxy simulator with the given parameters."""
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()

    if isinstance(params, dict):
        params = {i: params[i] for i in inputs}
    elif isinstance(params, (list, tuple, np.ndarray)):
        params = np.squeeze(params)
        params = {inputs[i]: params[i] for i in range(len(inputs))}

    phot = simulator(params)
    if return_type == "tensor":
        return torch.tensor(phot[np.newaxis, :], dtype=torch.float32).to(device)
    else:
        return phot


fitter = SBI_Fitter(
    name="online_sfh_lognorm_test",
    simulator=run_simulator,
    parameter_names=inputs,
    raw_photometry_names=simulator.instrument.filters.filter_codes + ["norm"],
)

fitter.run_single_sbi(
    engine="SNPE",
    learning_type="online",
    override_prior_ranges=priors,
    num_simulations=10_000,
    num_online_rounds=1,
)

fitter.plot_loss()

fitter.plot_coverage()