# Tutorial 8: Hyperparameter Optimization

To automatically tune hyperparameters in a `synthcity` plugin to generate more realistic data, we use hyperparameter optimization (HPO) algorithms such as Tree-structured Parzen estimators (TPE), Bayesian optimization, and genetic programming. In this tutorial we will use `optuna`, a very popular HPO library implementing TPE, to tune the hyperparameters of the `nflow` plugin to synthesize the diabetes dataset.

In [1]:
# stdlib
import sys
import warnings

# third party
import optuna
from sklearn.datasets import load_diabetes

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

    The default C++ compiler could not be found on your system.
    You need to either define the CXX environment variable or a symlink to the g++ command.
    For example if g++-8 is the command you can do
      import os
      os.environ['CXX'] = 'g++-8'
    


## Load the dataset

In [2]:
X, y = load_diabetes(return_X_y=True, as_frame=True)
X["target"] = y
X

Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,target
0,0.038076,0.050680,0.061696,0.021872,-0.044223,-0.034821,-0.043401,-0.002592,0.019907,-0.017646,151.0
1,-0.001882,-0.044642,-0.051474,-0.026328,-0.008449,-0.019163,0.074412,-0.039493,-0.068332,-0.092204,75.0
2,0.085299,0.050680,0.044451,-0.005670,-0.045599,-0.034194,-0.032356,-0.002592,0.002861,-0.025930,141.0
3,-0.089063,-0.044642,-0.011595,-0.036656,0.012191,0.024991,-0.036038,0.034309,0.022688,-0.009362,206.0
4,0.005383,-0.044642,-0.036385,0.021872,0.003935,0.015596,0.008142,-0.002592,-0.031988,-0.046641,135.0
...,...,...,...,...,...,...,...,...,...,...,...
437,0.041708,0.050680,0.019662,0.059744,-0.005697,-0.002566,-0.028674,-0.002592,0.031193,0.007207,178.0
438,-0.005515,0.050680,-0.015906,-0.067642,0.049341,0.079165,-0.028674,0.034309,-0.018114,0.044485,104.0
439,0.041708,0.050680,-0.015906,0.017293,-0.037344,-0.013840,-0.024993,-0.011080,-0.046883,0.015491,132.0
440,-0.045472,-0.044642,0.039062,0.001215,0.016318,0.015283,-0.028674,0.026560,0.044529,-0.025930,220.0


In [3]:
loader = GenericDataLoader(
    X,
    target_column="target",
    sensitive_columns=["sex"],
)
train_loader, test_loader = loader.train(), loader.test()

## Load the plugin class

In [4]:
PLUGIN = "nflow"
plugin_cls = type(Plugins().get(PLUGIN))
plugin_cls

[2023-04-08T20:56:15.722354+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T20:56:15.722354+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py


synthcity.plugins.generic.plugin_nflow.NormalizingFlowsPlugin

## Display the hyperparameter space

In [5]:
plugin_cls.hyperparameter_space()

[IntegerDistribution(name='n_iter', data=None, random_state=0, marginal_distribution=None, low=100, high=5000, step=100),
 IntegerDistribution(name='n_layers_hidden', data=None, random_state=0, marginal_distribution=None, low=1, high=10, step=1),
 IntegerDistribution(name='n_units_hidden', data=None, random_state=0, marginal_distribution=None, low=10, high=100, step=1),
 CategoricalDistribution(name='batch_size', data=None, random_state=0, marginal_distribution=None, choices=[32, 64, 128, 256, 512]),
 FloatDistribution(name='dropout', data=None, random_state=0, marginal_distribution=None, low=0.0, high=0.2),
 CategoricalDistribution(name='batch_norm', data=None, random_state=0, marginal_distribution=None, choices=[False, True]),
 CategoricalDistribution(name='lr', data=None, random_state=0, marginal_distribution=None, choices=[0.0001, 0.0002, 0.001]),
 CategoricalDistribution(name='linear_transform_type', data=None, random_state=0, marginal_distribution=None, choices=['lu', 'permutatio

## Use a trial to suggest a set of hyperparameters

In [8]:
from synthcity.utils.optuna_sample import suggest_all

trial = optuna.create_study().ask()
params = suggest_all(trial, plugin_cls.hyperparameter_space())
params['n_iter'] = 100  # speed up
params

{'n_iter': 100,
 'n_layers_hidden': 1,
 'n_units_hidden': 87,
 'batch_size': 256,
 'dropout': 0.15424246144819787,
 'batch_norm': False,
 'lr': 0.001,
 'linear_transform_type': 'svd',
 'base_transform_type': 'rq-autoregressive'}

## Evaluate the plugin with the suggested hyperparameters

In [9]:
from synthcity.benchmark import Benchmarks

plugin = plugin_cls(**params).fit(train_loader)
report = Benchmarks.evaluate(
    [("trial", PLUGIN, params)],
    train_loader,  # Benchmarks.evaluate will split out a validation set
    repeats=1,
    metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
)
report['trial']

100%|██████████| 100/100 [00:38<00:00,  2.56it/s]
[2023-04-08T20:57:54.561757+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 100/100 [00:30<00:00,  3.24it/s]


Unnamed: 0,min,max,mean,stddev,median,iqr,rounds,errors,durations,direction
detection.detection_mlp.mean,0.5,0.5,0.5,0.0,0.5,0.0,1,0,5.95,minimize


## Create an Optuna study and optimize the hyperparameters

In [14]:
def objective(trial: optuna.Trial):
    hp_space = Plugins().get(PLUGIN).hyperparameter_space()
    hp_space[0].high = 100  # speed up for now
    params = suggest_all(trial, hp_space)
    ID = f"trial_{trial.number}"
    try:
        report = Benchmarks.evaluate(
            [(ID, PLUGIN, params)],
            train_loader,
            repeats=1,
            metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
        )
    except Exception as e:  # invalid set of params
        print(f"{type(e).__name__}: {e}")
        print(params)
        raise optuna.TrialPruned()
    score = report[ID].query('direction == "minimize"')['mean'].mean()
    # average score across all metrics with direction="minimize"
    return score

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=10)
study.best_params

[2023-04-08T21:26:16.278633+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T21:26:16.301778+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 100/100 [00:39<00:00,  2.56it/s]
[2023-04-08T21:26:59.665597+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T21:26:59.684496+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 100/100 [00:26<00:00,  3.74it/s]
[2023-04-08T21:27:30.475951+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T21:27:30.495645+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 100/100 [00:04<00:00, 21.72it/s

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 42])
{'n_iter': 100, 'n_layers_hidden': 8, 'n_units_hidden': 42, 'batch_size': 32, 'dropout': 0.17540363704980177, 'batch_norm': True, 'lr': 0.001, 'linear_transform_type': 'permutation', 'base_transform_type': 'affine-coupling', 'workspace': Path('workspace'), 'random_state': 0}


 49%|████▉     | 49/100 [00:27<00:29,  1.75it/s]
[2023-04-08T21:32:10.659104+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T21:32:10.686494+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py


ValueError: Input contains NaN.
{'n_iter': 100, 'n_layers_hidden': 9, 'n_units_hidden': 25, 'batch_size': 32, 'dropout': 0.015375942265720366, 'batch_norm': False, 'lr': 0.0002, 'linear_transform_type': 'lu', 'base_transform_type': 'affine-autoregressive', 'workspace': Path('workspace'), 'random_state': 0}


100%|██████████| 100/100 [01:44<00:00,  1.05s/it]
[2023-04-08T21:34:01.698043+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T21:34:01.716717+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
  0%|          | 0/100 [00:00<?, ?it/s]
[2023-04-08T21:34:06.336252+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T21:34:06.398065+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 81])
{'n_iter': 100, 'n_layers_hidden': 3, 'n_units_hidden': 81, 'batch_size': 32, 'dropout': 0.06339754684650767, 'batch_norm': True, 'lr': 0.0002, 'linear_transform_type': 'svd', 'base_transform_type': 'affine-autoregressive', 'workspace': Path('workspace'), 'random_state': 0}


100%|██████████| 100/100 [00:22<00:00,  4.51it/s]
[2023-04-08T21:34:38.240859+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-04-08T21:34:38.271527+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 100/100 [00:56<00:00,  1.77it/s]


{'n_iter': 100,
 'n_layers_hidden': 5,
 'n_units_hidden': 67,
 'batch_size': 512,
 'dropout': 0.0654201225133338,
 'batch_norm': True,
 'lr': 0.001,
 'linear_transform_type': 'svd',
 'base_transform_type': 'affine-autoregressive'}

## Visualize the study

In [15]:
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_slice

plot_optimization_history(study)

In [16]:
# Visualize high-dimensional parameter relationships. 
plot_parallel_coordinate(study)

In [17]:
# Visualize hyperparameter relationships.
fig = plot_contour(study, params=['batch_size', 'dropout', 'n_layers_hidden', 'n_units_hidden'])
fig.update_layout(width=800, height=800)

In [18]:
# Visualize individual hyperparameters as slice plot.
plot_slice(study)

In [19]:
# Visualize parameter importances.
plot_param_importances(study)

In [20]:
# Learn which hyperparameters are affecting the trial duration with hyperparameter importance.
optuna.visualization.plot_param_importances(
    study, target=lambda t: t.duration.total_seconds(), target_name="duration"
)

In [21]:
# Visualize empirical distribution function of the objective.
plot_edf(study)

## Test performance of the optimized plugin

In [22]:
best_params = study.best_params
report = Benchmarks.evaluate(
    [("test", PLUGIN, best_params)],
    train_loader,
    test_loader,
    repeats=1,
    metrics={"detection": ["detection_mlp", "detection_xgb"]},  # DELETE THIS LINE FOR ALL METRICS
)
Benchmarks.print(report)

[2023-04-08T21:36:39.947037+0200][28420][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 100/100 [00:20<00:00,  4.87it/s]



[4m[1mPlugin : test[0m[0m


Unnamed: 0,min,max,mean,stddev,median,iqr,rounds,errors,durations
detection.detection_xgb.mean,0.988506,0.988506,0.988506,0.0,0.988506,0.0,1,0,0.18
detection.detection_mlp.mean,0.70364,0.70364,0.70364,0.0,0.70364,0.0,1,0,3.22





## Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub

- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.


### Checkout other projects from vanderschaarlab
- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
