# Overview

In this notebook we will cover the basic concepts for training a specific convolutional neural network architecture known as a U-net in TensorFlow. Using this architecture we are able to estimate boundaries of brain tumor tissue types (e.g. segmentation) from multimodal MR images. The data that we will be using in this tutorial comes from the MICCAI Brain Tumor Segmentation Challenge (BRaTS). More information about he BRaTS Challenge can be found here: http://braintumorsegmentation.org/

For basics of Tensorflow operation, neural networks and training, consider reviewing the preceding lectures in this series:

&nbsp;&nbsp;&nbsp;&nbsp; **01 - Introduction to Data, Tensorflow and Deep Learning** <br/>
&nbsp;&nbsp;&nbsp;&nbsp; **02 - Training a Classifier** <br/>
&nbsp;&nbsp;&nbsp;&nbsp; **03 - Inference with a Classifier**

### U-net

The U-net is a popular network design under the family of encoder-decoder architectures. In this type of architecture, an encoding (collapsing) arm gradually reduces the spatial dimension with pooling layers (or strided convolutions) while a decoding (expanding) arm gradually recovers the object details and spatial dimension. There are usually shortcut connections from encoder to decoder to help decoder recover the object details better. For further reading, see link for original paper (https://arxiv.org/abs/1505.04597) as well as a blog here about different segmentation techniques (http://blog.qure.ai/notes/semantic-segmentation-deep-learning-review).

### Importing modules

To train our simple classifer implementation, we will require three open-source libraries (`tensorflow`, `numpy` and `os`) as well as our custom modules created for this tutorial (`net`, `data`). 

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import tensorflow as tf
import net

### Hyperparameter variables

Hyperparameters are parameters whose values are set before the learning process begins and which in turn influence and direct the training process. These will be the three most important hyperparameter variables to vary in this experiment. We will cover these in more detail as they are encountered in the code below.

In [2]:
iterations = 2000
batch_size = 16
learning_rate = 1e-3

### Preparation

Here we perform some basic preparatory steps including making the output directory for saving training outputs, defining an `ops` dictionary to save operations, and reseting any existing graph that may exist

In [3]:
output_dir = '../exp_unet' 
os.makedirs(output_dir, exist_ok=True)
ops = {}
tf.reset_default_graph()

### Data batch

A data **mini-batch** is used to describe the collection of image and label pairs used to perform one update of our network parameters. The more number of images and labels we use for each update, the more likely that update is to reflect the underlying population data. However, the trade-off is that computationally each network update will require more time. A good initial starting point for images matrices of our dataset may be 16 or 32. 

To implement batching, we will use a prepared template method `net.init_batch()` to load a number of slices simulatenously:

In [4]:
batch = net.init_batch(batch_size, root='../data')

### Placeholders

A tensorflow **placeholder** is an entry point for us to feed actual data values into the model. We must define this **placeholder** and all subsequent downstream operations performed on this **placeholder** before ever passing data into the model. 

The placeholder `X` will serve as the method for introduction image data into the graph. The placeholder `y` will serve as the method for introducing the correct target label at each voxel location:
```
0 = background (no tumor)
1 = edema
2 = non-enhancing tumor 
3 = necrosis
4 = enhancing tumor
```

The placeholder `mode` will serve as a method for introducing whether or not the graph is being executed for training or for validation.

In [5]:
X = tf.placeholder(tf.float32, shape=[None, 240, 240, 4], name='X')
y = tf.placeholder(tf.float32, shape=[None, 240, 240, 5], name='y')
mode = tf.placeholder(tf.bool, name='mode')

### Network

In this example we will be using a template U-net created by the `net.create_unet()` method. The encoder arm of the architecture is similar to the classifer from earlier in this tutorial series, implemented by alternating series of convolutions, ReLU non-linearities and max-pooling. In addition a symmetric decoder arm of the architecture uses convolutional-transpose operations to gradually upsample the feature maps and recover high-frequency object details.

To implement this architecture, simply call the `net.create_unet()` method:

In [6]:
pred = net.create_unet(X, training=mode)

### Loss and error

Next, based on these prediction logits, we need to give the algorithm feedback whether or not the network is correct. To do so, we will use a modified **Dice** score loss function. Comparision between the true formal definition and our modified approximation are shown here:
```
Dice (formal) = 2 x (y_pred UNION y_true) 
              -------------------------
               | y_pred | + | y_true | 

Dice (approx) = 2 x (y_pred * y_true) + d 
              -------------------------
              | y_pred | + | y_true | + d 
```

Here *d* is small delta == 1e-7 added both to numerator/denominator to prevent division by zero. Note that the approximation is necessary because the true formal Dice score definition is non-differentiable.

To implement a Dice score loss function, we will use the a prepared template function `net.loss_dice()`.

In [7]:
ops['dice'] = net.loss_dice(pred, y)

### Optimizer

An optimizer is a strategy used to update the network parameters through backprogration by taking into account the quantitative loss function. We will be using the Adam optimizer for our tutorials, an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order moments. For further reading, see the following link for the original paper: https://arxiv.org/abs/1412.6980

A key hyperparameter here is the optimizer **learning rate**. The learning rate describes the absolute magnitude of update for each parameter for one iteration. A higher learning rate will result in a correspondingly larger, more aggresive "step" towards the global minimum of a function, however a learning rate that is too high may cause the network to overshoot the true function minimum and even worse, may lead to network instability. A good initial learning rate to use in most experiments, without other guiding heuristics, is `1e-3` which is what we will set our initial `learning_rate` hyperparameter to.

Note that the `tf.control_dependencies()` method here ensures that any other pending graph operations must be complete before the optimizer node is executed.

In [8]:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.control_dependencies(update_ops):
    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    ops['train'] = optimizer.minimize(-ops['dice'], global_step=global_step)

### Collections

After creating the placeholders and predictions, we will add them to named Graph collections for easy retrieval after training is complete during inference.

In [9]:
# --- Save key placeholders/operations for future reference
tf.add_to_collection("inputs", X)
tf.add_to_collection("inputs", mode)
tf.add_to_collection("outputs", pred)

### TensorBoard

TensorBoard is useful utility that can be used to track various statistics during the network training process. Here we set up operations to create log files that can be loaded using the TensorBoard interface

In [10]:
# --- Add data to TensorBoard
tf.summary.histogram('softmax-scores', pred)
tf.summary.scalar('dice', ops['dice'])
ops['summary'] = tf.summary.merge_all()

# Network training

Now that graph, loss function and optimizer have been configured, it is time to run the training algorithm. To begin we define a new `tf.Session` class and initialize our basic objects to enable saving intermediate checkpoints and writing log data. In addition we initialize `coord` and `thread` objects to handle asynchronous loading of input data into batches:
```
sess, saver, writer_train, writer_valid = net.init_session(sess, output_dir)
```

To perform actual training, we will construct a loop to repeat parameter updates a total of `iteration` times. For each update, we will start by loading the data into batches `X_batch` and `y_batch`:
```
X_batch, y_batch = sess.run([batch['train']['X'], batch['train']['y']])
```

Then we call `sess.run()` to run one iteration of the training process. Specifically we wil request the network to output the `error` (accuracy %), `summary` (used for creating logs) and `step` (global step reflecting total number of iterations). Note that the `ops['train']` operation corresponding to the optimizer node is also called, but there is no output for this function and hence no (`_,`) return variable.
```
 _, error, summary, step  = sess.run(
                [ops['train'], ops['dice'], ops['summary'], global_step],
                feed_dict={
                    X: X_batch, 
                    y: y_batch, 
                    mode: True})
```

Finally, for every 10 updates, will ask the network to run against a separate validation cohort (e.g. completely separate from the training dataset) to track the overall generalization of the algorithm's learned representation:
```
if not i % 10:
    ...
```

This entire training process can be executed by running the following cell:

In [11]:
with tf.Session() as sess:

    # --- Run graph
    sess, saver, writer_train, writer_valid = net.init_session(sess, output_dir)

    try:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        errors = {'train': 0, 'valid': 0}

        for i in range(iterations):

            X_batch, y_batch = sess.run([batch['train']['X'], batch['train']['y']])
            _, error, summary, step  = sess.run(
                [ops['train'], ops['dice'], ops['summary'], global_step],
                feed_dict={
                    X: X_batch, 
                    y: y_batch, 
                    mode: True})

            writer_train.add_summary(summary, step)
            errors = net.update_ema(errors, error, mode='train', iteration=i)
            net.print_status(errors, step)

            # --- Every 10th iteration run a single validation batch
            if not i % 10:

                X_batch, y_batch = sess.run([batch['valid']['X'], batch['valid']['y']])
                error, summary = sess.run(
                    [ops['dice'], ops['summary']],
                    feed_dict={
                        X: X_batch, 
                        y: y_batch, 
                        mode: False})

                writer_valid.add_summary(summary, step / 10)
                errors = net.update_ema(errors, error, mode='valid', iteration=i)
                net.print_status(errors, step)

        saver.save(sess, '%s/checkpoint/model.ckpy' % output_dir)

    finally:
        coord.request_stop()
        coord.join(threads)
        saver.save(sess, '%s/checkpoint/model.ckpy' % output_dir)

0002000 | Dice (train) : 0.99345 | Dice (valid): 0.99386

In the above space you will see updates of algorithm training status including number of iterations and errors on both the training and validation set data.

### Final thoughts

Feel free to continue training the algorithm until convergence at reasonable accuracy. Once complete, turn off the kernel (top menu > `Kernel` > `Shutdown`; you can keep this tab open in your browser to retrain later) so that it's resources can be used in the next notebook. You are now ready to move on the **05 - Inference with a U-net** to use the newly trained network on data.