In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import numpy as np
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

import matplotlib.pyplot as plt

Load and prepare the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

Use `tf.data` to batch and shuffle the dataset:

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [None]:
print('Number of iters per epoch:', np.ceil(x_train.shape[0] / 32))

Build the `tf.keras` model using the Keras [model subclassing API](https://www.tensorflow.org/guide/keras#model_subclassing):

In [None]:
class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10)

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)

# Create an instance of the model
model = MyModel()

Choose an optimizer and loss function for training: 

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [None]:
initial_learning_rate = 0.001

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=2000,
    decay_rate=0.1,
    staircase=True)

optimizer = tf.keras.optimizers.Adam(
    lr_schedule,
    beta_1=0.9,
    beta_2=0.98,
    epsilon=1e-9,
)

Select metrics to measure the loss and the accuracy of the model. These metrics accumulate the values over epochs and then print the overall result.

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

Use `tf.GradientTape` to train the model:

In [None]:
@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    # training=True is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    predictions = model(images, training=True)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

Test the model:

In [None]:
@tf.function
def test_step(images, labels):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss(t_loss)
  test_accuracy(labels, predictions)

In [None]:
EPOCHS = 3

for epoch in range(EPOCHS):
  # Reset the metrics at the start of the next epoch
  train_loss.reset_states()
  train_accuracy.reset_states()
  test_loss.reset_states()
  test_accuracy.reset_states()

  for images, labels in train_ds:
    train_step(images, labels)

  for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

  print(
    f'Epoch {epoch + 1}, '
    f'Loss: {train_loss.result()}, '
    f'Accuracy: {train_accuracy.result() * 100}, '
    f'Test Loss: {test_loss.result()}, '
    f'Test Accuracy: {test_accuracy.result() * 100}'
  )

The image classifier is now trained to ~98% accuracy on this dataset. To learn more, read the [TensorFlow tutorials](https://www.tensorflow.org/tutorials).

In [None]:
optimizer._decayed_lr('float32').numpy()

In [None]:
optimizer._iterations.numpy()

### Custom schedules

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [None]:
# learning_rate = CustomSchedule(512)
# optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
#                                      epsilon=1e-9)

In [None]:
temp_learning_rate_schedule = CustomSchedule(512)

plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")

In [None]:
class CustomSchedule2(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self):
    super(CustomSchedule2, self).__init__()

    self.peak_lr = 0.1
    self.warmup_steps = 100
    self.decay_rate = 0.1
    self.decay_period = 60
    self.decay_steps = 150

  def __call__(self, step):
    global_step = tf.cast(step, tf.float32)
    return self.peak_lr * tf.cond(
        global_step < self.decay_steps,
        lambda: tf.math.minimum(1.0, global_step / self.warmup_steps),
        lambda: tf.math.maximum(
            self.decay_rate,
            self.decay_rate ** ((global_step - self.decay_steps) / self.decay_period)
        )
    )

In [None]:
temp_learning_rate_schedule = CustomSchedule2()

plt.plot([temp_learning_rate_schedule(i) for i in range(300)])
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")

In [None]:
class TransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000, peak_lr=None):
        super(TransformerSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)
        self.warmup_steps = warmup_steps
        self.peak_lr = peak_lr
        self.rescale = 1. if peak_lr is None else peak_lr * tf.math.sqrt(self.d_model) / (self.warmup_steps ** -0.5)

    def __call__(self, step):
        # lr = rescale * (d_model^-0.5) * min(step^-0.5, step*(warm_up^-1.5))
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        lr = self.rescale * tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
        return lr

    def get_config(self):
        return {
            "d_model": self.d_model,
            "warmup_steps": self.warmup_steps,
            "peak_lr": self.peak_lr
        }

In [None]:
temp_learning_rate_schedule = \
    TransformerSchedule(
        d_model=144,
        warmup_steps=4000,
        peak_lr=0.00125,
    )

In [None]:
plt.plot([temp_learning_rate_schedule(i) for i in range(50000)])
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")

In [None]:
initial_learning_rate = 0.1
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100,
    decay_rate=0.96,
    staircase=True)

In [None]:
optimizer = tf.keras.optimizers.Adam(
    lr_schedule,
    beta_1=0.9,
    beta_2=0.98,
    epsilon=1e-9,
)

In [None]:
optimizer._decayed_lr('float32').numpy()

In [None]:
plt.plot([lr_schedule(i) for i in range(1000)])
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")