## Successive Halving with Optuna

In this notebook, we'll carry out [successive halving](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.pruners.SuccessiveHalvingPruner.html) with Optuna.

In [1]:
import numpy as np

import optuna

from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import xgboost as xgb

In [2]:
# load dataset and prepare data

data, target = load_breast_cancer(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(
    data, target, test_size=0.25)

dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

## Define the objective function

Check out [xgboost pruning integration](https://optuna-integration.readthedocs.io/en/stable/reference/generated/optuna_integration.XGBoostPruningCallback.html#optuna_integration.XGBoostPruningCallback)

Code below based on https://github.com/optuna/optuna-examples/blob/main/xgboost/xgboost_integration.py

In [3]:
def objective(trial):

    # hyperparameter space
    param = {
        "verbosity": 0,
        "objective": "binary:logistic",
        "eval_metric": "auc",
        "booster": trial.suggest_categorical("booster", ["gbtree", "gblinear", "dart"]),
        "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True),
        "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True),
    }

    # conditional space: some hyperparams depend on other hyperparams
    if param["booster"] == "gbtree" or param["booster"] == "dart":
        param["max_depth"] = trial.suggest_int("max_depth", 1, 9)
        param["eta"] = trial.suggest_float("eta", 1e-8, 1.0, log=True)
        param["gamma"] = trial.suggest_float("gamma", 1e-8, 1.0, log=True)
        param["grow_policy"] = trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"])
    
    if param["booster"] == "dart":
        param["sample_type"] = trial.suggest_categorical("sample_type", ["uniform", "weighted"])
        param["normalize_type"] = trial.suggest_categorical("normalize_type", ["tree", "forest"])
        param["rate_drop"] = trial.suggest_float("rate_drop", 1e-8, 1.0, log=True)
        param["skip_drop"] = trial.suggest_float("skip_drop", 1e-8, 1.0, log=True)

    # Add a callback for pruning.
    # This is the stopping criteria for successive halving: the accuracy after each round of training
    pruning_callback = optuna.integration.XGBoostPruningCallback(trial, "validation-auc")
    
    # set up the model
    bst = xgb.train(param, dtrain, evals=[(dtest, "validation")], callbacks=[pruning_callback])
    
    # evaluate
    preds = bst.predict(dtest)
    pred_labels = np.rint(preds)
    accuracy = accuracy_score(y_test, pred_labels)
    
    return accuracy

In the following code, we'll train **30 initial configurations**.

As opposed to scikit-learn, Optuna does not find winning configurations and pass them to the next round, where they'll be trained with more resources.

Instead, it stops configurations that do not look promising. Like this, all in all, it trains less models.

The initial 30 configurations are sampled at random from the hyperparameter space.

In [4]:
study = optuna.create_study(
    
    # a way to sample hyperparameters to create the configurations
    sampler=optuna.samplers.RandomSampler(),
    
    # successive halving
    pruner=optuna.pruners.SuccessiveHalvingPruner(
        # controls the minimum validation rounds that it needs to wait before stopping
        min_resource=1,
        
        reduction_factor=3,
        
        # Minimum number of trials that need to complete a rung before any trial
        # is considered for promotion
        bootstrap_count = 0,
    ),
    
    direction="maximize",    
)


study.optimize(
    objective, 
    
    # the number of initial configurations
    n_trials=30, 
)

[I 2024-09-21 13:17:59,962] A new study created in memory with name: no-name-bbaf80d1-4d73-43df-8ae2-00b40261c1f7


[0]	validation-auc:0.96795
[1]	validation-auc:0.96795
[2]	validation-auc:0.96795
[3]	validation-auc:0.96795
[4]	validation-auc:0.96795
[5]	validation-auc:0.96795
[6]	validation-auc:0.96795
[7]	validation-auc:0.96795
[8]	validation-auc:0.96795
[9]	validation-auc:0.96795


[I 2024-09-21 13:18:00,081] Trial 0 finished with value: 0.5944055944055944 and parameters: {'booster': 'dart', 'lambda': 3.5647914181848e-07, 'alpha': 0.00025407909433436827, 'max_depth': 3, 'eta': 3.5280216255704914e-06, 'gamma': 7.739700007568005e-07, 'grow_policy': 'depthwise', 'sample_type': 'uniform', 'normalize_type': 'forest', 'rate_drop': 2.769831656837777e-06, 'skip_drop': 0.3128486233462904}. Best is trial 0 with value: 0.5944055944055944.


[0]	validation-auc:0.87302


[I 2024-09-21 13:18:00,100] Trial 1 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.96957
[1]	validation-auc:0.97201
[2]	validation-auc:0.97282
[3]	validation-auc:0.97282
[4]	validation-auc:0.97424
[5]	validation-auc:0.97566
[6]	validation-auc:0.97627
[7]	validation-auc:0.97688
[8]	validation-auc:0.97769
[9]	validation-auc:0.97789


[I 2024-09-21 13:18:00,150] Trial 2 finished with value: 0.8951048951048951 and parameters: {'booster': 'gblinear', 'lambda': 0.0001439336484499298, 'alpha': 0.0011518192933196165}. Best is trial 2 with value: 0.8951048951048951.


[0]	validation-auc:0.94807
[1]	validation-auc:0.95477


[I 2024-09-21 13:18:00,170] Trial 3 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.96531


[I 2024-09-21 13:18:00,207] Trial 4 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.94807


[I 2024-09-21 13:18:00,246] Trial 5 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.97627
[1]	validation-auc:0.97525
[2]	validation-auc:0.97546
[3]	validation-auc:0.97667
[4]	validation-auc:0.97890
[5]	validation-auc:0.97992
[6]	validation-auc:0.97931
[7]	validation-auc:0.97890
[8]	validation-auc:0.97951
[9]	validation-auc:0.97992


[I 2024-09-21 13:18:00,304] Trial 6 finished with value: 0.9020979020979021 and parameters: {'booster': 'gblinear', 'lambda': 0.0005983439058114121, 'alpha': 4.161898803285618e-05}. Best is trial 6 with value: 0.9020979020979021.


[0]	validation-auc:0.95051


[I 2024-09-21 13:18:00,315] Trial 7 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.93418


[I 2024-09-21 13:18:00,343] Trial 8 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.96369


[I 2024-09-21 13:18:00,350] Trial 9 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.95051


[I 2024-09-21 13:18:00,373] Trial 10 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.96592
[1]	validation-auc:0.96471


[I 2024-09-21 13:18:00,381] Trial 11 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.98073
[1]	validation-auc:0.98073
[2]	validation-auc:0.97911
[3]	validation-auc:0.98093
[4]	validation-auc:0.98174
[5]	validation-auc:0.98215
[6]	validation-auc:0.98337
[7]	validation-auc:0.98357
[8]	validation-auc:0.98458
[9]	validation-auc:0.98519


[I 2024-09-21 13:18:00,437] Trial 12 finished with value: 0.9230769230769231 and parameters: {'booster': 'gblinear', 'lambda': 2.4238806760008624e-08, 'alpha': 3.074667207707852e-07}. Best is trial 12 with value: 0.9230769230769231.


[0]	validation-auc:0.97120
[1]	validation-auc:0.97708
[2]	validation-auc:0.97809
[3]	validation-auc:0.97992


[I 2024-09-21 13:18:00,450] Trial 13 pruned. Trial was pruned at iteration 3.


[0]	validation-auc:0.97951
[1]	validation-auc:0.97769
[2]	validation-auc:0.97667
[3]	validation-auc:0.97606


[I 2024-09-21 13:18:00,477] Trial 14 pruned. Trial was pruned at iteration 3.


[0]	validation-auc:0.95903


[I 2024-09-21 13:18:00,507] Trial 15 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.95051


[I 2024-09-21 13:18:00,523] Trial 16 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.50000


[I 2024-09-21 13:18:00,544] Trial 17 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.94807


[I 2024-09-21 13:18:00,576] Trial 18 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.96359
[1]	validation-auc:0.95375


[I 2024-09-21 13:18:00,602] Trial 19 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.94787
[1]	validation-auc:0.95183


[I 2024-09-21 13:18:00,632] Trial 20 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.95051


[I 2024-09-21 13:18:00,643] Trial 21 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.95051


[I 2024-09-21 13:18:00,652] Trial 22 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.92667
[1]	validation-auc:0.93083


[I 2024-09-21 13:18:00,680] Trial 23 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.95477
[1]	validation-auc:0.96775
[2]	validation-auc:0.96673


[I 2024-09-21 13:18:00,716] Trial 24 pruned. Trial was pruned at iteration 3.


[0]	validation-auc:0.95872


[I 2024-09-21 13:18:00,740] Trial 25 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.93854
[1]	validation-auc:0.94736


[I 2024-09-21 13:18:00,766] Trial 26 pruned. Trial was pruned at iteration 1.


[0]	validation-auc:0.95020
[1]	validation-auc:0.96917
[2]	validation-auc:0.98185
[3]	validation-auc:0.99097
[4]	validation-auc:0.99290
[5]	validation-auc:0.99371
[6]	validation-auc:0.99412
[7]	validation-auc:0.99493
[8]	validation-auc:0.99493
[9]	validation-auc:0.99452


[I 2024-09-21 13:18:00,849] Trial 27 finished with value: 0.951048951048951 and parameters: {'booster': 'dart', 'lambda': 0.5211794366482341, 'alpha': 5.9082205314673086e-06, 'max_depth': 7, 'eta': 0.38723109273768735, 'gamma': 0.04109685372397155, 'grow_policy': 'lossguide', 'sample_type': 'weighted', 'normalize_type': 'tree', 'rate_drop': 0.00013703669392688166, 'skip_drop': 4.418141527593644e-05}. Best is trial 27 with value: 0.951048951048951.


[0]	validation-auc:0.97323
[1]	validation-auc:0.97525
[2]	validation-auc:0.97647


[I 2024-09-21 13:18:00,866] Trial 28 pruned. Trial was pruned at iteration 3.


[0]	validation-auc:0.94807


[I 2024-09-21 13:18:00,885] Trial 29 pruned. Trial was pruned at iteration 1.


In this particular case, `min_resource` controls the callback, that is, the minimum number of iterations / validations that it needs to wait before stopping the training.

If `min_resource=4` no model will be stopped until they undergo 4 rounds of validation. 

If `min_resource="auto"` then 1 is the minimum possible. A model can be stopped if in the first round it produces a score below previous models.

Change the `min_resource` and check it out.

In [5]:
# the best hyperparameters

study.best_params

{'booster': 'dart',
 'lambda': 0.5211794366482341,
 'alpha': 5.9082205314673086e-06,
 'max_depth': 7,
 'eta': 0.38723109273768735,
 'gamma': 0.04109685372397155,
 'grow_policy': 'lossguide',
 'sample_type': 'weighted',
 'normalize_type': 'tree',
 'rate_drop': 0.00013703669392688166,
 'skip_drop': 4.418141527593644e-05}

In [6]:
# the best performance value

study.best_value

0.951048951048951

In [7]:
r = study.trials_dataframe()

r

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_alpha,params_booster,params_eta,params_gamma,params_grow_policy,params_lambda,params_max_depth,params_normalize_type,params_rate_drop,params_sample_type,params_skip_drop,system_attrs_completed_rung_0,system_attrs_completed_rung_1,system_attrs_completed_rung_2,state
0,0,0.594406,2024-09-21 13:17:59.966848,2024-09-21 13:18:00.080785,0 days 00:00:00.113937,0.0002540791,dart,3.528022e-06,7.7397e-07,depthwise,3.564791e-07,3.0,forest,2.769832e-06,uniform,0.3128486,0.967951,0.967951,0.967951,COMPLETE
1,1,0.873022,2024-09-21 13:18:00.082784,2024-09-21 13:18:00.099773,0 days 00:00:00.016989,0.0003747947,gbtree,0.0001576745,4.153013e-05,depthwise,0.01013827,1.0,,,,,0.873022,,,PRUNED
2,2,0.895105,2024-09-21 13:18:00.105774,2024-09-21 13:18:00.149745,0 days 00:00:00.043971,0.001151819,gblinear,,,,0.0001439336,,,,,,0.972008,0.972819,0.97789,COMPLETE
3,3,0.954767,2024-09-21 13:18:00.151742,2024-09-21 13:18:00.170734,0 days 00:00:00.018992,0.9382316,gblinear,,,,4.147224e-08,,,,,,0.954767,,,PRUNED
4,4,0.965923,2024-09-21 13:18:00.173730,2024-09-21 13:18:00.206713,0 days 00:00:00.032983,0.001205494,dart,0.0197931,0.001187112,lossguide,2.597402e-08,5.0,forest,2.908401e-08,weighted,0.3256841,0.965923,,,PRUNED
5,5,0.950811,2024-09-21 13:18:00.208709,2024-09-21 13:18:00.245689,0 days 00:00:00.036980,4.497565e-07,gbtree,1.867661e-05,3.592467e-06,lossguide,8.265671e-06,7.0,,,,,0.950811,,,PRUNED
6,6,0.902098,2024-09-21 13:18:00.247691,2024-09-21 13:18:00.302674,0 days 00:00:00.054983,4.161899e-05,gblinear,,,,0.0005983439,,,,,,0.975254,0.976673,0.979919,COMPLETE
7,7,0.950507,2024-09-21 13:18:00.305655,2024-09-21 13:18:00.315649,0 days 00:00:00.009994,0.05439605,gbtree,7.858191e-06,1.856325e-05,depthwise,0.008969816,2.0,,,,,0.950507,,,PRUNED
8,8,0.934178,2024-09-21 13:18:00.316650,2024-09-21 13:18:00.343635,0 days 00:00:00.026985,0.00586637,dart,0.001685504,0.9141614,depthwise,5.390612e-06,5.0,forest,0.01476308,weighted,0.002939456,0.934178,,,PRUNED
9,9,0.957404,2024-09-21 13:18:00.344632,2024-09-21 13:18:00.350629,0 days 00:00:00.005997,0.000172296,gblinear,,,,0.004183236,,,,,,0.957404,,,PRUNED


In [8]:
# a "rung" is each round of successive halving

v = [v for v in r.columns if 'rung' in v]

30-r[v].isnull().sum()

system_attrs_completed_rung_0    30
system_attrs_completed_rung_1     9
system_attrs_completed_rung_2     5
dtype: int64

In [9]:
# some of the configurations from the last round of 
# successive halving, called "rung", were stopped early

r[~r["system_attrs_completed_rung_2"].isnull()]["state"]

0     COMPLETE
2     COMPLETE
6     COMPLETE
12    COMPLETE
27    COMPLETE
Name: state, dtype: object

In [10]:
# completely trained configurations

r[r["state"]=="COMPLETE"]["state"].count()

5

As expected, we started with 30 configurations, roughly a third passed to the second round, and roughly a third passed to the third round.

It is not exactly a third, because this is asynchronous successive halving (ASHA), and we saw that in ASHA, some suboptimal configurations would be promoted to next rounds, because we don't wait to having them all to examine the top 30%.