In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint
from CLR.clr_callback import CyclicLR


from src.model import UNet_2d


import io
import json
from os import path
import glob

In [None]:
obj_dims = (648, 486)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
model_description = 'model-7.1'

In [None]:
np.random.seed(0)

In [None]:
# Dataset path
record_dir = path.join('..', 'data', model_description)
train_dataset_files = glob.glob(path.join(record_dir, 'train', '*'))
val_dataset_files = glob.glob(path.join(record_dir, 'validation', '*'))


# Paths for saving/loading model weights, predictions
base_path = path.join('..', 'models', model_description)
model_weights_path = path.join(base_path, model_description)
image_path = path.join(base_path, 'prediction-images')

In [None]:
os.mkdir(base_path)
os.mkdir(image_path)

In [None]:
TRAIN_LENGTH = len(train_dataset_files)
VAL_LENGTH = len(val_dataset_files)
input_shape = (DATASET_LENGTH, 648, 486, 1)
obj_dims = (648, 486)

# Dataset Creation

In [None]:
def _parse_function(example_proto):
    feature_description = {
        'plane': tf.io.FixedLenFeature(obj_dims, tf.float32),
        'sim': tf.io.FixedLenFeature(obj_dims, tf.float32)
    }
    example = tf.io.parse_single_example(example_proto, feature_description)
    
    
    plane = example['plane']
#     plane = tf.cast(plane, tf.float16)
#     plane = plane[4:644, 3:483] # Crop to target image size
    plane = plane[4:644, 19:467]
    plane_max = tf.reduce_max(plane)
    plane_min = tf.reduce_min(plane)
    plane = (plane - plane_min) / (plane_max - plane_min)  # Normalize values to [0, 1]

    sim = example['sim']
#     sim = tf.cast(sim, tf.float16)
    sim_max = tf.reduce_max(sim)
    sim_min = tf.reduce_min(sim)
    sim = (sim - sim_min) / (sim_max - sim_min)  # Normalize values to [0, 1]
    
    # Adding noise to simulated measurements
    a = np.random.uniform(0.0063, 0.0063*4)
    b = np.random.uniform(0.06, 0.1)
    noise = a*np.random.randn(*sim.shape) + b
    sim = sim + noise

    # Renormalize values to [0, 1]
    sim_max = tf.reduce_max(sim)
    sim_min = tf.reduce_min(sim)
    sim = (sim - sim_min) / (sim_max - sim_min)
    
    # Expand to channel dimension
    sim = sim[..., np.newaxis] 
    
    return sim, plane

def create_dataset(filenames, batch_size):
    """
    Takes in string array of filenames for TFRecord files containing samples.
    Returns: TFRecordDataset with given batch size
    """
    filenames = tf.random.shuffle(filenames)
    raw_dataset = tf.data.TFRecordDataset(filenames)
    
    dataset = raw_dataset.map(_parse_function)
    dataset = dataset.shuffle(256)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    return dataset

In [None]:
BATCH_SIZE = 8
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
VAL_STEPS = VAL_LENGTH // BATCH_SIZE

SAMPLES_PER_EPOCH = BATCH_SIZE*STEPS_PER_EPOCH

CLR_STEPS = STEPS_PER_EPOCH * 8

train_dataset = create_dataset(train_dataset_files, BATCH_SIZE)
val_dataset = create_dataset(val_dataset_files, BATCH_SIZE)

# Logging

In [None]:
def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def scaled_mse_loss(y_actual,y_pred):
    loss = K.square((y_actual-y_pred))
    loss = K.sum(loss)
    return loss

def plot_image_tensorboard(epoch, logs):
    # Create a plot to visualize image reconstruction progress

    # Call the model to get prediction
    pred = model.predict(val_sample[0][0][np.newaxis])

    # Create a mpl figure
    figure = plt.figure(figsize=(10,10))

    # Plot the prediction
    plt.subplot(1, 2, 1)
    plt.title("prediction")
    plt.imshow(pred[0].astype(np.float32))
    # Plot groundtruth
    plt.subplot(1, 2, 2)
    plt.title("ground truth")
    plt.imshow(val_sample[1][0].numpy().astype(np.float32))
    plot_image = plot_to_image(figure)
    with file_writer.as_default():
        tf.summary.image("Prediction vs Ground Truth", plot_image, step=epoch)

def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def scaled_mse_loss(y_actual,y_pred):
    loss = K.square((y_actual-y_pred))
#     loss = K.sqrt(loss)
    loss = K.sum(loss)
    return loss

def plot_image_tensorboard(epoch, logs):
    # Create a plot to visualize image reconstruction progress
    
    # Call the model to get prediction
    pred = model.predict(val_sample[0][0][np.newaxis])

    # Create a mpl figure
    figure = plt.figure(figsize=(10,10))

    # Plot the prediction
    plt.subplot(1, 2, 1)
    plt.title("prediction")
    plt.imshow(pred[0].astype(np.float32))
    # Plot groundtruth
    plt.subplot(1, 2, 2)
    plt.title("ground truth")
    plt.imshow(val_sample[1][0].numpy().astype(np.float32))
    plot_image = plot_to_image(figure)
    with file_writer.as_default():
        tf.summary.image("Prediction vs Ground Truth", plot_image, step=epoch)

import datetime
log_dir=os.path.join('logs', model_description + '-fit') + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch=1000000)
file_writer = tf.summary.create_file_writer(log_dir)

plot_image_tensorboard_cb = keras.callbacks.LambdaCallback(on_epoch_end=plot_image_tensorboard)

# Save model after epochs
checkpoint_cb = ModelCheckpoint(model_weights_path + '.e{epoch:03d}', monitor='val_loss', verbose=0, 
                             save_best_only=False, save_weights_only=True, mode='auto', 
                                save_freq=10*STEPS_PER_EPOCH)
checkpoint_best_cb = ModelCheckpoint(model_weights_path + '.best', monitor='val_loss', verbose=0,
                                    save_best_only=True, save_weights_only=True, mode='auto')

# Training

In [None]:
model = UNet_2d()
adam = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, amsgrad=False)

model.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
# model.compile(optimizer=adam, loss=tf.keras.losses.MeanAbsoluteError(), metrics=[tf.keras.losses.MeanAbsoluteError()])
model.build(input_shape)
model.summary()

In [None]:
EPOCHS = 450

In [None]:
val_sample = next(iter(val_dataset)) # Used for logging and plotting.

In [None]:
# clr = CyclicLR(base_lr=1e-4, max_lr=5e-4,
#                         step_size=CLR_STEPS)

In [None]:
# model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, 
#           callbacks=[plot_image_tensorboard_cb, tensorboard_callback, checkpoint_cb, checkpoint_best_cb, clr], 
#           validation_data=val_dataset, validation_steps=VAL_STEPS)
model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, 
          callbacks=[plot_image_tensorboard_cb, tensorboard_callback, checkpoint_cb, checkpoint_best_cb], 
          validation_data=val_dataset, validation_steps=VAL_STEPS)

model.save_weights(model_weights_path, save_format='tf')

# Prediction Visualization, Timing Tests

In [None]:
NUM_DISPLAY = 15

In [None]:
model = UNet_2d()
model.load_weights(model_weights_path + '.e220')

In [None]:
import time

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
preds = np.zeros((NUM_DISPLAY, 640, 448))
ground_truths = np.zeros((NUM_DISPLAY, 640, 448))
sims = np.zeros((NUM_DISPLAY, 648, 486))


### Train

In [None]:
t0 = time.perf_counter()
i = 0
for sim, plane in train_dataset.unbatch():
    preds[i] = model.predict(sim[np.newaxis])
    ground_truths[i] = plane
    sims[i] = np.squeeze(sim)
    i += 1
    if i == NUM_DISPLAY:
        break
        
# assert(i == VAL_LENGTH)

t1 = time.perf_counter()

print('Prediction time per sample:', (t1 - t0) / VAL_LENGTH, 's')

In [None]:
fig=plt.figure(figsize=(30, 15 * NUM_DISPLAY))
rows, columns = NUM_DISPLAY, 2
for i in range(NUM_DISPLAY):
    # Plot prediction
    s = fig.add_subplot(rows, columns, columns*i+1)
    s.set_title("prediction {0}".format(i), size=20)
    plt.imshow(preds[i])
    
    # Plot ground truth
    s = fig.add_subplot(rows, columns, columns*i+2)
    s.set_title("ground truth {0}".format(i), size=20)
    plt.imshow(ground_truths[i])
    
plt.savefig(path.join(image_path, 'train.pdf'))

### Validation

In [None]:
t0 = time.perf_counter()
i = 0
for sim, plane in val_dataset.unbatch():
    preds[i] = model.predict(sim[np.newaxis])
    ground_truths[i] = plane
    sims[i] = np.squeeze(sim)
    i += 1
    if i == NUM_DISPLAY:
        break
        
# assert(i == VAL_LENGTH)

t1 = time.perf_counter()

print('Prediction time per sample:', (t1 - t0) / VAL_LENGTH, 's')

In [None]:
fig=plt.figure(figsize=(30, 15 * NUM_DISPLAY))
rows, columns = NUM_DISPLAY, 2
for i in range(NUM_DISPLAY):
    # Plot prediction
    s = fig.add_subplot(rows, columns, columns*i+1)
    s.set_title("prediction {0}".format(i), size=20)
    plt.imshow(preds[i])
    
    # Plot ground truth
    s = fig.add_subplot(rows, columns, columns*i+2)
    s.set_title("ground truth {0}".format(i), size=20)
    plt.imshow(ground_truths[i])
    
plt.savefig(path.join(image_path, 'validation.pdf'))

In [None]:
plt.imshow(ground_truths[0])