### Training U-Net model for sparse-view artifact correction.

In [None]:
import math
import scipy.io
import random
import os
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
import imshow_grid as ig
from datetime import datetime
from skimage.metrics import structural_similarity as ssim

import tensorflow as tf
from tensorflow.keras import regularizers
from tensorflow.keras import optimizers
from tensorflow.keras.optimizers import schedules
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, TensorBoard
from tensorflow.keras.layers import Input, Conv2D, Conv3D, Conv3DTranspose, Lambda, Reshape, Add, MaxPooling2D, UpSampling2D, Subtract, Activation
from tensorflow.keras.layers import Concatenate
import tensorflow.keras.backend as K
from tensorboard import summary
from IPython import display
from IPython.display import clear_output

import ImportantFunctions as ImFunc

#### 1. Setup

In [None]:
# set GPU:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="?"

# check GPU in-use:
print(tf.test.is_built_with_cuda())
print(tf.config.list_physical_devices('GPU'))
print(tf.test.gpu_device_name())
print(tf.test.is_gpu_available())

In [None]:
# set parameters:
patience = 20
batch_size = 6
init_lr = 0.001
num_epochs = 30
angle_list=[128]
N = 512 #image size
geometry='parallel' # "parallel" or "fanflat" possible
tag = 'dualUnet' # tag of the used U-Net variant

In [None]:
# set paths:
train_path, val_path, test_path = ???, ???, ???
checkpoint_dir, TB_logs_dir = ???, ???

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

if not os.path.exists(TB_logs_dir):
    os.makedirs(TB_logs_dir)

In [None]:
# create paths for checkpoints and Tensorboard logs:
full_checkpoint_path = os.path.join(checkpoint_dir, "{}_{}_bs{}_lr{}_ep{}".format(tag, ImFunc.get_anglenames(angle_list), batch_size, init_lr, num_epochs), datetime.now().strftime("%Y_%m_%d__%H_%M"))

full_TB_logs_path = os.path.join(TB_logs_dir, "{}_{}_bs{}_lr{}_ep{}".format(tag, ImFunc.get_anglenames(angle_list), batch_size, init_lr, num_epochs), datetime.now().strftime("%Y_%m_%d__%H_%M"))

if not os.path.exists(full_checkpoint_path):
    os.makedirs(full_checkpoint_path)

if not os.path.exists(full_TB_logs_path):
    os.makedirs(full_TB_logs_path)

In [None]:
file_writer = tf.summary.create_file_writer(full_TB_logs_path + '/logs')
checkpoint = ModelCheckpoint(filepath = full_checkpoint_path + '/{epoch:d}',
                             monitor='val_loss', verbose=0, save_freq = 'epoch',
                             mode ='auto')
earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=patience,
                          verbose=1, mode='min', restore_best_weights = False)
tensorboard = TensorBoard(log_dir= os.path.join(full_TB_logs_path, "logs"),
                          histogram_freq=1, write_graph=True, write_images=False,
                          update_freq=20, profile_batch=0, embeddings_freq=0,
                          embeddings_metadata=None)

MyCallbacks = [checkpoint, earlystop, tensorboard, ImFunc.lr_scheduler]

#### 2. Generate data

In [None]:
# train and validation set:
train_gen = ImFunc.generate_batches_residual(train_path, geometry, batch_size, angle_list)
val_gen = ImFunc.generate_batches_residual(val_path, geometry, batch_size, angle_list)

# number of steps:
steps_per_epoch = ImFunc.get_number_of_steps(train_path, geometry, batch_size)
val_steps = ImFunc.get_number_of_steps(val_path, geometry, batch_size)

#### 3. Train

In [None]:
# make model:
model = ImFunc.make_or_restore_current_model(full_checkpoint_path, tag, init_lr, img_dim=N)

# fit the model:
history = model.fit(train_gen, steps_per_epoch = steps_per_epoch, epochs = num_epochs,
                    verbose=1, validation_data=val_gen, validation_steps= val_steps, 
                    callbacks = MyCallbacks)

#### ../Checkpoints/..

In [None]:
# restoring model and setting new learning rate:
wanted_epoch = 11
model = ImFunc.restore_model_from_epoch(full_checkpoint_path, wanted_epoch)

print("Current learning rate: {}".format(K.get_value(model.optimizer.lr)))