## Training a 2D model
Define parameters here, imports and load the GPUs

In [None]:
# define path to config file
config_path = r"config_2D_test.yml"
# define GPU idx
GPU_idx = "0"
# define GPU max memory
GPU_max_memory = 22000

In [None]:
import yaml
import csv
import os
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_idx
import numpy as np
import matplotlib.pyplot as plt
import shutil
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import Input
from tensorflow import keras
from data_generator import data_generator_2D
from UNet_RCAN_2D import UNet_RCAN
from loss import custom_loss_with_l2_reg
tf.random.set_seed(42)
print(tf.__version__)
!nvcc --version
!cat /usr/local/cuda/version.txt

In [None]:
# load GPUs
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=GPU_max_memory)])
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

### Data loading
Data is loaded and preprocessed (patches cropped, normalized).

Config file is red, some data instances displayed, and the config file printed out to double check parameters

In [None]:
# read config file
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

data_config = config['data']
model_config = config['model']
callback = config['callbacks']
x_train, y_train = data_generator_2D(data_config)

# plot training data
ix = np.random.randint(0,len(x_train),4)
fig = plt.figure(figsize=(15,7))

for i in range(4):
    norm_x = np.linalg.norm(x_train[ix[i]], axis=(0, 1))
    fig.add_subplot(2,4, 2*i+1)
    cmap=plt.get_cmap('magma')
    plt.imshow(x_train[ix[i],:,:,0].squeeze(),cmap)
    plt.title('Low SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')
    
    fig.add_subplot(2,4, 2*i+2)
    cmap=plt.get_cmap('magma')
    plt.imshow(y_train[ix[i],:,:,0].squeeze(),cmap)
    plt.title('High SNR',fontdict={'fontsize':18})
    plt_axis = plt.axis('off')

# create model instance
model_input = Input((data_config['patch_size'], data_config['patch_size'], 1))
model = UNet_RCAN(model_config)

if model_config['clip_value'] is not False or model_config['clip_value'] == 0:
    # add gradient clipping
    optimizer = keras.optimizers.Adam(learning_rate=model_config['lr'], clipvalue=model_config['clip_value'])
else:
    optimizer = keras.optimizers.Adam(learning_rate=model_config['lr'])
    
model.compile(optimizer=optimizer, loss=custom_loss_with_l2_reg(model, model_config['loss_type'], data_config['patch_size'], model_config['norm_factor'], model_config['edge_regularization'], model_config['l2_regularization']))

callbacks = [
    EarlyStopping(patience=callback['patience_stop'], verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=callback['factor_lr'], patience=callback['patience_lr']),
    ModelCheckpoint(filepath=model_config['save_dr'], verbose=1, save_best_only=True, save_weights_only=True)]

print(config)
print(model_config)

### Model training

In [None]:
os.makedirs(model_config['save_config'], exist_ok=True)

results = model.fit(x=x_train, y=y_train, batch_size=model_config['batch_size'],
                epochs=model_config["n_epochs"],
                verbose=1, callbacks=callbacks, validation_split=callback["val_split"]) 

### Save results
Model is saved as h5, the loss curve as csv and png and the config file is stored for documentation.

In [None]:
# Save config file
shutil.copyfile(config_path, model_config['save_config'] + "/config.yml")
    
# Save training history
def save_loss_plot(training_history, save_dir, img_format="png"):
    fig = plt.figure()
    x_ticks = range(1, len(training_history.history["loss"])+1)
    if callback["val_split"] == 0:
        plt.plot(x_ticks, training_history.history["loss"], label="training", color="#1f77b4")
        plt.plot(x_ticks, training_history.history["output_1_loss"], "--", label="output_1_loss", color="#1f77b4")
        plt.plot(x_ticks, training_history.history["output_2_loss"], "--", label="output_2_loss", color="#1f77b4")
    else: 
        plt.plot(x_ticks, training_history.history["loss"], label="training", color="#1f77b4")
        plt.plot(x_ticks, training_history.history["val_loss"], label="validation", color="#ff8010")
        plt.plot(x_ticks, training_history.history["output_1_loss"], "--", label="train output_1_loss", color="#1f77b4")
        plt.plot(x_ticks, training_history.history["output_2_loss"], "--", label="train output_2_loss", color="#1f77b4")
        plt.plot(x_ticks, training_history.history["val_output_1_loss"], "--", label="val output_1_loss", color="#ff8010")
        plt.plot(x_ticks, training_history.history["val_output_2_loss"], "--", label="val output_2_loss", color="#ff8010")
    plt.legend(loc="upper right")
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.savefig(save_dir + r"/loss." + img_format)

def save_loss_txt(training_history, save_dir):
    with open(save_dir + r"/loss.csv", "w") as outfile:
        writer = csv.writer(outfile)
        writer.writerow(training_history.history.keys())
        writer.writerows(zip(*training_history.history.values()))
    
save_loss_plot(results, model_config['save_config'], img_format="png")
save_loss_txt(results, model_config['save_config'])

# save weights
model.save_weights(model_config["save_dr"])