In [32]:
import os
import glob

import numpy as np
import scipy.io as sio
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf

import time

from src.new_model import *

In [33]:
obj_dims = (648, 486)
BATCH_SIZE = 4
best_is = [265, 442, 480]

test_files = np.array(sorted(glob.glob('../data/main_dataset/test/*')))
output_dir = '../results/'
best_images_dir = '../results/best_images'
res_target_path = '../data/real-data/resTargetZ_1.mat'

# Whether to save ground truths to results directory (set False to save time).
save_gts_flag = False

# Preprocess Resolution Target

In [34]:
res = sio.loadmat(res_target_path)['b']
res = normalize(res).T
res = res.astype(np.float16)[np.newaxis, ..., np.newaxis]

# Create Dataset

In [35]:
def _parse_function(example_proto):
    feature_description = {
        'plane': tf.io.FixedLenFeature(obj_dims, tf.float32),
        'sim': tf.io.FixedLenFeature(obj_dims, tf.float32)
        
    }
    example = tf.io.parse_single_example(example_proto, feature_description)
    plane = example['plane']
    plane_max = tf.reduce_max(plane)
    plane_min = tf.reduce_min(plane)
    plane = (plane - plane_min) / (plane_max - plane_min)  # Normalize values to [0, 1]

    sim = example['sim']
    sim_max = tf.reduce_max(sim)
    sim_min = tf.reduce_min(sim)
    sim = (sim - sim_min) / (sim_max - sim_min)  # Normalize values to [0, 1]

    # Expand to channel dimension
    sim = sim[..., tf.newaxis]
    
    return sim, plane

def create_dataset(filenames, batch_size):
    """
    Takes in string array of filenames for TFRecord files containing samples.
    Returns: TFRecordDataset with given batch size
    """
    raw_dataset = tf.data.TFRecordDataset(filenames)
    dataset = raw_dataset.map(_parse_function)
    dataset = dataset.batch(batch_size)
    
    return dataset

In [36]:
test_dataset = create_dataset(test_files[best_is], BATCH_SIZE)

In [37]:
sims = []
# Get sims in order of best_is
for i, (sim, plane) in enumerate(test_dataset.unbatch()):
    sims.append(sim)

# Timing tests

In [46]:
def time_model(model, sims, best_is, n):
    """
    Time test on each simulated best image.
    Inputs:
        - model: model to test
        - sims: simulated images
        - best_is: index of simulated image, in order of sims
        - n: number of predictions to average times over
    """
    # Predict on dummies to warm up GPU
    for _ in range(10):
        model(np.zeros(sims[0].shape)[None, ...])
    for i in range(len(best_is)):
        t0 = time.time()

        for _ in range(n):
            z = np.clip(model(sims[i][None, ...]), 0, 1)

        t1 = time.time()
        print('sim_{}: {}s per prediction'.format(best_is[i], (t1-t0) / n)) 
        
def time_model_res_target(model, res_target, n):
    """
    Time test on the res_target.
    Inputs:
        - model: model to test
        - res_target: preprocessed res_target
        - n: number of predictions to average times over
    """
    # Predict on dummies to warm up GPU
    for _ in range(50):
        model(np.zeros(sims[0].shape)[None, ...])

    t0 = time.time()
    for _ in range(n):
        z = np.clip(model(res_target), 0, 1)
    t1 = time.time()
    print('res_target: {}s per prediction'.format((t1-t0) / n)) 

## UNet, 9 Learnable Wiener Deconvolutions

In [47]:
psfs = np.zeros((648, 486, 9))
Ks = np.zeros((1, 1, 9))
model = UNet_multiwiener_resize(648, 486, psfs, Ks,
                         encoding_cs=[24, 64, 128, 256, 512, 1024],
                         center_cs=1024,
                         decoding_cs=[512, 256, 128, 64, 24, 24],
                         skip_connections=[True, True, True, True, True, False])

model.load_weights('../models/model_final_5/model_final_5.best')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb9106fce80>

In [48]:
time_model(model, sims, best_is, 100)

sim_265: 0.03986137628555298s per prediction
sim_442: 0.03904870510101319s per prediction
sim_480: 0.039069685935974124s per prediction


In [49]:
time_model_res_target(model, res, 100)

res_target: 0.041536185741424564s per prediction


## UNet, 1 Learnable Wiener Deconvolution

In [50]:
psf = np.zeros((648, 486))
K = 0
model = UNet_wiener(648, 486, psf, K, 
                         encoding_cs=[24, 64, 128, 256, 512, 1024],
                         center_cs=1024,
                         decoding_cs=[512, 256, 128, 64, 24, 24],
                         skip_connections=[True, True, True, True, True, False])

model.load_weights('../models/model_final_2/model_final_2.best')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb97e543208>

In [51]:
time_model(model, sims, best_is, 100)

sim_265: 0.031772370338439944s per prediction
sim_442: 0.030518879890441896s per prediction
sim_480: 0.030229434967041016s per prediction


In [52]:
time_model_res_target(model, res, 100)

res_target: 0.03144089221954346s per prediction


# Basic UNet

In [53]:
model = UNet(648, 486,
                 encoding_cs=[24, 64, 128, 256, 512, 1024],
                 center_cs=1024,
                 decoding_cs=[512, 256, 128, 64, 24, 24],
                 skip_connections=[True, True, True, True, True, False])

model.load_weights('../models/model_final_6/model_final_6.best')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb909b25fd0>

In [54]:
time_model(model, sims, best_is, 100)

sim_265: 0.02665058135986328s per prediction
sim_442: 0.026938197612762452s per prediction
sim_480: 0.0268117094039917s per prediction


In [55]:
time_model_res_target(model, res, 100)

res_target: 0.027177495956420897s per prediction
