<a href="https://colab.research.google.com/github/peterjsadowski/keras_tutorial/blob/master/4_keras_mnist_SHERPA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
SHERPA is a Python library for hyperparameter tuning of machine learning models.
Copyright (C) 2018  Lars Hertel, Peter Sadowski, and Julian Collado.

This file is part of SHERPA.

SHERPA is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

SHERPA is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with SHERPA.  If not, see <http://www.gnu.org/licenses/>.

INSTALLATION: 
pip install parameter-sherpa
pip install flask==0.12.2 # Newer version of flash leads to error: 'io.UnsupportedOperation: not writable'
"""
!pip install parameter-sherpa
!pip install flask==0.12.2

from __future__ import print_function
import sherpa
import time
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.datasets import mnist



Using TensorFlow backend.


In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
num_classes = 10

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

60000 train samples
10000 test samples


In [17]:
# Setup Sherpa study object to explore learning rate and momentum.
parameters = [sherpa.Continuous('lrinit', [0.1, 0.001], 'log'),
              sherpa.Continuous('momentum', [0., 0.99])]

algorithm = sherpa.algorithms.BayesianOptimization(max_num_trials=10)

study = sherpa.Study(parameters=parameters,
                     algorithm=algorithm,
                     lower_is_better=True)

INFO:sherpa.core:
-------------------------------------------------------
SHERPA Dashboard running on http://172.28.0.2:8880
-------------------------------------------------------


In [18]:
dbatch_size = 32
epochs = 1

for trial in study:
    print("Trial {}:\t{}".format(trial.id, trial.parameters))

    model = Sequential()
    model.add(Dense(units=30, activation='relu', input_dim=784))
    model.add(Dense(units=10, activation='softmax'))
    
    optimizer = keras.optimizers.SGD(lr=trial.parameters['lrinit'], 
                                     momentum=trial.parameters['momentum'])
    model.compile(loss='categorical_crossentropy', 
                  optimizer=optimizer,
                  metrics=['accuracy'])
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=0,
              callbacks=[study.keras_callback(trial, objective_name='val_loss')],
              validation_data=(x_test, y_test))

    study.finalize(trial=trial)
    
print(study.get_best_result())

Trial 1:	{'lrinit': 0.021544346900318846, 'momentum': 0.32999999999999996}
Trial 2:	{'lrinit': 0.021544346900318846, 'momentum': 0.6599999999999999}
Trial 3:	{'lrinit': 0.004641588833612782, 'momentum': 0.32999999999999996}
Trial 4:	{'lrinit': 0.004641588833612782, 'momentum': 0.6599999999999999}
Trial 5:	{'lrinit': 0.1, 'momentum': 0.99}
Trial 6:	{'lrinit': 0.013236310084867153, 'momentum': 0.537857299945472}
Trial 7:	{'lrinit': 0.07655490914789083, 'momentum': 0.0}
Trial 8:	{'lrinit': 0.017150394330641344, 'momentum': 0.0}
Trial 9:	{'lrinit': 0.09605670116060591, 'momentum': 0.2722226589033094}
{'Trial-ID': 9, 'Iteration': 0, 'lrinit': 0.09605670116060591, 'momentum': 0.2722226589033094, 'Objective': 0.19574639337956906}
