In [1]:
import tensorflow as tf
import cifar
tf.logging.set_verbosity(tf.logging.WARN)

## Download and extract the dataset

In [2]:
cifar.prepare_cifar_10()
cifar10_labels = cifar.cifar10_labels()

## Define the model function

In [3]:
def model(features, labels, mode, params):
    net = features['images']

    for filt, kern, stride in zip(params['filters'], params['kern'], params['strides']):
        net = tf.layers.batch_normalization(net,
                                            momentum=params['momentum'],
                                            training=mode == tf.estimator.ModeKeys.TRAIN)
        net = tf.layers.conv2d(net, filt,
                               kern, stride, activation=tf.nn.relu)

    net = tf.layers.flatten(net)
    net = tf.layers.batch_normalization(net,
                                        momentum=params['momentum'],
                                        training=mode == tf.estimator.ModeKeys.TRAIN)

    for units, drop in zip(params['dense'], params['drop']):
        net = tf.layers.dense(net, units, activation=tf.nn.relu)

        if params['with_bn']:
            net = tf.layers.batch_normalization(net,
                                                momentum=params['momentum'],
                                                training=mode == tf.estimator.ModeKeys.TRAIN)
        net = tf.layers.dropout(net, drop, training=mode == tf.estimator.ModeKeys.TRAIN)

    logits = tf.layers.dense(net, params['n_classes'])
    cls = tf.argmax(logits, -1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions={
            "class": cls,
            "score": tf.nn.softmax(logits)
        })

    loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops={
            "accuracy": tf.metrics.accuracy(labels, cls)
        })
    adam = tf.train.AdamOptimizer()

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt = adam.minimize(loss, global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=opt)

In [4]:
def inp_fn():
    return tf.data.Dataset.from_generator(cifar.cifar10_train, ({"images": tf.float32}, tf.int64),
                                          ({"images": tf.TensorShape([None, 32, 32, 3])}, tf.TensorShape(None)))


def test_inp_fn():
    return tf.data.Dataset.from_generator(cifar.cifar10_test, ({"images": tf.float32}, tf.int64),
                                          ({"images": tf.TensorShape([None, 32, 32, 3])}, tf.TensorShape(None)))

In [5]:
hparams = {
    "filters": [30, 50, 60],
    "kern": [[3, 3]]*3,
    "strides": [[2, 2], [1, 1], [1, 1]],
    "dense": [3500, 700],
    "drop": [0.4, 0.5],
    "n_classes": 10,
    "with_bn": True,
    "momentum": 0.75,
}

## Dropout with BN

In [6]:
drwbn = tf.estimator.Estimator(model, 'drwbn-ckpts', config=tf.estimator.RunConfig(save_summary_steps=2),
                               params=hparams)

#### Start Tensorboard

In [7]:
get_ipython().system_raw("start tensorboard --logdir drwbn-ckpts")  # Windows
# get_ipython().system_raw("tensorboard --logdir drop-ckpts &") #Linux

In [8]:
for i in range(30):
    drwbn.train(inp_fn)
    print(drwbn.evaluate(test_inp_fn))

{'accuracy': 0.5226, 'loss': 1.3642848, 'global_step': 20}
{'accuracy': 0.5917, 'loss': 1.1845695, 'global_step': 40}
{'accuracy': 0.6246, 'loss': 1.1305641, 'global_step': 60}
{'accuracy': 0.6149, 'loss': 1.2645891, 'global_step': 80}
{'accuracy': 0.6093, 'loss': 1.4570963, 'global_step': 100}
{'accuracy': 0.6105, 'loss': 1.6007249, 'global_step': 120}
{'accuracy': 0.6072, 'loss': 1.660339, 'global_step': 140}
{'accuracy': 0.6148, 'loss': 1.7931831, 'global_step': 160}
{'accuracy': 0.6144, 'loss': 1.7849233, 'global_step': 180}
{'accuracy': 0.6165, 'loss': 1.825989, 'global_step': 200}
{'accuracy': 0.6178, 'loss': 1.8372722, 'global_step': 220}
{'accuracy': 0.627, 'loss': 1.8474933, 'global_step': 240}
{'accuracy': 0.6287, 'loss': 1.8738902, 'global_step': 260}
{'accuracy': 0.6262, 'loss': 1.8911006, 'global_step': 280}
{'accuracy': 0.6263, 'loss': 1.9236542, 'global_step': 300}
{'accuracy': 0.6279, 'loss': 1.9401443, 'global_step': 320}
{'accuracy': 0.6273, 'loss': 1.950819, 'global_

## Dropout without BN

In [9]:
hparams['with_bn'] = False
drwobn = tf.estimator.Estimator(model, 'drwobn-ckpts', config=tf.estimator.RunConfig(save_summary_steps=2),
                                params=hparams)

#### Start Tensorboard

In [10]:
get_ipython().system_raw("start tensorboard --logdir drwobn-ckpts --port 6007")  # Windows
# get_ipython().system_raw("tensorboard --logdir drwobn-ckpts &") #Linux

In [11]:
for i in range(30):
    drwobn.train(inp_fn)
    print(drwobn.evaluate(test_inp_fn))

{'accuracy': 0.4223, 'loss': 1.6613123, 'global_step': 20}
{'accuracy': 0.5224, 'loss': 1.3384712, 'global_step': 40}
{'accuracy': 0.5726, 'loss': 1.208832, 'global_step': 60}
{'accuracy': 0.6025, 'loss': 1.1299886, 'global_step': 80}
{'accuracy': 0.6201, 'loss': 1.1092131, 'global_step': 100}
{'accuracy': 0.6266, 'loss': 1.1479524, 'global_step': 120}
{'accuracy': 0.6286, 'loss': 1.1950808, 'global_step': 140}
{'accuracy': 0.6219, 'loss': 1.2852988, 'global_step': 160}
{'accuracy': 0.6252, 'loss': 1.3611405, 'global_step': 180}
{'accuracy': 0.6312, 'loss': 1.4825554, 'global_step': 200}
{'accuracy': 0.6028, 'loss': 1.7469711, 'global_step': 220}
{'accuracy': 0.6209, 'loss': 1.6120105, 'global_step': 240}
{'accuracy': 0.6206, 'loss': 1.7405858, 'global_step': 260}
{'accuracy': 0.621, 'loss': 1.8695172, 'global_step': 280}
{'accuracy': 0.6173, 'loss': 1.9884601, 'global_step': 300}
{'accuracy': 0.6183, 'loss': 2.04527, 'global_step': 320}
{'accuracy': 0.6216, 'loss': 1.9960701, 'global_

## Results
#### BN after Dropout
![Graph (BN after dropout)](https://github.com/ilango100/batch_norm/raw/93279ed23212d0141aea6521bd3f01f8c4afde86//images/drwbn.png)
#### BN not after Dropout
![Graph (BN not after dropout)](https://github.com/ilango100/batch_norm/raw/93279ed23212d0141aea6521bd3f01f8c4afde86//images/drwobn.png)

As we can see, accuracy is more stable when BN is not applied after dropout. Also, accuracy for the version without BN after Dropout is greater than that having BN after dropout.

Read more about Batch Normalization and dropout in the [Disharmony between Dropout and Batch Normalization](https://arxiv.org/abs/1801.05134). You can try running this notebook in your local system or in [Google Colab](https://colab.research.google.com/drive/1D9_ltbQaT7yCpnEyg8il1jpRrxEQEWfx)