# Introduction to Brain Segmentation with Keras

## MAIN 2019 Educational Course 

### Thomas Funck

### McGill University

### **Contact**: email: [tffunck@gmail.com](mailto:tffunck@gmail.com) , Twitter: [@tffunck](https://twitter.com/tffunck)

## Configuring basic options

In [31]:
from minc_keras import create_dir_verbose, setup_dirs
from utils import *

### Set input and label string
input_str='pet.mnc' 
label_str='dtissue.mnc'

### Set filename for .csv that will store data frame 
images_fn='data.csv'

### Set source directory from which data will be read
source_dir="/data1/users/tfunck/pet/data_ses/"

### Set the target directory where output results will be saved
target_dir="/data1/users/tfunck/pet/results"

### Set raiots for train/validation/test
ratios=[0.7,0.15]

### By default we set clobber to False so that we don't overwrite existing files
### Feel free to change if needed
clobber=True

### Size of batches that will be passed to model. The default 2 makes it easy
batch_size=2

### Image dimensions. We are slicing the 3D images into 2D slices. This serves to augment the data
### and make training faster
image_dim=2

### Output activation function
activation_output="softmax"

### This is just a little housekeeping

In [32]:
setup_dirs()  
### Set filename for .csv file that will contain info about input images
images_fn = set_model_name(images_fn, report_dir, '.csv')

## Organize input and label images into train/validate/test splits
#### One of the hard parts of doing deep learning in practice is organizing your data.
#### The following section is a bit of complicated because it involves organizing the input and label files into 
#### a data frame and assigning them to a train/validate/test splits.

### Example data set
#### Train : Data on which network will be trained
#### Validation : Data on which network is evaluated between iterations
#### Test : Data for final evaluation of network
#### 16 images 
![](https://github.com/tfunck/minc_keras/blob/master/images/splits_a_1.png?raw=1)
### Multiple splits possible, depends on amount and structure of data
![](https://github.com/tfunck/minc_keras/blob/master/images/splits_a.png?raw=1)

### Splits with correlated data
#### 16 images
#### Data can be correlated
#### For example, your images may have been collected from different centers or on different scanners
![](https://github.com/tfunck/minc_keras/blob/master/images/splits_b_1.png?raw=1)

#### Example: 3 subtypes of images (e.g. scanner type)
#### ***Don't*** create splits with only 
![](https://github.com/tfunck/minc_keras/blob/master/images/splits_b_2.png?raw=1)
#### ***Do*** balance your subtypes between splits
![](https://github.com/tfunck/minc_keras/blob/master/images/splits_b_3.png?raw=1)

### Splits with correlated data and repeated subjects
#### 3 images x 5 subjects
#### 3 subtypes of images (e.g. scanner type)
![](https://github.com/tfunck/minc_keras/blob/master/images/splits_c.png?raw=1)

![](https://github.com/tfunck/minc_keras/blob/master/images/splits_c_3.png?raw=1)

In [None]:


[images, data] = prepare_data(source_dir, data_dir, report_dir, input_str, label_str, ratios, batch_size,feature_dim, images_fn,  clobber=clobber)

train : expected/real ratio = 70.00 / 70.29
validate : expected/real ratio = 15.00 / 15.22
Saving train images: 180 / 194

## Building a U-NET in Keras

![](https://github.com/tfunck/minc_keras/blob/master/images/unet.png?raw=1)

Ronneberger, Fischer, and Brox. 2015."U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. https://arxiv.org/abs/1505.04597

In [None]:
    ### 1) Define architecture of neural network
    Y_validate=np.load(data["validate_y_fn"]+'.npy')
    nlabels=len(np.unique(Y_validate))#Number of unique labels in the labeled images
    
    img_rows=image_dim[1]
    img_cols=image_dim[2]
    nMLP=16
    nRshp=int(sqrt(nMLP))
    nUpSm=int(image_dim[0]/nRshp)
    image = Input(shape=(image_dim[1], image_dim[2],1))
    
    BN1 = BatchNormalization()(image)

In [None]:
    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(BN1)
    conv1 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

In [None]:
    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(pool1)
    conv2 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

In [None]:
    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(pool2)
    conv3 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

In [None]:
    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(pool3)
    conv4 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

In [None]:
    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(pool4)
    conv5 = Convolution2D(512, 3, 3, activation='relu', border_mode='same')(conv5)

In [None]:
    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=3)
    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(up6)
    conv6 = Convolution2D(256, 3, 3, activation='relu', border_mode='same')(conv6)

In [None]:
    conv6_up = UpSampling2D(size=(2, 2))(conv6)
    conv6_pad = ZeroPadding2D( ((1,0),(1,0)) )(conv6_up)
    up7 = merge([conv6_pad, conv3], mode='concat', concat_axis=3)
    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(up7)
    conv7 = Convolution2D(128, 3, 3, activation='relu', border_mode='same')(conv7)

In [None]:
    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=3)
    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(up8)
    conv8 = Convolution2D(64, 3, 3, activation='relu', border_mode='same')(conv8)

In [None]:
    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=3)
    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(up9)
    conv9 = Convolution2D(32, 3, 3, activation='relu', border_mode='same')(conv9)

In [None]:
    conv10 = Convolution2D(nlabels, 1, 1, activation=activation)(conv9)

    model = keras.models.Model(input=[image], output=conv10)

In [None]:
### 2) Train network on data
model_fn =set_model_name(model_fn, model_dir)
history_fn = splitext(model_fn)[0] + '_history.json'

print( 'Model:', model_fn)
if not exists(model_fn) or clobber:
    #If model_fn does not exist, or user wishes to write over (clobber) existing model
    #then train a new model and save it
    
    #Load input images for training data
    X_train=np.load(data["train_x_fn"]+'.npy')
    #Load labels for training data
    Y_train=np.load(data["train_y_fn"]+'.npy')
    #Load input images for validation data set
    X_validate=np.load(data["validate_x_fn"]+'.npy')
    #Set compiler
    ada = keras.optimizers.Adam(0.0001)
    #Create filename to save checkpoints 
    checkpoint_fn = splitext(model_name)[0]+"_checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5"
    #Create checkpoint callback function for model
    checkpoint = ModelCheckpoint(checkpoint_fn, monitor='val_loss', verbose=0, save_best_only=True, mode='max')
    #Compile the model
    model.compile(loss = loss, optimizer=ada,metrics=[metric] )
    
  
    print("Running with", nb_epoch)
    #
    if loss in categorical_functions : 
        #Convert training data to categorical format
        Y_train = to_categorical(Y_train, num_classes=nlabels)
        #Convert validation data to categorical format
        Y_validate = to_categorical(Y_validate, num_classes=nlabels)
    #Fit model
    history = model.fit([X_train],Y_train,  validation_data=([X_validate], Y_validate), epochs = nb_epoch,callbacks=[ checkpoint])
    #save model   
    model.save(model_name)

    with open(history_fn, 'w+') as fp: json.dump(history.history, fp)
        
        

In [None]:
    ### 3) Evaluate model on test data
    model = load_model(model_fn)
    X_test=np.load(data["test_x_fn"]+'.npy')
    Y_test=np.load(data["test_y_fn"]+'.npy')
    if loss in categorical_functions :
        Y_test=to_categorical(Y_test)
    test_score = model.evaluate(X_test, Y_test, verbose=1)
    print('Test: Loss=', test_score[0], 'Metric=', test_score[1])
    #np.savetxt(report_dir+os.sep+'model_evaluate.csv', np.array(test_score) )

    ### 4) Produce prediction
    #predict(model_fn, validate_dir, data_dir, images_fn, images_to_predict=images_to_predict, category="validate", verbose=verbose)
    #predict(model_fn, train_dir, data_dir, images_fn, images_to_predict=images_to_predict, category="train", verbose=verbose)
    predict(model_fn, test_dir, data_dir, images_fn, loss, images_to_predict=images_to_predict, category="test", verbose=verbose)
    plot_loss(metric, history_fn, model_fn, report_dir)