In [None]:
import tensorflow as tf
import numpy as np
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from keras.models import Input
from tensorflow import keras
import matplotlib.pyplot as plt
from config_3D import CFG
from data_generator import data_generator_3D
from UNet_RCAN_3D import UNet_RCAN
from loss import loss_3D

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=10000)])
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

In [None]:
data_config = CFG['data']
model_config = CFG['model']
callback = CFG['callbacks']
x_train, y_train = data_generator_3D(data_config)

In [None]:
mip_x_train=np.max(x_train,axis=3)
mip_y_train=np.max(y_train,axis=3)

ix = np.random.randint(0,len(x_train),4)
fig = plt.figure(figsize=(15,7))

for i in range(4):
    fig.add_subplot(2,4, 2*i+1)
    cmap=plt.get_cmap('magma')
    plt.imshow(mip_x_train[ix[i],:,:,0].squeeze(),cmap)
    plt.title('Low SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')

    fig.add_subplot(2,4, 2*i+2)
    cmap=plt.get_cmap('magma')
    plt.imshow(mip_y_train[ix[i],:,:,0].squeeze(),cmap)
    plt.title('High SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')

In [None]:
model_input = Input((data_config['patch_size'], data_config['patch_size'],data_config['fr_end']-data_config['fr_start'], 1))
model = UNet_RCAN(model_config)

optimizer = keras.optimizers.Adam(learning_rate=model_config['lr'])
model.compile(optimizer=optimizer, loss=loss_3D)

callbacks = [
    EarlyStopping(patience=callback['patience_stop'], verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=callback['factor_lr'], patience=callback['patience_lr']),
    ModelCheckpoint(filepath=model_config['save_dr'], verbose=1, save_best_only=True, save_weights_only=True)]

In [None]:
results = model.fit(x=x_train[0:10], y=y_train[0:10], batch_size=model_config['batch_size'],
                    epochs=model_config['n_epochs'],
                    verbose=1, callbacks=callbacks, validation_split=0.1)