# Tutorial 8: Hyperparameter Optimization

To automatically tune hyperparameters in a `synthcity` plugin to generate more realistic data, we use hyperparameter optimization (HPO) algorithms such as Tree-structured Parzen estimators (TPE), Bayesian optimization, and genetic programming. In this tutorial we will use `optuna`, a very popular HPO library implementing TPE, to tune the hyperparameters of the `nflow` plugin to synthesize the diabetes dataset.

This tutorial requires the third party library `plotly` to be installed. This is not included in synthcity, as this tutorial is the only place it is needed. So in order to run this tutorial you will need to run `pip install plotly` as well as install synthcity.

In [None]:
!pip install synthcity
!pip install plotly
!pip uninstall -y torchaudio torchdata

In [None]:
# stdlib
import sys
import warnings

# third party
import optuna
from sklearn.datasets import load_diabetes

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

## Load the dataset

In [None]:
X, y = load_diabetes(return_X_y=True, as_frame=True)
X["target"] = y
X

In [None]:
loader = GenericDataLoader(
    X,
    target_column="target",
    sensitive_columns=["sex"],
)
train_loader, test_loader = loader.train(), loader.test()

## Load the plugin class

In [None]:
PLUGIN = "tvae"
plugin_cls = type(Plugins().get(PLUGIN))
plugin_cls

## Display the hyperparameter space

In [None]:
plugin_cls.hyperparameter_space()

## Use a trial to suggest a set of hyperparameters

In [None]:
from synthcity.utils.optuna_sample import suggest_all

trial = optuna.create_study().ask()
params = suggest_all(trial, plugin_cls.hyperparameter_space())
params['n_iter'] = 100  # speed up
params

## Evaluate the plugin with the suggested hyperparameters

In [None]:
from synthcity.benchmark import Benchmarks

plugin = plugin_cls(**params).fit(train_loader)
report = Benchmarks.evaluate(
    [("trial", PLUGIN, params)],
    train_loader,  # Benchmarks.evaluate will split out a validation set
    repeats=1,
    metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
)
report['trial']

## Create an Optuna study and optimize the hyperparameters

In [None]:
def objective(trial: optuna.Trial):
    hp_space = Plugins().get(PLUGIN).hyperparameter_space()
    hp_space[0].high = 100  # speed up for now
    params = suggest_all(trial, hp_space)
    ID = f"trial_{trial.number}"
    try:
        report = Benchmarks.evaluate(
            [(ID, PLUGIN, params)],
            train_loader,
            repeats=1,
            metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
        )
    except Exception as e:  # invalid set of params
        print(f"{type(e).__name__}: {e}")
        print(params)
        raise optuna.TrialPruned()
    score = report[ID].query('direction == "minimize"')['mean'].mean()
    # average score across all metrics with direction="minimize"
    return score

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=2)
study.best_params

## Visualize the study

In [None]:
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_slice

plot_optimization_history(study)

In [None]:
# Visualize high-dimensional parameter relationships. 
plot_parallel_coordinate(study)

In [None]:
# Visualize hyperparameter relationships.
fig = plot_contour(study, params=['batch_size', 'lr', 'encoder_dropout', 'decoder_dropout'])
fig.update_layout(width=800, height=800)

In [None]:
# Visualize individual hyperparameters as slice plot.
plot_slice(study)

In [None]:
# Visualize parameter importances.
plot_param_importances(study)

In [None]:
# Learn which hyperparameters are affecting the trial duration with hyperparameter importance.
optuna.visualization.plot_param_importances(
    study, target=lambda t: t.duration.total_seconds(), target_name="duration"
)

In [None]:
# Visualize empirical distribution function of the objective.
plot_edf(study)

## Test performance of the optimized plugin

In [None]:
best_params = study.best_params
report = Benchmarks.evaluate(
    [("test", PLUGIN, best_params)],
    train_loader,
    test_loader,
    repeats=1,
    metrics={"detection": ["detection_mlp", "detection_xgb"]},  # DELETE THIS LINE FOR ALL METRICS
)
Benchmarks.print(report)

## Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub

- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.


### Checkout other projects from vanderschaarlab
- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
