In [1]:
import pickle
from model import *
from utils import *
from config import config, log_config
from scipy.io import loadmat, savemat

  from ._conv import register_converters as _register_converters


In [4]:
def main_test():
    mask_perc = 10
    mask_name = "gaussian2d"
    model_name = "unet"

    # =================================== BASIC CONFIGS =================================== #

    print('[*] run basic configs ... ')

    log_dir = "log_inference_{}_{}_{}".format(model_name, mask_name, mask_perc)
    tl.files.exists_or_mkdir(log_dir)
    _, _, log_inference, _, _, log_inference_filename = logging_setup(log_dir)

    checkpoint_dir = "checkpoint_inference_{}_{}_{}".format(model_name, mask_name, mask_perc)
    tl.files.exists_or_mkdir(checkpoint_dir)

    save_dir = "samples_inference_{}_{}_{}".format(model_name, mask_name, mask_perc)
    tl.files.exists_or_mkdir(save_dir)

    # configs
    sample_size = config.TRAIN.sample_size

    # ==================================== PREPARE DATA ==================================== #

    print('[*] load data ... ')
    testing_target_data_path = config.TRAIN.testing_target_data_path 
    testing_blurry_data_path = config.TRAIN.testing_blurry_data_path

    with open(testing_target_data_path, 'rb') as f:
        X_test_target = pickle.load(f)
    with open(testing_blurry_data_path, 'rb') as f:
        X_test_blurry = pickle.load(f)

    print('X_test_target shape/min/max: ', X_test_target.shape, X_test_target.min(), X_test_target.max())
    print('X_test_blurry shape/min/max: ', X_test_blurry.shape, X_test_blurry.min(), X_test_blurry.max())

    print('[*] loading mask ... ')
    if mask_name == "gaussian2d":
        mask = \
            loadmat(
                os.path.join(config.TRAIN.mask_Gaussian2D_path, "GaussianDistribution2DMask_{}.mat".format(mask_perc)))[
                'maskRS2']
    elif mask_name == "gaussian1d":
        mask = \
            loadmat(
                os.path.join(config.TRAIN.mask_Gaussian1D_path, "GaussianDistribution1DMask_{}.mat".format(mask_perc)))[
                'maskRS1']
    elif mask_name == "poisson2d":
        mask = \
            loadmat(
                os.path.join(config.TRAIN.mask_Gaussian1D_path, "PoissonDistributionMask_{}.mat".format(mask_perc)))[
                'population_matrix']
    else:
        raise ValueError("no such mask exists: {}".format(mask_name))

    # ==================================== DEFINE MODEL ==================================== #

    print('[*] define model ... ')

    nw, nh, nz = X_test_target.shape[1:]

    # define placeholders
    t_image_good = tf.placeholder('float32', [sample_size, nw, nh, nz], name='good_image')     
    t_image_bad = tf.placeholder('float32', [sample_size, nw, nh, nz], name='bad_image')
    t_gen = tf.placeholder('float32', [sample_size, nw, nh, nz], name='generated_image')

    # define generator network
    net_test = u_net_bn(t_image_bad, is_train=False, reuse=False, is_refine=False)

    # nmse metric for testing purpose
    nmse_a_0_1 = tf.sqrt(tf.reduce_sum(tf.squared_difference(t_gen, t_image_good), axis=[1, 2, 3]))
    nmse_b_0_1 = tf.sqrt(tf.reduce_sum(tf.square(t_image_good), axis=[1, 2, 3]))
    nmse_0_1 = nmse_a_0_1 / nmse_b_0_1

    # ==================================== INFERENCE ==================================== #

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    tl.files.load_and_assign_npz(sess=sess,
                                 name=os.path.join(checkpoint_dir, 'unet') + '.npz',
                                 network=net_test)

    idex = tl.utils.get_random_int(min=0, max=len(X_test_target) - 1, number=sample_size, seed=config.TRAIN.seed)
    X_samples_good = X_test_target[idex]
    # X_samples_bad = threading_data(X_samples_good, fn=to_bad_img, mask=mask)
    X_samples_bad = X_test_blurry[idex]
    
    x_good_sample_rescaled = (X_samples_good + 1) / 2
    x_bad_sample_rescaled = (X_samples_bad + 1) / 2

    tl.visualize.save_images(X_samples_good,
                             [5, 10],
                             os.path.join(save_dir, "sample_image_good.png"))

    tl.visualize.save_images(X_samples_bad,
                             [5, 10],
                             os.path.join(save_dir, "sample_image_bad.png"))

    tl.visualize.save_images(np.abs(X_samples_good - X_samples_bad),
                             [5, 10],
                             os.path.join(save_dir, "sample_image_diff_abs.png"))

    tl.visualize.save_images(np.sqrt(np.abs(X_samples_good - X_samples_bad) / 2 + config.TRAIN.epsilon),
                             [5, 10],
                             os.path.join(save_dir, "sample_image_diff_sqrt_abs.png"))

    tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - X_samples_bad) / 2, 0, 1),
                             [5, 10],
                             os.path.join(save_dir, "sample_image_diff_sqrt_abs_10_clip.png"))

    tl.visualize.save_images(threading_data(X_samples_good, fn=distort_img),
                             [5, 10],
                             os.path.join(save_dir, "sample_image_aug.png"))
    scipy.misc.imsave(os.path.join(save_dir, "mask.png"), mask * 255)

    print('[*] start testing ... ')

    x_gen = sess.run(net_test.outputs, {t_image_bad: X_samples_bad})
    x_gen_0_1 = (x_gen + 1) / 2

    # evaluation for generated data

    nmse_res = sess.run(nmse_0_1, {t_gen: x_gen_0_1, t_image_good: x_good_sample_rescaled})
    ssim_res = threading_data([_ for _ in zip(x_good_sample_rescaled, x_gen_0_1)], fn=ssim)
    psnr_res = threading_data([_ for _ in zip(x_good_sample_rescaled, x_gen_0_1)], fn=psnr)

    log = "NMSE testing: {}\nSSIM testing: {}\nPSNR testing: {}\n\n".format(
        nmse_res,
        ssim_res,
        psnr_res)

    log_inference.debug(log)
    print(log)
    log = "NMSE testing average: {}\nSSIM testing average: {}\nPSNR testing average: {}\n\n".format(
        np.mean(nmse_res),
        np.mean(ssim_res),
        np.mean(psnr_res))

    log_inference.debug(log)
    print(log)
    log = "NMSE testing std: {}\nSSIM testing std: {}\nPSNR testing std: {}\n\n".format(np.std(nmse_res),
                                                                                        np.std(ssim_res),
                                                                                        np.std(psnr_res))

    log_inference.debug(log)
    print(log)
    # evaluation for zero-filled (ZF) data
    nmse_res_zf = sess.run(nmse_0_1,
                           {t_gen: x_bad_sample_rescaled, t_image_good: x_good_sample_rescaled})
    ssim_res_zf = threading_data([_ for _ in zip(x_good_sample_rescaled, x_bad_sample_rescaled)], fn=ssim)
    psnr_res_zf = threading_data([_ for _ in zip(x_good_sample_rescaled, x_bad_sample_rescaled)], fn=psnr)

    log = "NMSE ZF testing: {}\nSSIM ZF testing: {}\nPSNR ZF testing: {}\n\n".format(
        nmse_res_zf,
        ssim_res_zf,
        psnr_res_zf)

    log_inference.debug(log)
    print(log)
    log = "NMSE ZF average testing: {}\nSSIM ZF average testing: {}\nPSNR ZF average testing: {}\n\n".format(
        np.mean(nmse_res_zf),
        np.mean(ssim_res_zf),
        np.mean(psnr_res_zf))
    print(log)
    log_inference.debug(log)

    log = "NMSE ZF std testing: {}\nSSIM ZF std testing: {}\nPSNR ZF std testing: {}\n\n".format(
        np.std(nmse_res_zf),
        np.std(ssim_res_zf),
        np.std(psnr_res_zf))

    log_inference.debug(log)
    print(log)
    # sample testing images
    tl.visualize.save_images(x_gen,
                             [5, 10],
                             os.path.join(save_dir, "final_generated_image.png"))

    tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - x_gen) / 2, 0, 1),
                             [5, 10],
                             os.path.join(save_dir, "final_generated_image_diff_abs_10_clip.png"))

    tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - X_samples_bad) / 2, 0, 1),
                             [5, 10],
                             os.path.join(save_dir, "final_bad_image_diff_abs_10_clip.png"))

    print("[*] Job finished!")

In [5]:
main_test()

[*] run basic configs ... 
[!] log_inference_unet_gaussian2d_10 exists ...
[!] checkpoint_inference_unet_gaussian2d_10 exists ...
[!] samples_inference_unet_gaussian2d_10 exists ...
[*] load data ... 
X_test_target shape/min/max:  (7101, 320, 320, 1) -2.3643475 1.0
X_test_blurry shape/min/max:  (7101, 320, 320, 1) -2.3760571 1.0
[*] loading mask ... 
[*] define model ... 
  [TL] InputLayer  u_net/input: (50, 320, 320, 1)
  [TL] Conv2dLayer u_net/conv1: shape:[4, 4, 1, 64] strides:[1, 2, 2, 1] pad:SAME act:identity
  [TL] Conv2dLayer u_net/conv2: shape:[4, 4, 64, 128] strides:[1, 2, 2, 1] pad:SAME act:identity
  [TL] BatchNormLayer u_net/bn2: decay:0.900000 epsilon:0.000010 act:<lambda> is_train:False
  [TL] Conv2dLayer u_net/conv3: shape:[4, 4, 128, 256] strides:[1, 2, 2, 1] pad:SAME act:identity
  [TL] BatchNormLayer u_net/bn3: decay:0.900000 epsilon:0.000010 act:<lambda> is_train:False
  [TL] Conv2dLayer u_net/conv4: shape:[4, 4, 256, 512] strides:[1, 2, 2, 1] pad:SAME act:identity
 

`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.


[*] start testing ... 


  cropped = ar[slices]


NMSE testing: [0.4845448  0.50792134 0.6306307  0.73995465 0.6351446  0.8985752
 0.6915286  1.0957898  0.87166363 1.128837   0.959284   0.60525024
 1.0555382  0.89963925 0.825426   0.6701039  0.76496637 1.0106449
 0.82135564 0.7902379  0.70933455 0.82118195 0.70437604 0.5051079
 0.879649   1.0167056  0.79714406 0.6287065  0.8421357  0.9720973
 0.40469018 0.5161609  0.51721096 0.7063881  0.8827013  0.86969614
 0.51896566 0.8140495  0.8185586  1.015555   1.1005418  0.78933877
 0.788019   0.6893532  1.0004158  0.7718212  0.9411407  0.7411088
 0.87945443 0.5856536 ]
SSIM testing: [0.08431448 0.04424626 0.08862281 0.14006152 0.19399762 0.17409586
 0.13420404 0.12618039 0.19431753 0.08338707 0.10464348 0.23931321
 0.07810366 0.15650102 0.25786438 0.17579347 0.18178908 0.09370698
 0.10316552 0.20746282 0.16044188 0.14436687 0.16332029 0.12972883
 0.12141533 0.08404027 0.07203863 0.10376154 0.19170093 0.11018432
 0.06220676 0.12114455 0.3679927  0.07729224 0.19851422 0.09161095
 0.08392531 0.1