### This is the tutorial for batch normalization on the same sample MNIST dataset.

In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
from tqdm import tqdm

  from ._conv import register_converters as _register_converters


In [2]:
#Data Locations
train_path = './data/mnist_example/digitstrain.csv'
val_path = './data/mnist_example/digitsvalid.csv'

train_data = pd.read_csv(train_path)
X_train = np.asarray(train_data.values[:,:784])
Y_train = np.asarray(train_data.values[:,784])
print(Y_train.shape)

(2999,)


In [3]:
n_inputs = 28*28
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10

In [4]:
X = tf.placeholder(tf.float32, shape = [None, n_inputs], name = "X")
training = tf.placeholder_with_default(True, shape = (), name = "training")
y = tf.placeholder(tf.int64, shape = [None], name = "y")

In [5]:
#Method-1
hidden1 = tf.layers.dense(X, n_hidden1, name = "hidden1")
bn1 = tf.layers.batch_normalization(hidden1, training = training, momentum = 0.9)
bn1_act = tf.nn.elu(bn1)
hidden2 = tf.layers.dense(bn1_act, n_hidden2, name = "hidden2")
bn2 = tf.layers.batch_normalization(hidden2, training = training, momentum = 0.9)
bn2_act = tf.nn.elu(bn2)
logits_before_bn= tf.layers.dense(bn2_act, n_outputs, name = "outputs")
logits = tf.layers.batch_normalization(logits_before_bn, training= training, momentum = 0.9)

In [5]:
# Note :Run only one of the methods, else it will throw a scope error
#Method-2 : to avoid repitition .
from functools import partial

my_batch_norm_layer = partial(tf.layers.batch_normalization, training = training, momentum = 0.9)
hidden1 = tf.layers.dense(X, n_hidden1, name = "hidden1")
bn1 = my_batch_norm_layer(hidden1)
bn1_act = tf.nn.elu(bn1)
hidden2 = tf.layers.dense(bn1_act, n_hidden2, name = "hidden2")
bn2 = my_batch_norm_layer(hidden2)
bn2_act = tf.nn.elu(bn2)
logits_before_bn = tf.layers.dense(bn2_act, n_outputs, name = "outputs")
logits = my_batch_norm_layer(logits_before_bn)

In [6]:
with tf.name_scope("loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y,
                                                             logits = logits)
    loss = tf.reduce_mean(xentropy, name = "loss")

In [7]:
learning_rate = 0.01
with tf.name_scope("train"):
    optimizer = tf.train.GradientDescentOptimizer(learning_rate= learning_rate)
    training_op = optimizer.minimize(loss)

with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

init = tf.global_variables_initializer()
saver = tf.train.Saver()

In [8]:
batch_size = 100
val_data = pd.read_csv(val_path)
test = val_data.values 
train_data = train_data.values
np.random.shuffle(train_data)
with tf.Session() as sess:
    init.run()
    accs = []
    for epoch in tqdm(range(100)):
        np.random.shuffle(train_data)
        ind = 0
        while ind < train_data.shape[0]:
            X_batch = train_data[ind:ind+batch_size, :784]
            y_batch = train_data[ind:ind+batch_size,784]
            sess.run(training_op, feed_dict = {X : X_batch, y : y_batch, training: True})
            ind += batch_size
        accs.append(accuracy.eval(feed_dict = {X : test[:,:784], y : test[:,784]}))
    for i in range(len(accs)):
        print("Epoch : {!r} , Validation Accuracy : {!r}".format(i, accs[i]))
    #print ( "" % epoch, acc_val)
    save_path = saver.save(sess, './trained_models/batch_norm.ckpt')

100%|██████████| 100/100 [00:17<00:00,  5.78it/s]


Epoch : 0 , Validation Accuracy : 0.6846847
Epoch : 1 , Validation Accuracy : 0.7877878
Epoch : 2 , Validation Accuracy : 0.8228228
Epoch : 3 , Validation Accuracy : 0.8368368
Epoch : 4 , Validation Accuracy : 0.8498498
Epoch : 5 , Validation Accuracy : 0.8578579
Epoch : 6 , Validation Accuracy : 0.8648649
Epoch : 7 , Validation Accuracy : 0.8678679
Epoch : 8 , Validation Accuracy : 0.8718719
Epoch : 9 , Validation Accuracy : 0.8748749
Epoch : 10 , Validation Accuracy : 0.8768769
Epoch : 11 , Validation Accuracy : 0.8808809
Epoch : 12 , Validation Accuracy : 0.8818819
Epoch : 13 , Validation Accuracy : 0.8828829
Epoch : 14 , Validation Accuracy : 0.8818819
Epoch : 15 , Validation Accuracy : 0.8808809
Epoch : 16 , Validation Accuracy : 0.8848849
Epoch : 17 , Validation Accuracy : 0.8878879
Epoch : 18 , Validation Accuracy : 0.8878879
Epoch : 19 , Validation Accuracy : 0.8878879
Epoch : 20 , Validation Accuracy : 0.8948949
Epoch : 21 , Validation Accuracy : 0.8948949
Epoch : 22 , Validat