##### Copyright 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model Averaging

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/addons/tutorials/average_optimizers_callback"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/average_optimizers_callback.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/addons/blob/master/docs/tutorials/average_optimizers_callback.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
      <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/addons/docs/tutorials/average_optimizers_callback.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>


## Overview

This notebook demonstrates how to use Moving Average Optimizer along with the Model Average Checkpoint from tensorflow addons pagkage.


## Moving Averaging 

> The advantage of Moving Averaging is that they are less prone to rampant loss shifts or irregular data representation in the latest batch. It gives a smooothened and a more genral idea of the model training until some point.

## Stochastic Averaging

> Stochastic Weight Averaging converges to wider optimas. By doing so, it resembles geometric ensembeling. SWA is a simple method to improve model performance when used as a wrapper around other optimizers and averaging results from different points of trajectory of the inner optimizer.

## Model Average Checkpoint 

> `callbacks.ModelCheckpoint` doesn't give you the option to save moving average weights in the middle of training, which is why Model Average Optimizers required a custom callback. Using the ```update_weights``` parameter, ```ModelAverageCheckpoint``` allows you to:
1.   Assign the moving average weights to the model, and save them.
2.   Keep the old non-averaged weights, but the saved model uses the average weights.

## Setup

In [None]:
!pip install -U tensorflow-addons

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

In [None]:
import numpy as np
import os

## Build Model 

In [None]:
def create_model(opt):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),                         
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer=opt,
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])

    return model

## Prepare Dataset

In [None]:
#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)

fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

test_images, test_labels = test

In [None]:
images.shape

The dataset provides 60k examples, since the batchsize is set as 32, the optimizer will perform 1875 steps.

The goal will be to comparing three optimizers:

*   Unwrapped SGD

This is the traditional SGD Optimizer. The learning rate is set as 0.01.

*   SGD with Moving Average

This wrapper computes the moving averages of the weights over the steps. We will use default params here.

*   SGD with Stochastic Weight Averaging

We will begin the process of cyclic averaging from step 1000 with a period of 50.

And see how they perform with the same model.

In [None]:
#Optimizers 
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd, 1000, 50)

Both ```MovingAverage``` and ```StochasticAverage``` optimers use ```ModelAverageCheckpoint```.

In [None]:
#Setup the directories for each type of optimizer.
vanilla_checkpoint_path = "training/vanilla/cp-{epoch:04d}.ckpt"
vanilla_checkpoint_dir = os.path.dirname(vanilla_checkpoint_path)

ma_checkpoint_path = "training/ma/cp-{epoch:04d}.ckpt"
ma_checkpoint_dir = os.path.dirname(ma_checkpoint_path)

swa_checkpoint_path = "training/swa/cp-{epoch:04d}.ckpt"
swa_checkpoint_dir = os.path.dirname(swa_checkpoint_path)

#Setup the callbacks for each type of optimizer.
vanilla_callback = tf.keras.callbacks.ModelCheckpoint(filepath=vanilla_checkpoint_path,
                                                      save_weights_only=True,
                                                      verbose=1)
ma_callback = tfa.callbacks.AverageModelCheckpoint(filepath=ma_checkpoint_path,
                                                    update_weights=True)

swa_callback = tfa.callbacks.AverageModelCheckpoint(filepath=swa_checkpoint_path,
                                                    update_weights=True)

## Train Model


### Vanilla SGD Optimizer 

In [None]:
#Build Model
model = create_model(sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[vanilla_callback])

# Save the weights using the `checkpoint_path` format
model.save_weights(vanilla_checkpoint_path.format(epoch=0))

In [None]:
#Load latest weights
latest = tf.train.latest_checkpoint(vanilla_checkpoint_dir)

In [None]:
#Evalute results
model.load_weights(latest)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)

### Moving Average SGD

In [None]:
#Build Model
model = create_model(moving_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[ma_callback])

# Save the weights using the `checkpoint_path` format
model.save_weights(ma_checkpoint_path.format(epoch=0))

In [None]:
#Load latest weights
latest = tf.train.latest_checkpoint(ma_checkpoint_dir)

In [None]:
#Evalute results
model.load_weights(latest)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)

### Stochastic Weight Average SGD

In [None]:
#Build Model
model = create_model(stochastic_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[swa_callback])

# Save the weights using the `checkpoint_path` format
model.save_weights(swa_checkpoint_path.format(epoch=0))

In [None]:
#Load latest weights
latest = tf.train.latest_checkpoint(swa_checkpoint_dir)

In [None]:
#Evalute results
model.load_weights(latest)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)

The observations yields that Stochastic Weight Averaging tends to perform poorly in train loss compared to vanilla SGD but outperforms in test loss.