In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

import optuna
from optuna.integration import TFKerasPruningCallback

In [2]:
BATCHSIZE = 128
CLASSES = 10
EPOCHS = 40
N_TRAIN_EXAMPLES = 3000
STEPS_PER_EPOCH = int(N_TRAIN_EXAMPLES / BATCHSIZE / 10)
VALIDATION_STEPS = 30

In [3]:
def train_dataset():

    ds = tfds.load("mnist", split=tfds.Split.TRAIN, shuffle_files=True)
    ds = ds.map(lambda x: (tf.cast(x["image"], tf.float32) / 255.0, x["label"]))
    ds = ds.repeat().shuffle(1024).batch(BATCHSIZE)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return ds

def eval_dataset():

    ds = tfds.load("mnist", split=tfds.Split.TEST, shuffle_files=False)
    ds = ds.map(lambda x: (tf.cast(x["image"], tf.float32) / 255.0, x["label"]))
    ds = ds.repeat().batch(BATCHSIZE)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return ds

In [4]:
def create_model(trial):

    # Hyperparameters to be tuned by Optuna.
    lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)
    momentum = trial.suggest_float("momentum", 0.0, 1.0)
    units = trial.suggest_categorical("units", [32, 64, 128, 256, 512])

    # Compose neural network with one hidden layer.
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(units=units, activation=tf.nn.relu))
    model.add(tf.keras.layers.Dense(CLASSES, activation=tf.nn.softmax))

    # Compile model.
    model.compile(
        optimizer=tf.keras.optimizers.SGD(lr=lr, momentum=momentum, nesterov=True),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    return model

In [5]:
def objective(trial):
    # Clear clutter from previous TensorFlow graphs.
    tf.keras.backend.clear_session()

    # Metrics to be monitored by Optuna.
    if tf.__version__ >= "2":
        monitor = "val_accuracy"
    else:
        monitor = "val_acc"

    # Create tf.keras model instance.
    model = create_model(trial)

    # Create dataset instance.
    ds_train = train_dataset()
    ds_eval = eval_dataset()

    # Create callbacks for early stopping and pruning.
    callbacks = [
        tf.keras.callbacks.EarlyStopping(patience=3),
        TFKerasPruningCallback(trial, monitor),
    ]

    # Train model.
    history = model.fit(
        ds_train,
        epochs=EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_data=ds_eval,
        validation_steps=VALIDATION_STEPS,
        callbacks=callbacks,
    )

    return history.history[monitor][-1]

In [6]:
def show_result(study):

    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

In [7]:
 study = optuna.create_study(
        direction="maximize", pruner=optuna.pruners.MedianPruner(n_startup_trials=2), study_name = "optuna_demo_study_01"
    )

[32m[I 2021-01-25 11:37:14,176][0m A new study created in memory with name: optuna_demo_study_01[0m


In [8]:
study.optimize(objective, n_trials=20, timeout=600)

Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


[32m[I 2021-01-25 11:37:25,448][0m Trial 0 finished with value: 0.22552083432674408 and parameters: {'lr': 0.002741400457556893, 'momentum': 0.15695769147814376, 'units': 32}. Best is trial 0 with value: 0.22552083432674408.[0m


Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


[32m[I 2021-01-25 11:37:35,225][0m Trial 1 finished with value: 0.4559895694255829 and parameters: {'lr': 0.002416511777162246, 'momentum': 0.4067140265100664, 'units': 256}. Best is trial 1 with value: 0.4559895694255829.[0m


Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40

[32m[I 2021-01-25 11:37:38,781][0m Trial 2 pruned. Trial was pruned at epoch 12.[0m


Epoch 1/40

[32m[I 2021-01-25 11:37:39,596][0m Trial 3 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40

[32m[I 2021-01-25 11:37:40,598][0m Trial 4 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40

[32m[I 2021-01-25 11:37:41,436][0m Trial 5 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40




Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


[32m[I 2021-01-25 11:37:51,235][0m Trial 6 finished with value: 0.8791666626930237 and parameters: {'lr': 0.0901486222345401, 'momentum': 0.26264387865128735, 'units': 512}. Best is trial 6 with value: 0.8791666626930237.[0m


Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


[32m[I 2021-01-25 11:38:01,108][0m Trial 7 finished with value: 0.8645833134651184 and parameters: {'lr': 0.036391395623257314, 'momentum': 0.6022229317046285, 'units': 128}. Best is trial 6 with value: 0.8791666626930237.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:01,920][0m Trial 8 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:02,777][0m Trial 9 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


[32m[I 2021-01-25 11:38:12,401][0m Trial 10 finished with value: 0.8630208373069763 and parameters: {'lr': 0.07432228769287613, 'momentum': 0.060365671313782066, 'units': 512}. Best is trial 6 with value: 0.8791666626930237.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:13,418][0m Trial 11 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:14,267][0m Trial 12 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40

[32m[I 2021-01-25 11:38:16,655][0m Trial 13 pruned. Trial was pruned at epoch 7.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:17,518][0m Trial 14 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:18,523][0m Trial 15 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:19,342][0m Trial 16 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40






[32m[I 2021-01-25 11:38:20,284][0m Trial 17 pruned. Trial was pruned at epoch 0.[0m


Epoch 1/40




Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40

[32m[I 2021-01-25 11:38:23,547][0m Trial 18 pruned. Trial was pruned at epoch 10.[0m


Epoch 1/40

[32m[I 2021-01-25 11:38:24,387][0m Trial 19 pruned. Trial was pruned at epoch 0.[0m


In [9]:
show_result(study)

Study statistics: 
  Number of finished trials:  20
  Number of pruned trials:  15
  Number of complete trials:  5
Best trial:
  Value:  0.8791666626930237
  Params: 
    lr: 0.0901486222345401
    momentum: 0.26264387865128735
    units: 512


In [11]:
optuna.visualization.plot_optimization_history(study)