## optuna を tensorflow.keras で使用してみる
- Reference
  - https://github.com/optuna/optuna/blob/master/examples/pruning/tfkeras_integration.py

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
import optuna
from optuna.integration import TFKerasPruningCallback

In [14]:
BATCHSIZE = 16 
CLASSES = 10
EPOCHS = 3
N_TRAIN_EXAMPLES = 100
STEPS_PER_EPOCH = int(N_TRAIN_EXAMPLES / BATCHSIZE / 2)
VALIDATION_STEPS = 2 
print(STEPS_PER_EPOCH)

3


In [15]:
def train_dataset():
    ds = tfds.load('mnist', split=tfds.Split.TRAIN, shuffle_files=True)
    ds = ds.map(lambda x: (tf.cast(x['image'], tf.float32) / 255.0, x['label']))
    ds = ds.repeat().shuffle(1024).batch(BATCHSIZE)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return ds

In [16]:
def eval_dataset():
    ds = tfds.load('mnist', split=tfds.Split.TEST, shuffle_files=False)
    ds = ds.map(lambda x: (tf.cast(x['image'], tf.float32) / 255.0, x['label']))
    ds = ds.repeat().batch(BATCHSIZE)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

    return ds

In [17]:
def create_model(trial):
    lr = trial.suggest_loguniform('lr', 1e-4, 1e-1)
    momentum = trial.suggest_uniform('momentum', 0.0, 1.0)
    units = trial.suggest_categorical('units', [32, 64, 128, 256, 512])

    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(units=units, activation=tf.nn.relu))
    model.add(tf.keras.layers.Dense(CLASSES, activation=tf.nn.softmax))

    model.compile(
        optimizer=tf.keras.optimizers.SGD(lr=lr, momentum=momentum, nesterov=True),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
    )

    return model

In [18]:
def objective(trial):
    tf.keras.backend.clear_session()

    if tf.__version__ >= "2":
        monitor = "val_accuracy"
    else:
        monitor = "val_acc"

    model = create_model(trial)

    ds_train = train_dataset()
    ds_eval = eval_dataset()

    callbacks = [
        tf.keras.callbacks.EarlyStopping(patience=3),
        TFKerasPruningCallback(trial, monitor),
    ]

    history = model.fit(
        ds_train,
        epochs=EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_data=ds_eval,
        validation_steps=VALIDATION_STEPS,
        callbacks=callbacks,
    )

    return history.history[monitor][-1]

In [19]:
def show_result(study):
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

    print('Study statistics: ')
    print('  Number of finished trials: ', len(study.trials))
    print('  Number of pruned trials: ', len(pruned_trials))
    print('  Number of complete trials: ', len(complete_trials))

    print('Best trial:')
    trial = study.best_trial

    print('  Value: ', trial.value)

    print('  Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))

In [20]:
def main():
    study = optuna.create_study(
        direction='maximize', pruner=optuna.pruners.MedianPruner(n_startup_trials=2)
    )

    study.optimize(objective, n_trials=5, timeout=60)

    show_result(study)

In [21]:
main()

Train for 3 steps, validate for 2 steps
Epoch 1/3
Epoch 2/3






[W 2020-06-20 18:01:16,530] Setting status of trial#0 as TrialState.FAIL because of the following error: InternalError()
Traceback (most recent call last):
  File "/Users/tayutaedomo/project/ML/tensorflow-sandbox/venv/lib/python3.7/site-packages/optuna/study.py", line 734, in _run_trial
    result = func(trial)
  File "<ipython-input-18-6c3070c1f72c>", line 25, in objective
    callbacks=callbacks,
  File "/Users/tayutaedomo/project/ML/tensorflow-sandbox/venv/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
    use_multiprocessing=use_multiprocessing)
  File "/Users/tayutaedomo/project/ML/tensorflow-sandbox/venv/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 337, in fit
    eval_data_iter = iter(validation_dataset)
  File "/Users/tayutaedomo/project/ML/tensorflow-sandbox/venv/lib/python3.7/site-packages/tensorflow_core/python/data/ops/dataset_ops.py", line 332, in __iter__
    return iterator_ops.Itera

InternalError: Cache should only be read after it has been completed. [Op:MakeIterator]