In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

from src.model import *

import io
import json
from os import path
import glob

In [None]:
obj_dims = (648, 486)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
model_description = 'model_test'

In [None]:
# Dataset path
record_dir = path.join('..', 'data', model_description)
train_dataset_files = glob.glob(path.join(record_dir, 'train', '*'))
val_dataset_files = glob.glob(path.join(record_dir, 'validation', '*'))


# Paths for saving/loading model weights, predictions
base_path = path.join('..', 'models', model_description)
model_weights_path = path.join(base_path, model_description)
image_path = path.join(base_path, 'prediction-images')

In [None]:
os.makedirs(base_path, exist_ok=True)
os.makedirs(image_path, exist_ok=True)

In [None]:
TRAIN_LENGTH = len(train_dataset_files)
VAL_LENGTH = len(val_dataset_files)
input_shape = (None, 648, 486, 1)

obj_dims = (648, 486)

# Dataset Creation

In [None]:
def _parse_function(example_proto):
    feature_description = {
        'plane': tf.io.FixedLenFeature(obj_dims, tf.float32),
        'sim': tf.io.FixedLenFeature(obj_dims, tf.float32)
        
    }
    example = tf.io.parse_single_example(example_proto, feature_description)
    
    
    plane = example['plane']

    # Downsampling
#     plane = tf.squeeze(tf.image.resize(plane[tf.newaxis, ..., tf.newaxis], [648, 486]))
#     plane = plane[2:322, 2:242]
    
    plane_max = tf.reduce_max(plane)
    plane_min = tf.reduce_min(plane)
    plane = (plane - plane_min) / (plane_max - plane_min)  # Normalize values to [0, 1]

    sim = example['sim']

    # Downsampling
#     sim = tf.squeeze(tf.image.resize(sim[tf.newaxis, ..., tf.newaxis], [648, 486]))

    sim_max = tf.reduce_max(sim)
    sim_min = tf.reduce_min(sim)
    sim = (sim - sim_min) / (sim_max - sim_min)  # Normalize values to [0, 1]

    # Expand to channel dimension
    sim = sim[..., tf.newaxis]

    
    return sim, plane

def create_dataset(filenames, batch_size):
    """
    Takes in string array of filenames for TFRecord files containing samples.
    Returns: TFRecordDataset with given batch size
    """
    filenames = tf.random.shuffle(filenames)
    raw_dataset = tf.data.TFRecordDataset(filenames)
    
    dataset = raw_dataset.map(_parse_function)
    dataset = dataset.shuffle(256)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    return dataset

In [None]:
BATCH_SIZE = 4
STEPS_PER_EPOCH = int(np.ceil(TRAIN_LENGTH / BATCH_SIZE))
VAL_STEPS = int(np.ceil(VAL_LENGTH / BATCH_SIZE))

train_dataset = create_dataset(train_dataset_files, BATCH_SIZE)
val_dataset = create_dataset(val_dataset_files, BATCH_SIZE)

# Logging

In [None]:
def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def scaled_mse_loss(y_actual,y_pred):
    loss = K.square((y_actual-y_pred))
    loss = K.sum(loss)
    return loss

def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def scaled_mse_loss(y_actual,y_pred):
    loss = K.square((y_actual-y_pred))
#     loss = K.sqrt(loss)
    loss = K.sum(loss)
    return loss

def plot_image_tensorboard(epoch, logs):
    # Create a plot to visualize image reconstruction progress
    
    # Call the model to get prediction
    pred = model.predict(val_sample[0][0][np.newaxis])

    # Create a mpl figure
    figure = plt.figure(figsize=(10,10))

    # Plot the prediction
    plt.subplot(1, 2, 1)
    plt.title("prediction")
    plt.imshow(pred[0].astype(np.float32))
    # Plot groundtruth
    plt.subplot(1, 2, 2)
    plt.title("ground truth")
    plt.imshow(val_sample[1][0].numpy().astype(np.float32))
    plot_image = plot_to_image(figure)
    with file_writer.as_default():
        tf.summary.image("Prediction vs Ground Truth", plot_image, step=epoch)

import datetime
log_dir=os.path.join('logs', model_description + '-fit') + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch=1000000)
file_writer = tf.summary.create_file_writer(log_dir)

plot_image_tensorboard_cb = keras.callbacks.LambdaCallback(on_epoch_end=plot_image_tensorboard)

# Save model after epochs
checkpoint_cb = ModelCheckpoint(model_weights_path + '.e{epoch:03d}', monitor='val_loss', verbose=0, 
                             save_best_only=False, save_weights_only=True, mode='auto', 
                                save_freq=10*STEPS_PER_EPOCH)
checkpoint_best_cb = ModelCheckpoint(model_weights_path + '.best', monitor='val_loss', verbose=0,
                                    save_best_only=True, save_weights_only=True, mode='auto')

# Training

In [None]:
def SSIMLoss(y_true, y_pred):
    y_true = y_true[..., np.newaxis]
    y_pred = y_pred[..., np.newaxis]
    
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))

In [None]:
comps_path = '/home/rshuai/research/u-net-reconstruction/data/PSFs/processed/rank24_z0_ds2_normalized/comps.npy'
weights_path = '/home/rshuai/research/u-net-reconstruction/data/PSFs/processed/rank24_z0_ds2_normalized/weights.npy'

# Load in comps and weights
h = np.load(comps_path)
weights = np.load(weights_path)

psf = tf.squeeze(tf.math.reduce_sum(h*weights, axis=0))
psf = psf / tf.math.reduce_max(psf)
K = 1

In [None]:
# UNet multiwiener
registered_psfs_path = '/home/rshuai/research/u-net-reconstruction/data/PSFs/9_psfs/psfs_Z1_1_9_registered.npy'

psfs = np.load(registered_psfs_path).transpose([1, 2, 0])
assert psfs.shape == (648, 486, 9)

Ks = np.ones((1, 1, 9))

In [None]:
# Central PSF is psf[4]
psf = psfs[:, :, 4]

In [None]:
# Downsample PSFs for downsampled inputs
# psfs = tf.image.resize(psfs, [648//2, 486//2])

In [None]:
model = UNet_multiwiener_resize(648, 486, psfs, Ks, 
                         encoding_cs=[24, 64, 128, 256, 512, 1024],
                         center_cs=1024,
                         decoding_cs=[512, 256, 128, 64, 24, 24],
                         skip_connections=[True, True, True, True, True, False])

# model = UNet_wiener(648, 486, psf, K, 
#                          encoding_cs=[24, 64, 128, 256, 512, 1024],
#                          center_cs=1024,
#                          decoding_cs=[512, 256, 128, 64, 24, 24],
#                          skip_connections=[True, True, True, True, True, False])

# model = UNet(648, 486,
#                  encoding_cs=[24, 64, 128, 256, 512, 1024],
#                  center_cs=1024,
#                  decoding_cs=[512, 256, 128, 64, 24, 24],
#                  skip_connections=[True, True, True, True, True, False])


initial_learning_rate = 1e-4
adam = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate, beta_1=0.9, beta_2=0.999, amsgrad=False)

model.compile(optimizer=adam, loss=SSIMLoss, metrics=SSIMLoss)
model.build((None, 648, 486, 1))

model.summary()

In [None]:
EPOCHS = 450

In [None]:
val_sample = next(iter(val_dataset)) # Used for logging and plotting.

In [None]:
earlystop_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0.0001, patience=15, verbose=0,
    mode='min', baseline=None, restore_best_weights=False
)

In [None]:
model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, 
          callbacks=[plot_image_tensorboard_cb, tensorboard_callback, checkpoint_cb, checkpoint_best_cb, earlystop_cb], 
          validation_data=val_dataset, validation_steps=VAL_STEPS)

model.save_weights(model_weights_path, save_format='tf')

# Prediction Visualization, Timing Tests

In [None]:
NUM_DISPLAY = 15

In [None]:
# model = UNet_2d()
# model = UNet_2d_wiener(np.zeros_like(psf), K)
# model = UNet_2d_wiener_components(comps, K)


model = UNet_multiwiener_resize(648, 486, psfs, Ks, 
                         encoding_cs=[24, 64, 128, 256, 512, 1024],
                         center_cs=1024,
                         decoding_cs=[512, 256, 128, 64, 24, 24],
                         skip_connections=[True, True, True, True, True, False])

# model = UNet_wiener(648, 486, psf, K, 
#                          encoding_cs=[24, 64, 128, 256, 512, 1024],
#                          center_cs=1024,
#                          decoding_cs=[512, 256, 128, 64, 24, 24],
#                          skip_connections=[True, True, True, True, True, False])

# model = UNet(648, 486,
#                  encoding_cs=[24, 64, 128, 256, 512, 1024],
#                  center_cs=1024,
#                  decoding_cs=[512, 256, 128, 64, 24, 24],
#                  skip_connections=[True, True, True, True, True, False])

epoch = 'best'
model.load_weights(model_weights_path + '.{}'.format(epoch))

In [None]:
import time

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
preds = np.zeros((NUM_DISPLAY, 648, 486))
ground_truths = np.zeros((NUM_DISPLAY, 648, 486))

# preds = np.zeros((NUM_DISPLAY, 640, 480))
# ground_truths = np.zeros((NUM_DISPLAY, 640, 480))

# preds = np.zeros((NUM_DISPLAY, 208, 124))
# ground_truths = np.zeros((NUM_DISPLAY, 208, 124))
sims = np.zeros((NUM_DISPLAY, 648, 486))
# sims = np.zeros((NUM_DISPLAY, 648, 486, 6))



In [None]:
# With downsampled
# preds = np.zeros((NUM_DISPLAY, 320, 240))
# ground_truths = np.zeros((NUM_DISPLAY, 320, 240))
# sims = np.zeros((NUM_DISPLAY, 648//2, 486//2))



### Train

In [None]:
t0 = time.perf_counter()
i = 0
for sim, plane in train_dataset.unbatch():
    preds[i] = model.predict(sim[np.newaxis])
    ground_truths[i] = plane
    sims[i] = np.squeeze(sim)
    i += 1
    if i == NUM_DISPLAY:
        break
        
# assert(i == VAL_LENGTH)

t1 = time.perf_counter()

print('Prediction time per sample:', (t1 - t0) / NUM_DISPLAY, 's')

In [None]:
fig=plt.figure(figsize=(30, 15 * NUM_DISPLAY))
rows, columns = NUM_DISPLAY, 2
for i in range(NUM_DISPLAY):
    # Plot prediction
    s = fig.add_subplot(rows, columns, columns*i+1)
    s.set_title("prediction {0}".format(i), size=20)
    plt.imshow(np.clip(preds[i], 0, 1))
    
    # Plot ground truth
    s = fig.add_subplot(rows, columns, columns*i+2)
    s.set_title("ground truth {0}".format(i), size=20)
    plt.imshow(ground_truths[i])
    
plt.savefig(path.join(image_path, 'train_{}.pdf'.format(epoch)))

### Validation

In [None]:
t0 = time.perf_counter()
i = 0
for sim, plane in val_dataset.unbatch():
    preds[i] = model.predict(sim[np.newaxis])
    ground_truths[i] = plane
    sims[i] = np.squeeze(sim)
    i += 1
    if i == NUM_DISPLAY:
        break
        
# assert(i == VAL_LENGTH)

t1 = time.perf_counter()

print('Prediction time per sample:', (t1 - t0) / NUM_DISPLAY, 's')

In [None]:
fig=plt.figure(figsize=(30, 15 * NUM_DISPLAY))
rows, columns = NUM_DISPLAY, 2
for i in range(NUM_DISPLAY):
    # Plot prediction
    s = fig.add_subplot(rows, columns, columns*i+1)
    s.set_title("prediction {0}".format(i), size=20)
    plt.imshow(np.clip(preds[i], 0, 1))
    
    # Plot ground truth
    s = fig.add_subplot(rows, columns, columns*i+2)
    s.set_title("ground truth {0}".format(i), size=20)
    plt.imshow(ground_truths[i])
    
plt.savefig(path.join(image_path, 'validation_{}.pdf'.format(epoch)))

In [None]:
plt.imshow(ground_truths[0])