# Custom Training Loop

Synference implements a custom training loop (within the `custom_runner.py` submodule) which offers more flexibility than the built-in training loops of `sbi` and `LtU-ILI`. This allows users to implement advanced training strategies, such as custom optimizers (e.g. AdamW), and several quality of life features, such as model caching during training to avoid losing progress in case of interruptions.

Crucially, this custom training loop is also directly integrated with Optuna for hyperparameter optimization, allowing users to easily perform hyperparameter searches while training their models. It reports training progress to Optuna, allowing users to monitor the performance of different hyperparameter configurations in real-time, and allows pruning of unpromising trials based on intermediate results.

The interface to this is still in beta, and will be stabilized in future releases and brought into line with the rest of Synference's API.

Currently the custom training loop is configured via a YAML configuration file, which specifies the training parameters, optimizer settings, and other options. An example configuration file is provided below:

```yaml

train_args:
  skip_optimization: True
  validation_fraction: 0.1
  fixed_params:
    model_choice: "nsf" # Must be a list
    optimizer_choice: "Adam" # Must be a list
    learning_rate: 0.0007460108070908076
    training_batch_size: 79
    stop_after_epochs: 57
    clip_max_norm: 6.656577606872957
    nsf_hidden_features: 30
    nsf_num_transforms: 14


```

This configuration file does not perform optimization, but instead trains a model with specified hyperparameters. To perform hyperparameter optimization, an example configuration file is provided below:

```yaml

train_args:
  skip_optimization: False
  validation_fraction: 0.1
  optuna:
    n_trials: 50
    build_final_model: False
    objective:
      metric: 'log_prob'
    study:
      study_name: ""
      storage: ""
      direction: 'maximize'
      load_if_exists: True
    pruner:
      type: Median
      n_startup_trials: 10
      n_warmup_steps: 30
      interval_steps: 10
      n_min_trials: 10
      #max_resource: 1000
      #reduction_factor: 3
      #min_resource: 10
      #bootstrap_count: 10
    search_space:
      model_choice: ["mdn", "nsf"] # Must be a list
      optimizer_choice: ["AdamW", "Adam"] # Must be a list
      learning_rate:
        type: "float"
        low: 1e-6
        high: 5e-2
        log: True
      training_batch_size:
        type: "int"
        low: 32
        high: 256
      stop_after_epochs:
        type: "int"
        low: 10
        high: 60
      clip_max_norm:
        type: "float"
        low: 0.1
        high: 10.0
      models:
        nsf: 
          hidden_features:
            type: "int"
            low: 10
            high: 100
          num_transforms:
            type: "int"
            low: 3
            high: 128
        mdn: 
          hidden_features:
            type: "int"
            low: 10
            high: 200
          num_components:
            type: "int"
            low: 10
            high: 600

```


This configuration file will perform hyperparameter optimization over the specified search space, using the median pruner to prune unpromising trials. The objective metric is set to 'log_prob', meaning that the optimization will aim to maximize the log probability of the validation data under the trained model.

If you are performing optimization parallelized across multiple nodes, you will probably  want to run an external SQL database which supports concurrent connections, such as PostgreSQL or MySQL, and provide the appropriate connection string in the `storage` field of the `study` section of the configuration file, or directly via the `sql_db_path` argument of the `run_single_sbi()` method.

```python

fitter.run_single_sbi(...,
    custom_config_yaml="path/to/custom_config.yaml",
    sql_db_path='mysql+pymysql://root:password@url:port/study_name'
)

```

We recommend the Optuna dashboard for monitoring the progress of hyperparameter optimization. This can be launched using the following command:

```bash
optuna-dashboard your_sql_link
```