In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=12000)])
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

from tensorflow import keras
from keras.models import Input
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau,LearningRateScheduler

from datagenerator_2D import data_generator
from Model_2D_RCAN import make_generator
from loss_2D import generator_loss

In [None]:
patch_size = 256
n_patches =  16
channel_n =  0


augment = False
shuffle = True
add_noise= False
l_poisson = 1.235

threshold = 0.0
ratio = 1.0

GT_image_dr = r'D:\Projects\Denoising-STED\20220913-RPI\UNet-RCAN-different psnr\sequential\train\drift\dAverage.tif'
lowSNR_image_dr =r'D:\Projects\Denoising-STED\20220913-RPI\UNet-RCAN-different psnr\sequential\train\drift\d1frame.tif'

x_train, y_train, x_valid, y_valid = data_generator(GT_image_dr, lowSNR_image_dr, patch_size, n_patches,
                                                    n_channel=channel_n, threshold,ratio,lp=l_poisson, augment=augment, 
                                                    shuffle=shuffle,add_noise=add_noise)

In [None]:
ix = np.random.randint(0,len(x_train),4)
fig = plt.figure(figsize=(15,7))

for i in range(4):
    norm_x = np.linalg.norm(x_train[ix[i]], axis=(0, 1))
    fig.add_subplot(2,4, 2*i+1)
    cmap=plt.get_cmap('magma')
    plt.imshow(x_train[ix[i],:,:,0].squeeze(),cmap)
    plt.title('Low SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')
    
    fig.add_subplot(2,4, 2*i+2)
    cmap=plt.get_cmap('magma')
    plt.imshow(y_train[ix[i],:,:,0].squeeze(),cmap)
    plt.title('High SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')

In [None]:
filters =[32,128,256]

num_filters = filters[0]
r = 8
filters_cab=num_filters/r
num_RG=5
num_RCAB=5

generator_input = Input((patch_size, patch_size,1))
generator = make_generator(generator_input, filters, num_filters,filters_cab,num_RG,num_RCAB,
                           kernel_shape=3,dropout=0.2)

In [None]:
model_save_directory = r"D:\Projects\Denoising-STED\20220913-RPI\UNet-RCAN-different psnr\sequential\mitochondria_lp_1.235.h5" 
n_epochs =  200
gen_lr = 5e-5
batch_size = 1

gen_opt = keras.optimizers.Adam(learning_rate=gen_lr)
generator.compile(optimizer=gen_opt, loss=generator_loss)

callbacks = [
    EarlyStopping(patience=50, verbose=1),
    ReduceLROnPlateau(monitor='val_loss',factor=0.2,patience=5),
    ModelCheckpoint(filepath=model_save_directory, verbose=1, save_best_only=True, save_weights_only=True)]

In [None]:
# load_model_directory = r"D:\Projects\Denoising-STED\20220913-RPI\STED power dependence\tubulin\tubulin_STED70.h5" 
# generator.load_weights(load_save_directory)

In [None]:
results = generator.fit(x=x_train,y=y_train, batch_size=batch_size, epochs=n_epochs,verbose=1,
                        callbacks=callbacks,validation_split=0.1)