# Simple neural network example
### Christian Igel, 2019

We use TensorFlow 2.x:

In [None]:
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
import numpy as np
import matplotlib.pyplot as plt

# Load the TensorBoard notebook extension
%load_ext tensorboard.notebook

Generate and visualize toy data:

In [None]:
def generate_sine_data(N, noise = 0):
    x = np.random.rand(N, 1) * 2 * np.pi
    x = np.sort(x, axis=0)
    y = np.sin(x) + np.random.normal(0, noise, (N, 1))
    return x, y

In [None]:
x_train, y_train = generate_sine_data(50, 0.5)
x_val, y_val = generate_sine_data(50, 0.5)
x_test, y_test = generate_sine_data(100, 0)

print("Shape of training input and labels:", x_train.shape, y_train.shape)

In [None]:
fig, ax = plt.subplots()
ax.plot(x_train.reshape([-1]), y_train.reshape([-1]))
ax.plot(x_val.reshape([-1]), y_val.reshape([-1]))
ax.plot(x_test.reshape([-1]), y_test.reshape([-1]))
plt.show()

Next we define the model in away that it is easy to create several models of the same type:

In [None]:
def my_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.Dense(64, activation='sigmoid', input_shape=(1,),
            kernel_initializer=tf.initializers.VarianceScaling(scale=0.01**2),
            bias_initializer=tf.initializers.TruncatedNormal(stddev=0.01)),
        tf.keras.layers.Dense(1, activation='linear',
            kernel_initializer=tf.initializers.VarianceScaling(scale=0.01**2),
            bias_initializer=tf.initializers.TruncatedNormal(stddev=0.01))
    ])

model = my_model()

print(model.summary())

Next we define some operations carried out in the training loop. 
First, we  store information about the training progress in format that can be visualized by TensorBoard. Second, we store the network weights giving the lowest validation error.

In [None]:
callbacks = [
  tf.keras.callbacks.TensorBoard(log_dir='./logs/run2', update_freq='batch'),
  #tf.keras.callbacks.EarlyStopping(monitor='loss', patience=20),
  tf.keras.callbacks.ModelCheckpoint('./logs/model_val.hdf5', monitor='val_loss', save_best_only=True, verbose=1)
]

Small data sets can be fed into the model as arrays:

In [None]:
# Define optimization algorithm
sgd = tf.optimizers.SGD(lr=0.2)

# Compile model (i.e., build compute graph)
model.compile(optimizer=sgd,
              loss='MSE')

# Training loop
model.fit(x_train, y_train, batch_size=25, epochs=100, 
          validation_data=(x_val, y_val), validation_freq=1, 
          #steps_per_epoch=x_train.shape[0],
          callbacks=callbacks)

For larger data sets, use the `Datset` API:

In [None]:
# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=16).batch(8)

# Prepare the validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))

# Prepare the validation dataset
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [None]:
# Clear any logs from previous runs
!rm -rf ./logs
!sync

In [None]:
# Define optimization algorithm
#opt = tf.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9)
opt = tf.optimizers.Adam(lr=0.01)

# Compile model (i.e., build compute graph)
model.compile(optimizer=opt,
              loss='MSE')

# Training loop
model.fit(train_dataset, epochs=2000, verbose=1,
          validation_data=val_dataset, validation_freq=1, 
          callbacks=callbacks)

Let's visulaize the training and evaluate the model:

In [None]:
%tensorboard --logdir logs

In [None]:
pred_train = model.predict(x_train)

In [None]:
fig, ax = plt.subplots()
ax.plot(x_train.reshape([-1]), y_train.reshape([-1]))
ax.plot(x_test.reshape([-1]), y_test.reshape([-1]))
ax.plot(x_train.reshape([-1]), pred_train.reshape([-1]))
plt.show()

In [None]:
# Generate a network with the same structure as the one used during training
best_model = my_model()
# Set the weights to the weights that gave the lowest validation error during training
best_model.load_weights('logs/model_val.hdf5')

In [None]:
pred_best_train = best_model.predict(x_train)
fig, ax = plt.subplots()
ax.plot(x_train.reshape([-1]), y_train.reshape([-1]))
ax.plot(x_test.reshape([-1]), y_test.reshape([-1]))
ax.plot(x_train.reshape([-1]), pred_train.reshape([-1]))
ax.plot(x_train.reshape([-1]), pred_best_train.reshape([-1]))
plt.show()