# Online Training

A variation of SBI is Sequential Neural Posterior Estimation (SNPE), where the model is trained online, i.e., in multiple rounds. In each round, simulations are generated from the current posterior estimate, and the model is updated with these new simulations. This approach can be more efficient in terms of simulation budget, especially when the prior is broad and the posterior is narrow.

However, the model is no longer amortized, as it is specialized to the specific observation after training. This means that a new model must be trained for each new observation, which can be computationally expensive if many observations need to be analyzed.

In [None]:
from synference import SBI_Fitter, test_data_dir

fitter = SBI_Fitter.init_from_hdf5(
    model_name="test", hdf5_path=f"{test_data_dir}/example_model_library.hdf5"
)

fitter.create_feature_array();

Now we will recreate the simulator from the grid data stored in the HDF5 file.

In [None]:
fitter.recreate_simulator_from_library(
    override_library_path=f"{test_data_dir}/example_model_library.hdf5",
    override_grid_path="test_grid.hdf5",
);

Now we can choose an observation for our multiple rounds of online training. Here, we will randomly select one of the simulations from our grid as the observation.

In [None]:
index = 20
sample = fitter.feature_array[index]
true_params = fitter.fitted_parameter_array[index]

sample

Now we can run our online SBI model - to do this we set `learning_type` to 'online', specify the number of online rounds with 'num_online_rounds', and provide our chosen observation with 'online_training_xobs'. We also set the number of simulations per round with 'num_simulations'. The engine is set to 'SNPE' to use Sequential Neural Posterior Estimation, but SNLE and SNRE are also available for online training.

In [None]:
fitter.run_single_sbi(
    online_training_xobs=sample,
    learning_type="online",
    engine="SNPE",
    num_simulations=1000,
    num_online_rounds=4,
    override_prior_ranges={"peak_age": (10, 1000)},
    evaluate_model=False,
    plot=False,
);

Now we can specifically see how the model performs on the conditioned observation.

In [None]:
samples = fitter.sample_posterior(X_test=sample)

fitter.plot_posterior(
    X=sample,
    y=samples,
    num_samples=1000,
)