Skip to content

Latest commit

 

History

History
90 lines (64 loc) · 4.4 KB

pruning.rst

File metadata and controls

90 lines (64 loc) · 4.4 KB

Pruning Unpromising Trials

This feature automatically stops unpromising trials at the early stages of the training (a.k.a., automated early-stopping). Optuna provides interfaces to concisely implement the pruning mechanism in iterative training algorithms.

Activating Pruners

To turn on the pruning feature, you need to call ~optuna.trial.Trial.report and ~optuna.trial.Trial.should_prune after each step of the iterative training. ~optuna.trial.Trial.report periodically monitors the intermediate objective values. ~optuna.trial.Trial.should_prune decides termination of the trial that does not meet a predefined condition.

"""filename: prune.py"""

import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection

import optuna

def objective(trial):
    iris = sklearn.datasets.load_iris()
    classes = list(set(iris.target))
    train_x, test_x, train_y, test_y = \
        sklearn.model_selection.train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)

    alpha = trial.suggest_loguniform('alpha', 1e-5, 1e-1)
    clf = sklearn.linear_model.SGDClassifier(alpha=alpha)

    for step in range(100):
        clf.partial_fit(train_x, train_y, classes=classes)

        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(test_x, test_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune(step):
            raise optuna.structs.TrialPruned()

    return 1.0 - clf.score(test_x, test_y)

# Set up the median stopping rule as the pruning condition.
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)

Executing the script above:

$ python prune.py
[I 2018-11-21 17:27:57,836] Finished a trial resulted in value: 0.052631578947368474. Current best value is 0.052631578947368474 with parameters: {'alpha': 0.011428158279113485}.
[I 2018-11-21 17:27:57,963] Finished a trial resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.01862693201743629}.
[I 2018-11-21 17:27:58,164] Finished a trial resulted in value: 0.21052631578947367. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.01862693201743629}.
[I 2018-11-21 17:27:58,333] Finished a trial resulted in value: 0.02631578947368418. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.01862693201743629}.
[I 2018-11-21 17:27:58,617] Finished a trial resulted in value: 0.23684210526315785. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.01862693201743629}.
[I 2018-11-21 17:27:58,642] Setting trial status as TrialState.PRUNED.
[I 2018-11-21 17:27:58,666] Setting trial status as TrialState.PRUNED.
[I 2018-11-21 17:27:58,675] Setting trial status as TrialState.PRUNED.
[I 2018-11-21 17:27:59,183] Finished a trial resulted in value: 0.39473684210526316. Current best value is 0.02631578947368418 with parameters: {'alpha': 0.01862693201743629}.
[I 2018-11-21 17:27:59,202] Setting trial status as TrialState.PRUNED.
...

We can see Setting trial status as TrialState.PRUNED in the log messages. This means several trials are stopped before they finish all iterations.

Integration Modules for Pruning

To implement pruning mechanism in much simpler forms, Optuna provides integration modules for the following libraries.

  • XGBoost: optuna.integration.XGBoostPruningCallback
  • LightGBM: optuna.integration.LightGBMPruningCallback
  • Chainer: optuna.integration.ChainerPruningExtension
  • Keras: optuna.integration.KerasPruningCallback

For example, ~optuna.integration.XGBoostPruningCallback introduces pruning without directly changing the logic of training iteration. (See also example for the entire script.)

pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-error')
bst = xgb.train(param, dtrain, n_round, evals=[(dtest, 'validation')],
                callbacks=[pruning_callback])