# TF/Keras Learning Rate & Schedulers

Neural networks are optimized with an algorithm called back-propagation. The model parameters start from some random initial guesses, the data is feed forward, and the error of the outcome is propagated backwards in the model as a function of a parameter called the **learning rate**.

The learning rate is an important hyper-parameter for model training. For instance, if it is too low, the training of the model can take too long; if it is too large, the model may not converge. Furthermore, the model can converge to a sub-optimal solution (a local minimum).

Learning rate schedulers provide a dynamic solution to these issues. They are just some functions of a training epoch (or step) that multiply the learning rate. Here, I will plot the readily available learning rate schedulers in Tensorflow/Keras as well as one popular scheduler that is probably not available in Tensorflow/Keras. Seeing these functions is extremely helpful for training. Most of the plots are shown in both linear and logarithmic scales. 

Interested readers may benefit from playing with the code by changing the parameters of schedulers to see how they affect the learning rate.

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

## Cosine Decay

Scheduler = $\dfrac{1}{2} \left(1- \alpha\right)  \left[1 + cos\left(\frac{\pi Step}{Decay Steps}\right)\right] + \alpha$

In [None]:
def plot_scheduler(step, schedulers):
    if not isinstance(schedulers, list):
        schedulers = [schedulers]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[10, 3])
    for scheduler in schedulers:
        ax1.plot(range(step), scheduler(range(step)), label=scheduler.name)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Learning Rate')
        ax1.legend()

        ax2.plot(range(step), scheduler(range(step)), label=scheduler.name)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Learning Rate')
        ax2.set_yscale('log')
        ax2.legend()
    plt.show()

In [None]:
init_lr = 1e-2
decay_steps = 50
alpha = 1e-2
cos_dec1 = tf.keras.experimental.CosineDecay(init_lr, decay_steps, alpha=alpha, name='Cosine Decay 1')

init_lr = 1e-2
decay_steps = 50
alpha = 0
cos_dec2 = tf.keras.experimental.CosineDecay(init_lr, decay_steps, alpha=alpha, name='Cosine Decay 2')

init_lr = 1e-3
decay_steps = 50
alpha = 1e-2
cos_dec3 = tf.keras.experimental.CosineDecay(init_lr, decay_steps, alpha=alpha, name='Cosine Decay 3')

plot_scheduler(60, [cos_dec1, cos_dec2, cos_dec3])

## Linear Cosine Decay

In [None]:
init_lr = 1e-2
decay_steps = 50
alpha = 0
beta = 1e-3
num_periods=1
lin_cos_dec1 = tf.keras.experimental.LinearCosineDecay(init_lr,
                                                       decay_steps,
                                                       num_periods=num_periods, alpha=alpha,
                                                       beta=beta, name='LinCosDec 1')

init_lr = 1e-3
decay_steps = 50
alpha = 0
beta = 1e-3
num_periods=2
lin_cos_dec2 = tf.keras.experimental.LinearCosineDecay(init_lr,
                                                       decay_steps,
                                                       num_periods=num_periods, alpha=alpha,
                                                       beta=beta, name='LinCosDec 2')

init_lr = 4e-3
decay_steps = 50
alpha = 1e-5
beta = 1e-8
num_periods=4
lin_cos_dec3 = tf.keras.experimental.LinearCosineDecay(init_lr,
                                                       decay_steps,
                                                       num_periods=num_periods, alpha=alpha,
                                                       beta=beta, name='LinCosDec 3')

plot_scheduler(100, [lin_cos_dec1, lin_cos_dec2, lin_cos_dec3])

## Noisy Linear Cosine Decay

In [None]:
init_lr = 1e-2
decay_steps = 50

noisy_lin_cos_dec = tf.keras.experimental.NoisyLinearCosineDecay(
    init_lr, decay_steps, initial_variance=0.002, variance_decay=0.001,
    num_periods=2, alpha=0.0, beta=0.01, name=None)

plot_scheduler(100, noisy_lin_cos_dec)

## CyclicalLearningRate

In [None]:
clr = tfa.optimizers.CyclicalLearningRate(1e-3, 1e-2,
    step_size=25,
    scale_fn=lambda x: 1,
    scale_mode= 'cycle',
    name= 'CyclicalLearningRate'
)

plt.plot(range(100), clr(range(100)).numpy())
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.show()

## Exponential Cyclical Learning Rate

In [None]:
def plot_scheduler2(step, schedulers):
    if not isinstance(schedulers, list):
        schedulers = [schedulers]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[10, 3])
    for scheduler in schedulers:
        x = range(step)
        y = [scheduler(i).numpy() for i in x]
        ax1.plot(x, y, label=scheduler.name)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Learning Rate')
        ax1.legend()

        ax2.plot(x, y, label=scheduler.name)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Learning Rate')
        ax2.set_yscale('log')
        ax2.legend()
    plt.show()

In [None]:
exp_clr = tfa.optimizers.ExponentialCyclicalLearningRate(
    initial_learning_rate=1e-3,
    maximal_learning_rate=1e-2,
    step_size=5,
    scale_mode='iterations',
    gamma=0.8,
    name='ExponentialCyclicalLearningRate'
)

plot_scheduler2(100, exp_clr)

## Custom Learning Scheduler

In [None]:
def plot_scheduler3(step, schedulers):
    if not isinstance(schedulers, list):
        schedulers = [schedulers]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[10, 3])
    for scheduler in schedulers:
        x = range(step)
        y = [scheduler(i) for i in x]
        ax1.plot(x, y)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Learning Rate')

        ax2.plot(x, y)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Learning Rate')
        ax2.set_yscale('log')
    plt.show()

In [None]:
# This is directly copied from a notebook of Chris Deotte.

LR_START = 1e-5
LR_MAX = 1e-2
LR_RAMPUP_EPOCHS = 2
LR_SUSTAIN_EPOCHS = 1
LR_STEP_DECAY = 0.75

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = LR_MAX * LR_STEP_DECAY**((epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS)//2)
    return lr

plot_scheduler3(100, lrfn)