## Basic SBI Model Training

In this tutorial, we will walk through the process of training a simulation-based inference (SBI) model using the `synference` package. We will assume we already have a library of simulations and corresponding parameters.


First let's consider the training process more generally. The main steps involved in training an SBI model are:
1. **Prepare the Simulation Data**: Gather a set of simulations and their corresponding parameters.
2. **Choose a Model Architecture**: Select an appropriate neural network architecture for the SBI model.
3. **Define the Training Procedure**: Set up the training loop, loss function, and optimization algorithm.
4. **Train the Model**: Run the training process and monitor performance.
5. **Evaluate the Model**: Assess the trained model's performance on a validation set.

Now let's look at how to implement these steps using `synference`.


In [None]:
from synference import SBI_Fitter

From the output of the library generation tutorials, we should have a HDF5 file called `test_model_grid.hdf5` in our 'grids/' directory. If you don't have this file, please refer to the [Library Generation](../library_gen/basic_library_generation.ipynb) tutorial.

We can directly use this file to instantiate a `SBI_Fitter` instance, which is the class which handles training and evaluating SBI models in `synference`.


In [None]:
fitter = SBI_Fitter.init_from_hdf5(model_name="test", hdf5_path="test_model_grid.hdf5")

# Feature and Parameter Arrays

Now this fitter has loaded the generated observations and parameters from the HDF5 file. Note that the data is not yet normalized or set up with the correct features for training. We will handle that in the next steps.

We can see the names of the observations.

In [None]:
print(fitter.raw_observation_names)

The names of the features:

In [None]:
print(fitter.parameter_names)

and any associated parameters units:

In [None]:
print(fitter.parameter_units)

The actual array itself is stored in the `parameter_array` attribute.

In [None]:
print(fitter.parameter_array)

And a similar logic for the observations:

In [None]:
print(fitter.raw_observation_grid)

The first step is to turn this raw grid of photometric observations into a set of features that can be used for training. This is done with the `fitter.create_feature_array` method.

This method handles the following tasks:
1. Normalizing the observations (e.g., converting magnitudes to fluxes, normalizing by a reference band, etc.)
2. Creating features from the observations (e.g., colors, ratios, etc.)
3. Removing photometric bands in the library from the feature array that are not present in the observations.
3. Handling missing data (e.g., setting features to NaN if any of the required bands are missing)
4. Adding additional features (e.g., redshift) from the parameter array to the feature array.
5. Adding realistic noise to the features based on a provided noise model (see the [Noise Models](../noise_modelling/noise_models.ipynb) tutorial for more details).
6. Adding photometric uncertainties to the feature array.


The default configuration of this method doesn't do all of these however. By default, all photometric bands are kept, no additional features are added, and no noise is added. The default normalization is to convert the raw array of photometry to AB magnitudes only. 

We call the method below and we can see it prints information about the features it creates.

In [None]:
fitter.create_feature_array();

We will proceed with the default configuration for now. More advanced configurations will be covered in later tutorials. Using different normalizations/units or adding additional features can have a significant impact on the performance of the trained SBI model.

Before we do any fitting, we can inspect the feature and parameter arrays to see the distribution of the data.

Firstly we can look at the feature array, and see the distribution of the photometry given our model and feature array configuration. The below figure shows a histogram of each feature in the feature array.

In [None]:
fitter.plot_histogram_feature_array(bins=20);

Secondly we can look at the parameter array, and see the distribution of the parameters given our model and parameter configuration. The below figure shows a histogram of each parameter in the parameter array. We can see that the parameters are uniformly distributed, as expected from our library generation configuration.

In [None]:
fitter.plot_histogram_parameter_array();

# Training an SBI Model

SBI model training is handled with the `fitter.train_single_sbi` method. This method handles the following tasks:
1. Creating a prior from the parameter array.
2. Setting up the neural density estimator (NDE) for the SBI model.
3. Training the SBI model.
4. Saving the trained model to disk.
5. Plotting diagnostics of the trained model.


We will cover the various options for different SBI configurations in later tutorials. For now, we will proceed with the default configuration.

synference is built on top of the LtU-ILI package, which utilizes `sbi` and `lampe` for the underlying SBI functionality. The default NDE is a `Masked Autoregressive Flow (MAF)` from the `sbi` package. The default prior proposal is a uniform prior over the range of the parameters in the parameter array.

In [None]:
?fitter.run_single_sbi

The primary arguments to the `fitter.train_single_sbi` method are:
- `train_test_fraction`: The fraction of the data to use for training. The rest is used for validation. The default is 0.8.
- `validation_fraction`: The fraction of the training data to use for validation during training. The default is 0.2.
- `backend`: The backend to use for training. Either `sbi` or `lampe`. The default is `sbi`.
- `hidden_features`: The number of hidden features in the NDE. The default is 50.
- `num_components/transforms`: The number of components or transforms in the NDE. The default is 4.
- `training_batch_size`: The batch size for training. The default is 64.
- `stop_after_epochs`: The number of epochs with no improvement to stop training. The default is 15.

There are other methods to turn on or off plotting, model saving, validation, etc. See the docstring for more details. 

Now we will run the training, and quite a lot of things will be printed. We are setting `name_append` to 'test_1' so that the trained model is saved with a unique name. If left as the default a timestamp will be used.

In [None]:
posterior_model, stats = fitter.run_single_sbi(
    name_append="test_1", random_seed=42, hidden_features=256, num_components=64
)


The first part of the output shows we split the training data into the training and testing splits, then we create the prior from the parameter grid, and show the ranges of each parameter.

The next part shows us creating the neural density estimator (NDE) model, which is a mixture density network (MDN) with 4 components. The model is created using the `sbi` package, which is built on top of PyTorch.

Then the actual training happens - we see the training epochs increment until the model has stopped improving on the validation set. The training stops after 15 epochs with no improvement, as we set `stop_after_epochs=15`.

The model is pickled and saved to the output directory for this model, which is `models/test_model/` by default. The summary of the training model is saved as a .json file in the same directory. And the configuration of the fitter is also pickled and saved to the same directory, which saves the feature and parameter configuration used for training. A model can be re-loaded later using the `fitter.load_model_from_pkl` method. We can save in a different format by changing the `save_method` argument to e.g. `torch` or `hickle`.

Now we have a trained model. The validation metrics run which include:
1. A posterior corner plot for a random observation from the test set.
2. A loss plot which shows the training and validation loss over epochs.
3. A coverage plot which shows how well the credible intervals of the posterior match the true parameters.
4. A ranks histogram which shows how well the posterior samples match the true parameters.
5. A log_probabiity plot which shows the log probability of the true parameters under the posterior.
6. A True vs predicted plot which shows the true parameters vs the maximum a posteriori (MAP) estimate from the posterior.


These plots are shown in the output, and also saved to the `plots/` directory in the output folder for this model.

# Loading a Trained Model

We can load a trained model into an exisiting `SBI_Fitter` instance using the `fitter.load_model_from_pkl` method. This method takes the path to the pickled model file as an argument. 

If only one model is present in the directory, we can simply provide the directory path and the method will find the model file automatically. If multiple models are present, we can provide the full path to the model file.

In [None]:
fitter.load_model_from_pkl("test/test_test_1_posterior.pkl");

Alternatively, we can create a new `SBI_Fitter` instance and load the model into that instance, using the class method `load_saved_model`. This method takes the path to the pickled model file as an argument, and returns a new `SBI_Fitter` instance with the model loaded.

In [None]:
new_fitter = SBI_Fitter.load_saved_model("test/test_test_1_posterior.pkl")

# Plotting model loss

We can plot the model loss using the `fitter.plot_loss` method. This method will create a plot of the training and validation loss over epochs, and save it to the `plots/` directory in the output folder for this model. By default, it will not overwrite existing plots, but you can change this with the `overwrite` argument.

In [None]:
fitter.plot_loss(overwrite=True);

# Plotting validation metrics

Whilst it does happen automatically during training, we can also plot the validation metrics of a trained model using the `fitter.plot_diagnostics` method. You can provide your own validation set, or by default it will use the test set from the last training run. By default, it will not create existing plots in the `plots/` directory, but you can change this with the `overwrite` argument.

In [None]:
fitter.plot_diagnostics();

# Getting model metrics

We can print and save metrics of the trained model using the `fitter.evaluate_model` method. This method will print the metrics to the console, and also save them to a .json file in the output directory for this model. The metrics include:
- TARP (Tests of Accuracy with Random Points)
- Log DPIT (Logarithmic Deviation of the Probability Integral Transform)
- Mean Log Probability
- Parameter-specific metrics (MSE, RMSE, Mean Absolute Error, Median Absolute Error, R-squared, Normalized RMSE)


In [None]:
fitter.evaluate_model();

# Posterior Samples

We can sample the posterior for a given observation using the `fitter.sample_posterior` method. This method takes an observation, or a set of observations, as an argument, and returns samples from the posterior distribution. If no observation is provided, it will draw posterior samples for all observations in the test set.

In [None]:
fitter.sample_posterior()

# Next Steps

In the next tutorials, we will cover more advanced configurations for training SBI models, including different feature and parameter configurations, different NDEs, and different prior proposals. We will also cover how to use the trained models for inference on real data.