## Early Stopping
Early stopping is a method to avoid overfitting the training set. The general idea is interrupting the training process when its performance on the validation set starts dropping.

A simple way to implement Early Stopping in TensorFlow is to evaluate the model on a validation set at regular intervals (for example, every 50 steps), and save the best model if the current model outperforms the previous best one. Count the number of steps since the last best snapshot was saved, and interrupt training when this number reaches a limit.

Early Stopping presents better performance when using with other regularization techniques.

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
n_inputs = 784
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10

X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")

with tf.name_scope("dnn"):
    hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.elu, name="hidden1")
    hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.elu, name="hidden2")
    logits = tf.layers.dense(hidden2, n_outputs, name="outputs")
    
with tf.name_scope("loss"):
    xen = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss = tf.reduce_mean(xen, name="loss")

In [2]:
with tf.name_scope("train"):
    initial_learning_rate = 0.1
    decay_steps = 10000
    decay_rate = 1/10
    global_step = tf.Variable(0, trainable=False, name="global_step")
    learning_rate = tf.train.exponential_decay(initial_learning_rate, global_step, decay_steps, decay_rate)
    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
    training_op = optimizer.minimize(loss, global_step=global_step)

In [3]:
with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

In [4]:
init = tf.global_variables_initializer()

n_epochs = 500
batch_size = 100

best_epoch = None # Store the epoch with the highest accuracy score
best_accuracy = 0.97 # Best accuracy threshold

mnist = input_data.read_data_sets("/tmp/data/", validation_size=6000)

saver = tf.train.Saver()

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for iteration in range(len(mnist.test.labels) // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        
        # Evaluate the model every 50 epochs on the validation set
        # And store the best model if available
        if epoch % 50 == 0:
            acc_val = accuracy.eval(feed_dict={X: mnist.validation.images, y: mnist.validation.labels})
            print("Epoch:", epoch, "--", "Validation Accuracy:", acc_val)
            if acc_val > best_accuracy:
                best_accuracy = acc_val
                best_epoch = epoch
                saver.save(sess, "models/early_stopping/best_model.cpkt")
    print("Best Epoch:", best_epoch)

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
Epoch: 0 -- Validation Accuracy: 0.916833
Epoch: 50 -- Validation Accuracy: 0.983833
Epoch: 100 -- Validation Accuracy: 0.984333
Epoch: 150 -- Validation Accuracy: 0.984167
Epoch: 200 -- Validation Accuracy: 0.984
Epoch: 250 -- Validation Accuracy: 0.984
Epoch: 300 -- Validation Accuracy: 0.984
Epoch: 350 -- Validation Accuracy: 0.984
Epoch: 400 -- Validation Accuracy: 0.984
Epoch: 450 -- Validation Accuracy: 0.984
Best Epoch: 100


In [5]:
from sklearn.metrics import accuracy_score

# Restore the best model and make predictions
with tf.Session() as sess:
    saver.restore(sess, "models/early_stopping/best_model.cpkt")
    X_new_scaled = mnist.test.images[:20]
    Z = logits.eval(feed_dict={X: X_new_scaled})
    y_pred = np.argmax(Z, axis=1)

    print("Predicted classes:", y_pred)
    print("Actual classes   :", mnist.test.labels[:20])
    print(accuracy_score(mnist.test.labels[:20], y_pred))

INFO:tensorflow:Restoring parameters from models/early_stopping/best_model.cpkt
Predicted classes: [7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
Actual classes   : [7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
1.0
