### Implementing Batch Normalization

Reference:
1. [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf) (The original paper)
2. [How does batch normalization help optimization](https://arxiv.org/pdf/1805.11604.pdf) (Useful to understand a bit more about this topic)

During the learning process, the input distribution to each layer changes. This makes the training procedure difficult. This paper addresses the problem of internal covariate shift by normalizing the input to each layer.

In this notebook, we implement the batch normalization layer and compare two network (one with batchnorm and one w/o) trained on the mnist dataset.

In [25]:
import datetime
import matplotlib.pyplot as plt
import numpy as np
import os

import sys
import tensorflow as tf

from tensorflow.python.ops import nn_ops

sys.path.append("..")

from datasets.mnist import MNIST_DATASET

In [26]:
%load_ext tensorboard

Hyper-parameter batch_size is set based on section 4.1 of [1].

In [27]:
BATCH_SIZE = 60
SHUFFLE_BUFFER_SIZE = 100

Dataset is read from local-disk as numpy arrays and converted to tf.data.Dataset objects.

In [28]:
dataset_path = '/Users/rohit/Desktop/datasets/mnist'
x_train, y_train, x_test, y_test = MNIST_DATASET.load_dataset(dataset_path=dataset_path, reshape=False)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

We use a simple fully-connected network with sigmoid activations. The input is 28x28 image as a (784, ) vector. We train it for 50 epochs ie. 50k iterations and plot graphs for accuracy and loss using tensorboard. The choice for RMSprop as the optimizer is a random pick. One can try with Adam, SGD etc.

In [30]:
model_one = tf.keras.Sequential([
    tf.keras.layers.Dense(100, activation='sigmoid', input_shape=(784, )),
    tf.keras.layers.Dense(100, activation='sigmoid'),
    tf.keras.layers.Dense(100, activation='sigmoid'),
    tf.keras.layers.Dense(10, )
])

model_one.build()

print(model_one.summary())

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_16 (Dense)            (None, 100)               78500     
                                                                 
 dense_17 (Dense)            (None, 100)               10100     
                                                                 
 dense_18 (Dense)            (None, 100)               10100     
                                                                 
 dense_19 (Dense)            (None, 10)                1010      
                                                                 
Total params: 99,710
Trainable params: 99,710
Non-trainable params: 0
_________________________________________________________________
None


In [31]:
model_one.compile(optimizer=tf.keras.optimizers.RMSprop(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['sparse_categorical_accuracy'])

In [None]:
logdir = os.path.join("logs", datetime.datetime.now().strftime("batch_norm-%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

model_one.fit(train_dataset, epochs=50, validation_data=test_dataset, callbacks=[tensorboard_callback])

We implement the batch norm layer as per Algorithm 1 in [1] shown below

<p align="center">
    <img src="images/batchnorm.1.png" width="400" />
</p>

In [33]:
class BatchNorm1D(tf.keras.layers.Layer):
    def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super().__init__()

    def build(self, input_shape):
        self.gamma = self.add_weight(
                                name='gamma',
                                shape=(input_shape[-1], ), 
                                initializer='ones',
                                trainable=True)

        self.beta = self.add_weight(
                                name='beta',
                                shape=(input_shape[-1], ), 
                                initializer='zeros',
                                trainable=True)

        self.epsilon = 10e-6

    def call(self, inputs):
        mean = tf.math.reduce_mean(inputs, axis=0)
        variance = tf.math.reduce_mean(tf.math.square(tf.math.subtract(inputs, mean)), axis=0)

        normalized_input = tf.math.divide(tf.math.subtract(inputs, mean), tf.math.sqrt(variance + self.epsilon))

        return tf.math.add(tf.math.multiply(self.gamma, normalized_input), self.beta)

This is very much similar to the first model except that we have added the batch norm layer. The activation is applied after this. Rest of the hyper-parameters are the same.

In [34]:
# tf.keras.backend.clear_session()

model_two = tf.keras.Sequential([
    tf.keras.layers.Dense(100, input_shape=(784, )),
    BatchNorm1D(),
    tf.keras.layers.Activation(activation='sigmoid'),
    tf.keras.layers.Dense(100),
    BatchNorm1D(),
    tf.keras.layers.Activation(activation='sigmoid'),
    tf.keras.layers.Dense(100),
    BatchNorm1D(),
    tf.keras.layers.Activation(activation='sigmoid'),
    tf.keras.layers.Dense(10, )
])

model_two.build()

print(model_two.summary())

Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_20 (Dense)            (None, 100)               78500     
                                                                 
 batch_norm1d_6 (BatchNorm1D  (None, 100)              200       
 )                                                               
                                                                 
 activation_6 (Activation)   (None, 100)               0         
                                                                 
 dense_21 (Dense)            (None, 100)               10100     
                                                                 
 batch_norm1d_7 (BatchNorm1D  (None, 100)              200       
 )                                                               
                                                                 
 activation_7 (Activation)   (None, 100)              

In [None]:
model_two.compile(optimizer=tf.keras.optimizers.RMSprop(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['sparse_categorical_accuracy'])

logdir = os.path.join("logs", datetime.datetime.now().strftime("batch_norm-%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

model_two.fit(train_dataset, epochs=50, validation_data=test_dataset, callbacks=[tensorboard_callback])

To understand the difference between the two models trained, we look at the tensorboard graphs. 

First we compare the test accuracy. Red is network with batch norm and green is w/o. We can see that network with batch norm trains faster and overall achieves a higher accuracy. This is consistent with observations in Figure 1 (a) of [1].

<p align="center">
    <img src="images/batchnorm.2.png" width="800" />
</p>