In [None]:
import os
import numpy as np
import tensorflow as tf
from keras.models import Input
from tifffile import imwrite
from config import CFG
from datagenerator import data_generator, data_generator_test
from model import UNet_RCAN
import matplotlib.pyplot as plt

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=12000)])
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

In [None]:
data_config = CFG['data_test']
model_config = CFG['model']
if data_config['train']:
    x_test, w_test, y_test = data_generator(data_config)

else:
    x_test = data_generator_test(data_config)

In [None]:
patch_size = int(data_config['patch_size']*data_config['scale'])
model_input = Input((data_config['patch_size'], data_config['patch_size'], 1))
model = eval(model_config['model_type'] + "(model_input, model_config)")
model(np.zeros((1, data_config['patch_size'], data_config['patch_size'], 1)))
model.load_weights(model_config['save_dr'])

prediction1 = np.zeros((len(x_test), data_config['patch_size'], data_config['patch_size'], 1))
prediction2 = np.zeros((len(x_test), patch_size, patch_size, 1))

for i in range(len(x_test)):
    prediction = model(x_test[i:i + 1], training=False)
    prediction1[i] = prediction['UNet']
    prediction2[i] = prediction[model_config['model_type']]
    prediction1[i] = prediction1[i] / prediction1[i].max()
    prediction2[i] = prediction2[i] / prediction2[i].max()
prediction1[prediction1 < 0] = 0
prediction2[prediction2 < 0] = 0

In [None]:
ix = np.random.randint(len(prediction2))
fig = plt.figure(figsize=(40,40))

plt.subplot(1, 3, 1)
plt.imshow(x_test[ix, :, :, 0] , cmap='magma')
plt.title('Low SNR Input',fontdict={'fontsize':20})
plt_axis = plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(prediction1[ix, :, :, 0] , cmap='magma')
plt.title('Prediction-scaled',fontdict={'fontsize':20})
plt_axis = plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(prediction2[ix, :, :, 0] , cmap='magma')
plt.title('Prediction',fontdict={'fontsize':20})
plt_axis = plt.axis('off')

In [None]:
if not data_config['train']:
    pred1 = (prediction1 * (2 ** 16 - 1)).astype(np.uint16)
    pred2 = (prediction2 * (2 ** 16 - 1)).astype(np.uint16)
    X_test = (x_test * (2 ** 16 - 1)).astype(np.uint16)
    imwrite(os.path.join(data_config['save_dr'], '', 'pred_scaled.tif'), pred1.squeeze(), imagej=True,
            metadata={'axes': 'TYX'})
    imwrite(os.path.join(data_config['save_dr'], '', 'pred.tif'), pred2.squeeze(), imagej=True,
            metadata={'axes': 'TYX'})
    imwrite(os.path.join(data_config['save_dr'], '', 'noisy.tif'), X_test.squeeze(), imagej=True,
            metadata={'axes': 'TYX'})
if data_config['train']:
    Y_test = (y_test * (2 ** 16 - 1)).astype(np.uint16)
    imwrite(os.path.join(data_config['save_dr'], '', 'gt.tif'), Y_test.squeeze(), imagej=True, metadata={'axes': 'TYX'})