In [None]:
%load_ext autoreload
%autoreload 2
import os

from sbifitter import SBI_Fitter

file_path = os.path.dirname(os.path.realpath(__file__))
grid_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "grids")
output_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "models")

grid_path = f"""{grid_folder}/grid_Pop_II_LogNormal_SFH_5_z_12_logN_5.0_BPASS_Chab_v1.hdf5"""  # noqa: E501

In [None]:
fitter = SBI_Fitter.init_from_hdf5(
    "BPASS_Chab_LogNorm_5_z_12_optimize",
    grid_path,
    return_output=False,
    device="cpu",
)


fitter.create_feature_array_from_raw_photometry(extra_features=[], normalize_method=None);

In [None]:
fitter.optimize_sbi(
    study_name="BPASS_Chab_LogNorm_5_z_12_optimize2",
    suggested_hyperparameters={
        "learning_rate": [1e-5, 1e-2],
        "hidden_features": [12, 200],
        "num_components": [2, 16],
        "training_batch_size": [32, 128],
        "num_transforms": [1, 4],
        "stop_after_epochs": [10, 30],
    },
    n_trials=40,
    n_jobs=7,
    verbose=True,
)

In [None]:
fitter.sample_posterior()

In [None]:
fitter.load_model_from_pkl(f"{output_folder}/BPASS_Chab_LogNorm_5_z_12_phot_grid2_redshift")

### Let's try another one - two parameter optimization problem, with a nsf model from the lampe engine.

In [None]:
grid_path = (
    f"""{grid_folder}/grid_Pop_II_LogNormal_SFH_5_z_12_logN_4.0_BPASS_Chab_Calzetti_v1.hdf5"""  # noqa: E501
)

fitter = SBI_Fitter.init_from_hdf5(
    "BPASS_Chab_LogNorm_5_z_12_optimize_v2",
    grid_path,
    return_output=False,
    device="cuda:0",
)

In [None]:
fitter.create_feature_array_from_raw_photometry(extra_features=[], normalize_method=None);

In [None]:
fitter.optimize_sbi(
    study_name="BPASS_Chab_LogNorm_5_z_12_optimize_v2",
    suggested_hyperparameters={
        "learning_rate": [1e-5, 1e-3],
        "hidden_features": [12, 200],
        "num_transforms": [2, 16],
        "training_batch_size": [32, 128],
        "stop_after_epochs": [10, 30],
    },
    n_trials=40,
    n_jobs=1,
    verbose=True,
    score_metrics=["log_prob", "tarp"],
    direction=["maximize", "minimize"],
    persistent_storage=True,
    timeout_minutes_trial_sampling=15,  # minutes
    fixed_hyperparameters={
        "model_type": "nsf",
        "backend": "lampe",
        "n_nets": 1,
    },
)