In [None]:
# https://github.com/fchollet/keras/blob/master/examples/mnist_mlp.py
'''Trains a simple deep NN on the MNIST dataset.
Gets to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.

Andras: 
2 seconds per epoch on CPU. (AMD Ryzen 1800X)
'''

from __future__ import print_function

import tensorflow as tf
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop


batch_size = 4096
num_classes = 10
epochs = 20

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))

model.summary()

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

# Now something ugly. We start training and stop immediately. This does some kind of initialization,
# which is not done properly inside a thread. Possible global/local namespace issues with the backend?
killer_callback = keras.callbacks.LambdaCallback(on_batch_begin=lambda batch,logs: model.__setattr__('stop_training', True))
_ = model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=0,
              validation_data=(x_test, y_test),
              callbacks=[killer_callback])

In [None]:
# Suggested reading:
# https://keras.io/getting-started/faq/#how-can-i-interrupt-training-when-the-validation-loss-isnt-decreasing-anymore
# https://keras.io/callbacks

import threading
from IPython.display import display
import ipywidgets as widgets
from ipywidgets import VBox, HBox
from keras.callbacks import LambdaCallback
import time

initial_epoch = 0
stop_training = False

text = widgets.HTML('Starting ...')
batch_counter = widgets.HTML()
progress = widgets.FloatProgress(value=0.0, min=0.0, max=1.0)
batch_progress = widgets.FloatProgress(value=0.0, min=0.0, max=1.0)
button = widgets.ToggleButton(description="STOP! Hammer Time",
                              tooltip='Kill the training in progress',
                              button_style='danger',
                              value=True)

gui = VBox([HBox([progress, text]), HBox([batch_progress, batch_counter]), button])

def button_clicked(b):
    global stop_training
    if button.value == False:
#         model.stop_training = True
        stop_training = True
        button.description = f'Waiting for epoch {initial_epoch}'
        button.button_style = 'warning'
        button.tooltip = 'Do not click or else kittens die. You have been warned ...'
    elif button.value == True:
        stop_training = False
        thread = threading.Thread(target=model_fit, args=(initial_epoch,))
        thread.start()
        button.button_style = 'danger'
        button.description = 'STOP! Hammer Time'
        button.tooltip = 'Kill the training in progress'


button.observe(button_clicked, 'value')

n_batches = x_train.shape[0] / batch_size

#     on_epoch_begin and on_epoch_end expect two positional arguments: epoch, logs
#     on_batch_begin and on_batch_end expect two positional arguments: batch, logs
#     on_train_begin and on_train_end expect one positional argument: logs

def on_train_begin(logs):
    progress.value = 1 / epochs / 3
    text.value = f'Epoch {initial_epoch} running ...'


def on_train_end(logs):
    global stop_training
    stop_training = False


def on_epoch_begin(epoch, logs):
    batch_progress.value = 0
    batch_counter.value = f'Batch 0/{int(n_batches)}'
    progress.value = (epoch + 0.1) / epochs


def on_batch_begin(batch, logs):
    batch_counter.value = f'Batch {batch}/{int(n_batches)}'
    batch_progress.value = batch / n_batches
    
    
def on_epoch_end(epoch, logs):
    global initial_epoch, stop_training
    batch_counter.value = f'Batch {int(n_batches)}/{int(n_batches)}'
    batch_progress.value = 1.0
    text.value = f'Epoch {epoch}/{epochs - 1}, Loss = {logs["loss"]:.4f}'
    progress.value = (epoch + 1) / epochs
    initial_epoch = epoch + 1
    batch_progress.value = 1.0
    if stop_training == True:
        model.stop_training = True
        button.description = 'Restart'
        button.button_style = 'info'
        button.tooltip = 'Click to restart training'
#     if epoch + 1 >= 3:
#         model.stop_training = True


progressbar_callback = LambdaCallback(on_train_begin=on_train_begin,
                                      on_train_end=on_train_end,
                                      on_epoch_begin=on_epoch_begin,
                                      on_epoch_end=on_epoch_end,
                                      on_batch_begin=on_batch_begin,)

In [None]:
def model_fit(initial_epoch):
    history = model.fit(x_train, y_train,
                        batch_size=batch_size,
                        epochs=epochs,
                        verbose=0,
                        validation_data=(x_test, y_test),
                        callbacks=[progressbar_callback],
                        initial_epoch=initial_epoch)


thread = threading.Thread(target=model_fit, args=(initial_epoch, ))
display(gui)
thread.start()

In [None]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])