In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from keras.models import Input
from tensorflow import keras
from config import CFG
from datagenerator import data_generator_3D
from model import UNet, RCAN, UNet_RCAN
from loss import loss
from evaluation_parameters import nmse_psnr_ssim_3D
from tifffile import imwrite

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=12000)])
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

In [None]:
data_config = CFG['data_test']
model_config = CFG['model']
x_test, y_test = data_generator_3D(data_config)

In [None]:
mip_x_test=np.max(x_test,axis=3)
mip_y_test=np.max(y_test,axis=3)

ix = np.random.randint(0,len(x_test),4)
fig = plt.figure(figsize=(15,7))

for i in range(4):
    fig.add_subplot(2,4, 2*i+1)
    cmap=plt.get_cmap('magma')
    plt.imshow(mip_x_test[ix[i],:,:,0].squeeze(),cmap)
    plt.title('Low SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')

    fig.add_subplot(2,4, 2*i+2)
    cmap=plt.get_cmap('magma')
    plt.imshow(mip_y_test[ix[i],:,:,0].squeeze(),cmap)
    plt.title('High SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')

In [None]:
model_input = Input((data_config['patch_size'], data_config['patch_size'],
                     data_config['fr_end'] - data_config['fr_start'], 1))
model = eval(model_config['model_type'] + "(model_input, model_config)")
model(np.zeros(
    (1, data_config['patch_size'], data_config['patch_size'], data_config['fr_end'] - data_config['fr_start'], 1)))
model.load_weights(model_config['save_dr'])


prediction2 = np.zeros(x_test.shape)

for i in range(len(x_test)):
    prediction = model(x_test[i:i + 1], training=False)
    prediction2[i] = prediction[model_config['model_type']]
    prediction2[i] = prediction2[i] / prediction2[i].max()
prediction2[prediction2 < 0] = 0

mip_x_test=np.max(x_test,axis=3)
mip_prediction2=np.max(prediction2,axis=3)
mip_y_test=np.max(y_test,axis=3)

In [None]:
ix = np.random.randint(len(mip_prediction2))
fig = plt.figure(figsize=(40,40))

plt.subplot(1, 3, 1)
plt.imshow(mip_x_test[ix, :, :, 0] , cmap='magma')
plt.title('Low SNR Input',fontdict={'fontsize':20})
plt_axis = plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(mip_prediction2[ix, :, :, 0] , cmap='magma')
plt.title('Prediction by RCAN',fontdict={'fontsize':20})
plt_axis = plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(mip_y_test[ix, :, :, 0] , cmap='magma')
plt.title('Ground Truth',fontdict={'fontsize':20})
plt_axis = plt.axis('off')

In [None]:
imageq_param = np.zeros((6,len(prediction2)))

imageq_param[0::2,:] = nmse_psnr_ssim_3D(x_test,y_test)
imageq_param[1::2,:] = nmse_psnr_ssim_3D(prediction2,y_test)

plt.rcParams.update({'font.size': 30})
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(30, 10))
labels = ['noisy', 'prediction']

bplot1 = axes[0].boxplot([imageq_param[0,:],imageq_param[1,:]],
                         vert=True,  
                         patch_artist=True,  
                         labels=labels,showfliers=False)
axes[0].set_ylim(0, 1.2*imageq_param[0:2].max())
axes[0].set_title('NMSE',fontsize=30)

bplot2 = axes[1].boxplot([imageq_param[2,:],imageq_param[3,:]],
                         vert=True,  
                         patch_artist=True,  
                         labels=labels,showfliers=False) 
axes[1].set_ylim(0, 1.5*imageq_param[2:4].max())
axes[1].set_title('PSNR',fontsize=30)

bplot3 = axes[2].boxplot([imageq_param[4,:],imageq_param[5,:]],
                         vert=True,  
                         patch_artist=True,  
                         labels=labels,showfliers=False)

axes[2].set_ylim(0, 1)
axes[2].set
cc= axes[2].set_title('MS-SSIM',fontsize=30)



np.savetxt(os.path.join(data_config['save_dr'], '', 'eval_param.csv'), np.transpose(imageq_param),
            header="NMSE_noisy,NMSE_prediction,PSNR_noisy,PSNR_prediction,SSIM_noisy,SSIM_prediction", delimiter=",")
plt.savefig(os.path.join(data_config['save_dr'], '', 'eval_param.tif'))
plt.show()
plt.close()

In [None]:
pred2 = np.moveaxis(prediction2, 3, 1)
noisy = np.moveaxis(x_test, 3, 1)
gt = np.moveaxis(y_test, 3, 1)

pred2 = (pred2 * (2 ** 16 - 1)).astype(np.uint16)
noisy = (noisy * (2 ** 16 - 1)).astype(np.uint16)
gt = (gt * (2 ** 16 - 1)).astype(np.uint16)


imwrite(os.path.join(data_config['save_dr'], '', 'pred.tif'), pred2.squeeze(), imagej=True, metadata={'axes': 'TZYX'})
imwrite(os.path.join(data_config['save_dr'], '', 'noisy.tif'), noisy.squeeze(), imagej=True, metadata={'axes': 'TZYX'})
imwrite(os.path.join(data_config['save_dr'], '', 'gt.tif'), gt.squeeze(), imagej=True, metadata={'axes': 'TZYX'})