# Setting up and training simple convolutional networks with tensorflow and Keras

In this tutorial, we will learn how to set up simple 3D convolutional network with tensorflow and keras.
As an example, we will set up a network that takes a batch of 3D tensors with 2 channels (e.g. PET and MR) as input and outputs a batch of 3D tensors with 1 channel (denoised and deblurred PET image).
Moreover, we will see how to train a model and how to monitor training. 

The model that we will setup in this tutorial will look like the figure below, except that we won't split and
concatenate the features in the first layer.

![foo bar](https://raw.githubusercontent.com/gschramm/pyapetnet/master/figures/fig_1_apetnet.png)

## Setting up a simple network

In [None]:
# import python modules used in this tutorial
import tensorflow as tf

Before setting up our first model, we define a short helper function that allows us to visualize models in a matplotlib figure.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image  as mpimg
from tempfile import NamedTemporaryFile

def show_model(model):
  """ function that saves structure of a model into png and shows it with matplotlib
  """
  tmp_file = NamedTemporaryFile(prefix = 'cnn_model_', suffix = '.png', dir = '.')
  tf.keras.utils.plot_model(model, to_file= tmp_file.name, show_shapes = True, dpi = 192)
  img = mpimg.imread(tmp_file)
  fig, ax = plt.subplots(figsize = (12,12))
  img = plt.imshow(img)
  ax.set_axis_off()

  return fig, ax

Let's setup the network described above. We can setup the whole network with layers that are predefined in keras which makes life easy. Since our desired output (denoised and beblurred PET image) is "close" to first input channel (the noisy and blurry PET image), we add the first input channel to the output. The batch and spatial dimensions of all layers are "None", since all layers preserve those dimensions. This in turn means the model an be applied to all batch sizes and spatial dimensions. 

In [None]:
def simple_model(nfeat          = 30,      # number of featuers for Conv3D layers
                 kernel_shape   = (3,3,3), # kernel shapes for Conv3D layers
                 nhidden_layers = 6,       # number of hiddenlayers  
                 batch_norm     = True,    # use batch normalization between Conv3D and activation
                 add_final_relu = True):   # add a final ReLU activation at the end to clip negative values

  # setup the input layer for batches of 3D tensors with two channels
  inp = tf.keras.layers.Input(shape = (None, None, None, 2), name = 'input_layer')

  # add a split layer such that we can add the first channel (PET) to the output
  split = tf.keras.layers.Lambda( lambda x: tf.split(x, num_or_size_splits = 2, axis = -1), name = 'split')(inp)

  # add all "hidden" layers
  x   = inp
  for i in range(nhidden_layers):
    x = tf.keras.layers.Conv3D(nfeat, kernel_shape, padding = 'same',
                               kernel_initializer = 'glorot_uniform', name = f'conv3d_{i+1}')(x)
    if batch_norm:
      x = tf.keras.layers.BatchNormalization(name = f'batchnorm_{i+1}')(x)
    x = tf.keras.layers.PReLU(shared_axes=[1,2,3], name = f'prelu_{i+1}')(x)


  # add a (1,1,1) Conv layers with 1 feature to reduce along the feature dimension
  x = tf.keras.layers.Conv3D(1, (1,1,1), padding='same', name = 'conv_final',
                             kernel_initializer = 'glorot_uniform')(x)

  # add first input channel
  x = tf.keras.layers.Add(name = 'add')([x] + [split[0]])

  # add a final ReLU to clip negative values
  if add_final_relu:
    x = tf.keras.layers.ReLU(name = 'final_relu')(x)

  model  = tf.keras.Model(inputs = inp, outputs = x)

  return model

In [None]:
model = simple_model()

Let's print a summary of all layers, connections and the number of trainable parameters.

In [None]:
print(model.summary())

Let's visualize the model using the helper function defined above.

In [None]:
fig, ax = show_model(model)

## Training a neural network with tensorflow and keras

Training a keras model is done via ```model.fit()``` as described in https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit

First, we need to configure the loss function and optimizer for training using ```model.compile()``` where we specify a loss function, an optimizer and a learning rate. See https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile

In this tutorial, we use Mean Squared Error as the loss and the popular Adam optimizer with a learning rate of 1e-3. Note thati:
- you can also you other loss functions such as Mean Absolute Error that measure the distance between the predicted and target image
- 1e-3 is the default starting step size for the Adam optimizer that works well for many applications. In principle, we could also try a bigger learning rate, but that might lead to divergence. Using a smaller learning rate (e.g. 3e-4) is also possible, but will slow down training. A nice simplified explaination of different optimizers is shown here: https://www.youtube.com/watch?v=gmwxUy7NYpA

In [None]:
learning_rate = 1e-3
loss          = tf.keras.losses.MeanSquaredError()

model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate), loss = loss)

After configuring the loss function, optimizer and learning right we can start training. In this tutorial we will use simulated blobs to demonstrate how the training works. Training on "real" is the same. We only have to replace the data loader as shown in the previous tutorial. 

In [None]:
# generate artificial training and validation data
from scipy.ndimage import gaussian_filter
import numpy as np

ndata      = 50            # number of simulated data sets
ntrain     = 40            # number of data sets used for training (rest used for validation)
batch_size = 10
im_shape   = (29,29,29)

np.random.seed(1)

x = np.zeros((ndata,) + im_shape + (2,), dtype = np.float32)
y = np.zeros((ndata,) + im_shape + (1,), dtype = np.float32)

for i in range(ndata):
  # generate a few random binary blobs by thresholding filtered random images
  blobs  = (gaussian_filter(np.random.rand(*im_shape),4.5/2.35) > 0.5).astype(np.float32)
  # the target image is the blob image shifted by a random constand  
  target = (np.random.rand() + 0.5)*blobs + 0.2*np.random.rand()
   
  y[i,:,:,:,0] = target
  # the first input channel is a blurred version of the target image  
  x[i,:,:,:,0] = gaussian_filter(target,3)
  # the second image is the target with different contrast (structural prior image)  
  x[i,:,:,:,1] = (np.random.rand()-0.5)*blobs + 0.2*np.random.rand()
    
train_loader = tf.data.Dataset.from_tensor_slices((x[:ntrain,...], y[:ntrain,...]))
train_dataset = train_loader.shuffle(len(x)).batch(batch_size).prefetch(2)

Let's run a short demo training. We only use a few (50) epochs to results quickly. In real trainings usually around 500 - 1000 epochs are needed.

In [None]:
nepochs = 50
history = model.fit(train_dataset, epochs = nepochs, validation_data = (x[ntrain:,...], y[ntrain:,...]))

Let's plot the evolution of the training and validation loss to see how well the training worked.

In [None]:
fig, ax = plt.subplots()
ax.semilogy(np.arange(1, nepochs + 1), history.history['loss'], label = 'loss')
ax.semilogy(np.arange(1, nepochs + 1), history.history['val_loss'], label = 'validation loss')
ax.set_xlabel('epoch')
ax.legend()
ax.grid(ls=':')

After training, we can use the trained model to make predictions from the validation data.

In [None]:
pred = model.predict(x[40:,...])

Let's show a few of the validation data sets and the corresponding predictions.

In [None]:
import pymirc.viewer as pv
# enable interactive plots with the ipympl package
%matplotlib widget
vi = pv.ThreeAxisViewer([x[ntrain:,...,0].squeeze(), x[ntrain:,...,1].squeeze(), y[ntrain:,...,0].squeeze(), pred[...,0].squeeze()],
                             imshow_kwargs = [{'vmin':0,'vmax':1.6},{'vmin':-0.6,'vmax':0.6},{'vmin':0,'vmax':1.6},{'vmin':0,'vmax':1.6}], 
                             rowlabels = [f'input 0', f'input 1', f'target', f'prediction'])

## Now it's your turn

Now we know everthing that we need to set up and train a simple 3D convoltion neural network aimed to do structure guided deblurring and denoising.

Now it's your turn to:
1. Train a network on the simulated brainweb data using the tensorflow data input pipeline of the previous tutorial. Make sure that use small random patches (size ca 29,29,29) to avoid GPU out of memory errors. To get decent training results, we recommend to use 500-1000 epochs. Use the first 40 for training and the last twenty for validation to monitor potential over-fitting.
2. (optional) Perform some stress tests of your network. What happens e.g. when the change (flip) the contrast of the MR images? What happens if we apply spatial flips to the images?
3. (optional) Calculate the PSNR between the predictions and the target images.
4. (optional) Have a look at keras callbacks that can be passed to ```model.fit()``` to e.g. save the model with the best validation loss, or to dynamically decrease the learning rate. https://www.tensorflow.org/api_docs/python/tf/keras/callbacks
5. (optional) Run trainings with different network hyper parameters (number of hidden layers, number of featuers, batch size ...) and compare the training and validation loss. Visualize a few predictions from the validation data as well. In total we have 3x20=60 simulated data sets.