## Batch Normalization

Batch normalization is a technique for improving the performance and stability of [artificial neural networks](https://en.wikipedia.org/wiki/Artificial_neural_network). It is a technique to provide any layer in a neural network with inputs that are zero mean/unit variance.[

TensorFlow is a Python library for fast numerical computing created and released by Google.

It is a foundation library that can be used to create Deep Learning models directly or by using wrapper libraries that simplify the process built on top of TensorFlow.

In [1]:
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

In [2]:
def batch_norm(x, phase_train, name='bn', decay=0.99, reuse=None, affine=True):
    """
    Summary
    ------
    x
        Tensor, 4D BHWD input maps
    phase_train
        boolean tf.Variable, true indicates training phase
    name
        string, variable name
    affine
        whether to affine-transform outputs
    Return
    ------
    normed
        batch-normalized maps
    """
    with tf.variable_scope(name, reuse=reuse):
        og_shape = x.get_shape().as_list()
        if len(og_shape) == 2:
            x = tf.reshape(x, [-1, 1, 1, og_shape[1]])
        shape = x.get_shape().as_list()
        beta = tf.get_variable(name='beta', shape=[shape[-1]],
                               initializer=tf.constant_initializer(0.0),
                               trainable=True)
        gamma = tf.get_variable(name='gamma', shape=[shape[-1]],
                                initializer=tf.constant_initializer(1.0),
                                trainable=affine)

        batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
        ema = tf.train.ExponentialMovingAverage(decay=decay)
        ema_apply_op = ema.apply([batch_mean, batch_var])
        ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)

        def mean_var_with_update():
            """Summary
            Returns
            -------
            name : TYPE
                Description
            """
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean), tf.identity(batch_var)
        mean, var = control_flow_ops.cond(phase_train,
                                          mean_var_with_update,
                                          lambda: (ema_mean, ema_var))

         ''' 
         tf.nn.batch_normalization
         '''
           
        normed = tf.nn.batch_norm_with_global_normalization(
            x, mean, var, beta, gamma, 1e-5, affine)
        if len(og_shape) == 2:
            normed = tf.reshape(normed, [-1, og_shape[-1]])
    return normed

### References
[1] Batch normalization https://en.wikipedia.org/wiki/Batch_normalization