## Batch Normalization Example
Ref - 
1. https://towardsdatascience.com/batch-normalization-in-practice-an-example-with-keras-and-tensorflow-2-0-b1ec28bde96f
2. https://towardsdatascience.com/tensorboard-hyperparameter-optimization-a51ef7af71f5

In [1]:
import datetime as dt
import numpy as np
import pandas as pd
import os
import tensorflow as tf

from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tensorboard.plugins.hparams import api as hp
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, BatchNormalization

In [2]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

### Parameters

In [24]:
TEST_SIZE = 0.2
VAL_SIZE = 0.25
EPOCH = 100

# Log Directory
LOG_DIR = os.path.join('.', 'logs', 'hp_opt', dt.datetime.now().strftime('%Y%m%d-%H%M%S'))

# Hyperparaemters
HP_NUM_UNITS = hp.HParam('num_units', hp.Discrete([32, 64, 128]))
HP_BN = hp.HParam('batch_normalization', hp.Discrete([True, False]))
HP_BATCH = hp.HParam('batch_size', hp.Discrete([16, 32, 64]))

### Load Data

In [4]:
iris = load_iris()

### Data Preprocessing

In [5]:
# Extract Feature Data
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df = df.astype(np.float)

In [6]:
# Extract Label Data (Label Index 0 ~ 2 => Label String)
df['label'] = iris.target
df['label'] = df.label.replace(dict(enumerate(iris.target_names)))

In [7]:
# Label => One-Hot Encoding
label = pd.get_dummies(df['label'], prefix='label')
df = pd.concat([df, label], axis=1)

df.drop(['label'], axis=1, inplace=True)

In [8]:
# DataFrame => Numpy Array
X = np.asarray(df[['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']])
y = np.asarray(df[['label_setosa', 'label_versicolor', 'label_virginica']])

In [9]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE)

### Create Model

In [19]:
# Create log files
with tf.summary.create_file_writer(LOG_DIR).as_default():
    hp.hparams_config(hparams=[HP_NUM_UNITS, HP_BN, HP_BATCH],
                      metrics=[hp.Metric('accuracy', display_name='Acc')])

In [23]:
# Model
def create_model(hparams):
    # Create
    model = Sequential()
    model.add(Dense(hparams[HP_NUM_UNITS], input_shape=(4, ), activation='relu'))
    if hparams[HP_BN]:
        model.add(BatchNormalization())
    model.add(Dense(hparams[HP_NUM_UNITS] * 2, activation='relu'))
    if hparams[HP_BN]:
        model.add(BatchNormalization())    
    model.add(Dense(hparams[HP_NUM_UNITS] * 2, activation='relu'))
    if hparams[HP_BN]:
        model.add(BatchNormalization())
    model.add(Dense(hparams[HP_NUM_UNITS], activation='relu'))
    if hparams[HP_BN]:
        model.add(BatchNormalization())
    model.add(Dense(hparams[HP_NUM_UNITS], activation='relu'))
    model.add(Dense(3, activation='softmax'))

    # Compile
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    # Train
    h = model.fit(X_train, y_train, batch_size=hparams[HP_BATCH], epochs=EPOCH, 
                  validation_split=VAL_SIZE, verbose=0)

    return h.history['val_accuracy'][-1]

### Train Model

In [21]:
def run(run_dir, hparams):
    with tf.summary.create_file_writer(run_dir).as_default():
        # Record hyperparameter
        hp.hparams(hparams)

        # Record metric
        acc = create_model(hparams)
        acc = tf.reshape(tf.convert_to_tensor(acc), []).numpy()
        tf.summary.scalar('accuracy', acc, step=1)

In [None]:
# Train model with different hyperparameter
session_num = 0
for num_units in HP_NUM_UNITS.domain.values:
    for batch_size in HP_BATCH.domain.values:
        for bn_layer in HP_BN.domain.values:
            hparams = {HP_NUM_UNITS: num_units,
                       HP_BATCH: batch_size,
                       HP_BN: bn_layer}

            print('Run %d' % session_num)
            run(os.path.join(LOG_DIR, 'run_%d' % session_num), hparams)
            session_num += 1

### Check optimization result

In [None]:
%tensorboard --logdir='./logs/hp_opt/20200719-090551'