In [None]:
from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
from keras.utils import to_categorical

x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.reshape(x_test.shape + (1,))

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [None]:
from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255,
                                   width_shift_range=2,
                                   height_shift_range=2)
val_datagen = ImageDataGenerator(rescale=1./255)

train_gen = train_datagen.flow(x_train, y_train, batch_size=50)
val_gen = val_datagen.flow(x_test, y_test, batch_size=50)

In [None]:
from keras import models, layers
from keras.constraints import MaxNorm
from keras.initializers import TruncatedNormal

def build_teacher(max_norm=1):
    max_norm = MaxNorm(max_norm)
    initializer = TruncatedNormal(mean=0, stddev=0.01)

    model = models.Sequential()
    model.add(layers.Flatten(input_shape=(28, 28, 1)))
    model.add(layers.Dropout(0.2))
    model.add(layers.Dense(128, activation='relu', kernel_constraint=max_norm, kernel_initializer=initializer))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(128, activation='relu', kernel_constraint=max_norm, kernel_initializer=initializer))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(10, kernel_initializer=initializer))
    model.add(layers.Activation('softmax'))
    
    return model

In [None]:
# Custom callback to modify momentum during training
from keras.callbacks import Callback
import keras.backend as K
import numpy as np

class MomentumScheduler(Callback):
    """Momentum scheduler.
    # Arguments
        schedule: a function that takes an epoch index as input
            (integer, indexed from 0) and current momentum
            and returns a new momentum as output (float).
        verbose: int. 0: quiet, 1: update messages.
    """

    def __init__(self, schedule, verbose=0):
        super(MomentumScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'momentum'):
            raise ValueError('Optimizer must have a "momentum" attribute.')
        momentum = float(K.get_value(self.model.optimizer.momentum))
        try:  # new API
            momentum = self.schedule(epoch, momentum)
        except TypeError:  # old API for backward compatibility
            momentum = self.schedule(epoch)
        if not isinstance(momentum, (float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        K.set_value(self.model.optimizer.momentum, momentum)
        if self.verbose > 0:
            print('\nEpoch %05d: MomentumScheduler setting momentum '
                  'to %s.' % (epoch + 1, momentum))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['momentum'] = K.get_value(self.model.optimizer.momentum)
        
        
        
def m_scheduler(epoch, momentum):
    max_momentum = 0.99
    increase = (max_momentum - 0.5) / 500
    if epoch < 500:
        return momentum + increase
    else:
        return max_momentum

In [None]:
class LRScheduler(Callback):
    """LR scheduler.
    # Arguments
        schedule: a function that takes an epoch index as input
            (integer, indexed from 0) and current lr
            and returns a new momenutm as output (float).
        verbose: int. 0: quiet, 1: update messages.
    """

    def __init__(self, schedule, verbose=0):
        super(LRScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        lr = float(K.get_value(self.model.optimizer.lr))
        m = float(K.get_value(self.model.optimizer.momentum))
        try:  # new API
            lr = self.schedule(epoch, lr, m)
        except TypeError:  # old API for backward compatibility
            lr = self.schedule(epoch)
        if not isinstance(lr, (float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\nEpoch %05d: LRScheduler setting lr '
                  'to %s.' % (epoch + 1, lr * 0.998**epoch))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['lr'] = K.get_value(self.model.optimizer.lr)
        
        
def lr_scheduler(epoch, lr, m):
    if epoch < 500:
        m_0 = 0.5 + ((0.99 - 0.5) / 500) * (epoch - 1)
    else:
        m_0 = 0.99
        
    if epoch >= 1:    
        lr /= (1- m_0)
        lr *= (1 - m)
    else:
        lr *= (1-m)
    
    return lr

In [None]:
change_m = MomentumScheduler(m_scheduler, verbose=1)
lrs = LRScheduler(lr_scheduler, verbose=1)

In [None]:
from keras.optimizers import SGD

sgd = SGD(lr=10, decay=0.002, momentum=0.5)

epochs = 30
batch_size = 50
steps_per_epoch = len(x_train) / batch_size

max_norm = 15

teacher = build_teacher(max_norm)

teacher.compile(loss='categorical_crossentropy',
           optimizer='adam',
           metrics=['acc'])

history = teacher.fit_generator(train_gen,
                      epochs=epochs,
                      steps_per_epoch=steps_per_epoch,
                      validation_data=val_gen)

res = teacher.evaluate_generator(val_gen)
errors = len(y_test) - res[1] * len(y_test)
print('Errors: ', errors)

In [None]:
res = teacher.evaluate_generator(val_gen)
errors = len(y_test) - res[1] * len(y_test)
print('Errors: ', errresors)

In [None]:
res

In [None]:
import matplotlib.pyplot as plt

history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']

epochs = range(1, len(loss_values) + 1)

plt.plot(epochs, loss_values, 'bo', label='Training loss')
plt.plot(epochs, val_loss_values, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

In [None]:
plt.clf()
acc_values = history_dict['acc']
val_acc_values = history_dict['val_acc']

plt.plot(epochs, acc_values, 'bo', label='Training acc')
plt.plot(epochs, val_acc_values, 'b', label='Valication acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
teacher.save('mnist_teacher.h5')

In [None]:
teacher.summary()