# Model 26 Coarse Gated Segmentation

A small-scale (multiple downsamplings) network computes a sigmoidal mask multiplied to the input.  This is followed by a repeated contracting and expanding residual network like u-net or pyramid pooling network.


GPU: Tesla K80 11GB.

### Results



### To Try

- Loss Function
    - Try filtering out empty crops from training, so dice or jaccard coefficients are valid.
    - Try removing smoothing from numerator of loss function, so there is no way to improve the network on empty patches.
- u-net or v-net
- spatial pyramid pooling
- Use small patches, small batches, batchnorm
- making sure my loss functions work.
- mean squared error loss
- Add dilated convolution stack to end of network (small fov increase).
- Using Dropout (try 0.1)
- A shallow u-net: Pooling once and taking advantage of the smaller volume to increase channels and layers.  This would lead to a greatly increased fov.  
- experiment with downsampling: try stride 2 2x2x2 conv like in v-net, not that they offered much justification for why this was better than the usual stride 2 3x3x3 conv.


## Imports and Constants, etc.

In [None]:
import datetime
import importlib
import keras
from keras.layers import (Dense, SimpleRNN, Input, Conv1D, 
                          LSTM, GRU, AveragePooling3D, MaxPooling3D, GlobalMaxPooling3D,
                          Conv3D, UpSampling3D, BatchNormalization, 
                          Concatenate, Add, Multiply,
                          GaussianNoise, Dropout, Conv3DTranspose, 
                         )
from keras.models import Model
import nibabel as nib
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
import projd
import random
import re
import scipy
import shutil
import SimpleITK # xvertseg MetaImage files
import sys
from sklearn.model_selection import train_test_split
import uuid

import matplotlib.pyplot as plt # data viz
import seaborn as sns # data viz

import imageio # display animated volumes
from IPython.display import Image # display animated volumes

from IPython.display import SVG # visualize model
from keras.utils.vis_utils import model_to_dot # visualize model

# for importing local code
src_dir = str(Path(projd.cwd_token_dir('notebooks')) / 'src') # $PROJECT_ROOT/src
if src_dir not in sys.path:
    sys.path.append(src_dir)

import util
import preprocessing
import datagen
import modelutil
import xvertseg
import augmentation
import metrics

MODEL_NAME = 'model_26'

DATA_DIR = Path('/data2').expanduser()
# DATA_DIR = Path('~/data/2018').expanduser()
# UVMMC
NORMAL_SCANS_DIR = DATA_DIR / 'uvmmc/nifti_normals'
PROJECT_DATA_DIR = DATA_DIR / 'uvm_deep_learning_project'
PP_IMG_DIR = PROJECT_DATA_DIR / 'uvmmc' / 'preprocessed' # preprocessed scans dir
PP_MD_PATH = PROJECT_DATA_DIR / 'uvmmc' / 'preprocessed_metadata.pkl'
# xVertSeg
XVERTSEG_DIR = DATA_DIR / 'xVertSeg.v1'
PP_XVERTSEG_DIR = PROJECT_DATA_DIR / 'xVertSeg.v1' / 'preprocessed' # preprocessed scans dir
PP_XVERTSEG_MD_PATH = PROJECT_DATA_DIR / 'xVertSeg.v1' / 'preprocessed_metadata.pkl'


MODELS_DIR = PROJECT_DATA_DIR / 'models'
LOG_DIR = PROJECT_DATA_DIR / 'log'
TENSORBOARD_DIR = PROJECT_DATA_DIR / 'tensorboard'
TMP_DIR = DATA_DIR / 'tmp'

for d in [DATA_DIR, NORMAL_SCANS_DIR, PROJECT_DATA_DIR, PP_IMG_DIR, MODELS_DIR, LOG_DIR, 
          TENSORBOARD_DIR, TMP_DIR, PP_MD_PATH.parent, PP_XVERTSEG_DIR, PP_XVERTSEG_MD_PATH.parent]:
    if not d.exists():
        d.mkdir(parents=True)
        
%matplotlib inline
sns.set()

# I love u autoreload!
%load_ext autoreload
%autoreload 2

## Hyperparameters

In [None]:
HP = dict()

SD = 'seed'
VS = 'validation_split'
TS = 'test_split'

HP[SD] = 25 # random seed for dataset shuffling and splitting.
HP[VS] = 0.2 # VALIDATION_SPLIT = 0.2 # 3 samples for validation
HP[TS] = 0.134 # 2 samples for test

BS = 'batch_size'
NB = 'n_batches' # number of batches in an epoch or None for the sensible default
EP = 'epochs'
MQS = 'max_queue_size'

HP[BS] = 1 
HP[NB] = 10 # Used to increase epoch size. 
HP[MQS] = 20
HP[EP] = 100

PS = 'patch_shape'
IS = 'input_shape'
BMT = 'binary_mask_thresh'

# PATCH_SHAPE = (32, 32, 32)
# PATCH_SHAPE = (64, 64, 64) # Used to crop images for training (data augmentation, memory, speed)
HP[PS] = (128, 128, 128) # Big.  Good for visualization.
# PATCH_SHAPE = None # Full sized images

# INPUT_SHAPE = (PATCH_SHAPE + (1,)) # Model input shape adds channel dimension, but not examples dim.
HP[IS] = (None, None, None, 1) # Accept variable size volumes/images.

HP[BMT] = 0.5 # > threshold = 1. <= thresh = 0.

TR = 'transpose'
FL = 'flip'
GS = 'grey_std'
REQUIRE_MASK = 'require_mask'


HP[TR] = False
HP[FL] = 0.5
HP[GS] = 0.01
HP[REQUIRE_MASK] = False


KS = 'kernel_size'
DROPOUT = 'dropout_rate'
NOISE = 'noise_rate' # std dev of gaussian noise
NC = 'n_c' # number of channels
ND = 'n_d' # number of downsamplings
NR = 'n_r' # number of residual blocks
NBL = 'n_blocks'
NDG = 'n_dg' # number of downsamplings for the gate
NRG = 'n_rg' # number of residuals at before the end of the gate network
NRDG = 'n_rdg' # number of residuals immediately after a downsampling in the gate network
MAXC = 'max_c' # maximum number of channels, to keep parameters in check
HP[KS] = (3, 3, 3) # (5, 5, 5) # (7, 7, 7)
HP[NC] = 16 
HP[NDG] = 1 
HP[NRG] = 1
HP[NRDG] = 1
HP[ND] = 1
HP[NR] = 1
HP[NBL] = 1
HP[DROPOUT] = None # 0.1
HP[NOISE] = 0.0001

LOSS = 'loss'
W0 = 'w0'
W1 = 'w1'
HP[LOSS] = 'smooth_dice_loss'
HP[W0] = 1 # binary cross entropy weight for class 0
HP[W1] = 100 # weight informed by the 1-to-0 ratio in the training data.

MODELS = [
    {'id': 'ndg5', 'hp': {**HP, **{}}},
    {'id': 'ndg5a', 'hp': {**HP, **{NDG: 4, NRG: 8, NRDG: 2,
                                    ND: 3, NR: 2, NBL: 4,
                                    MAXC: 64}}},
    {'id': 'ndg4', 'hp': {**HP, **{NDG: 4, NRG: 8, NRDG: 2,
                                   ND: 3, NR: 2, NBL: 4,
                                   MAXC: 64,
                                   LOSS: 'rough_dice_loss',
                                  }}},
    {'id': 'rqmsk', 'hp': {**HP, **{NDG: 4, NRG: 8, NRDG: 2,
                                   ND: 3, NR: 2, NBL: 4,
                                   MAXC: 64,
                                   LOSS: 'smooth_dice_loss',
                                   REQUIRE_MASK: True,
                                  }}},
]

md = MODELS[-1]
hp = md['hp']

In [None]:
hp


## Data Generation

In [None]:
infos_func = lambda: xvertseg.read_xvertseg_metadata(PP_XVERTSEG_MD_PATH)
train_gen, val_gen, test_gen = xvertseg.get_xvertseg_datagens(
    infos_func, seed=hp[SD], validation_split=hp[VS], test_split=hp[TS])

train_gen.config(batch_size=hp[BS], length=hp[NB], crop_shape=hp[PS], flip=hp[FL], 
                 transpose=hp[TR], gray_std=hp[GS], require_mask=hp[REQUIRE_MASK]).reindex()
val_gen.config(batch_size=hp[BS], crop_shape=hp[PS], flip=hp[FL], 
               transpose=hp[TR], gray_std=hp[GS]).reindex()
# val_gen.config(batch_size=1).reindex() # Test full image
test_gen.config(batch_size=1).reindex() # Evaluate using full image

## Build Model

In [None]:
def residual_block(x, n_c, kernel_size=(3,3,3), activation='relu'):
        x_initial = x
        x = Conv3D(n_c, kernel_size=kernel_size, padding='same', activation=activation)(x)
        x = Add()([x_initial, x])  
        return x


def residual_pyramid_block(x, n_c=None, max_c=None, n_d=1, n_r=1, kernel_size=(3, 3, 3)):
    '''
    x is downsampled n_d times, each downsampled layer is convolved n_r times
    and then upsampled and merged with the layer above it.
    '''
    
    if n_c is None:
        n_c = int(x.shape[-1]) # channels last

    # depth 0
    xds = {0: x} # skip connection
    
    def chans(d):
        if max_c is not None:
            return min(n_c * (2**d), max_c)
        else:
            return n_c * (2**d) # The usual double channels when you increase depth.
        

    
    # downsample.  u-net is maxpool, conv, conv
    for d in range(1, n_d + 1):
#         x = AveragePooling3D(padding='same')(x)
#         x = MaxPooling3D(padding='same')(x)
#         x = Conv3D(n_c, kernel_size=(1, 1, 1), activation='relu')
        x = Conv3D(chans(d), kernel_size=kernel_size, strides=(2,2,2), padding='same', 
                   activation='relu')(x)
        xds[d] = x
        
    # convolve.  u-net is none.
    for d in range(1, n_d + 1):
        x = xds[d]
        for j in range(n_r):
            x = residual_block(x, chans(d), kernel_size=kernel_size)    
        
        xds[d] = x
    
    # upsample and merge
    for d in reversed(range(n_d)):
        # upsample and reduce channels to chans(d)
        x = xds[d + 1]        
        x = Conv3DTranspose(chans(d), kernel_size=kernel_size, strides=(2,2,2),
                            padding='same', activation='relu')(x)
        # merge
        x2 = xds[d]
        x = Add()([x, x2])
        xds[d] = x

    return x  
    
    
def build_model(input_shape, n_c=4, max_c=None, n_blocks=4, n_r=4, n_d=4, n_rg=1, n_dg=4, n_rdg=3, noise=None, 
                loss='binary_crossentropy', metrics=[], kernel_size=3):
    '''
    n_rd: number of residual convs after downsampling

    returns: Keras model
    '''
    x_input = Input(shape=input_shape)
    x = x_input
    
    # noise regularization
    if noise: 
        x = GaussianNoise(stddev=noise)(x)

    x = Conv3D(n_c, kernel_size=(5, 5, 5), padding='same', activation='relu')(x)
    
    #
    # COARSE GATE
    #
    
    x_init = x # save x for the gated merge
    
    def chans(d):
        if max_c is not None:
            print('chans max_c', max_c)
            return min(n_c * (2**d), max_c)
        else:
            return n_c * (2**d) # The usual double channels when you increase depth.
        
    for i in range(n_dg):
        # downsample
        n_cd = chans(i+1)
        x = Conv3D(n_cd, kernel_size=kernel_size, strides=(2,2,2), padding='same',
                   activation='relu')(x)
        for j in range(n_rdg):
            print('gate downsampling residuals', i, j, n_rdg, n_cd)
            x = residual_block(x, n_cd, kernel_size=kernel_size)    

    for i in range(n_rg):
        print('gate bottom residuals', i, n_rg, chans(n_dg))
        x = residual_block(x, chans(n_dg), kernel_size=kernel_size)    
        
    # coarse sigmoid
    x = Conv3D(1, kernel_size=(1, 1, 1), activation='sigmoid')(x)
    
    # upsample and gate input
    x = UpSampling3D(size=(2**n_dg, 2**n_dg, 2**n_dg))(x)
    x = Multiply()([x, x_init])

    #
    # SEGMENTATION
    #
    
    for i in range(n_blocks):
        x = residual_pyramid_block(x, n_d=n_d, n_r=n_r, kernel_size=kernel_size)

    x = Conv3D(n_c, kernel_size=kernel_size, padding='same', activation='relu')(x)
    y = Conv3D(1, kernel_size=(1, 1, 1), activation='sigmoid')(x)
    
    model = Model(inputs=x_input, outputs=y)
    model.compile(optimizer='adam', loss=loss, metrics=['accuracy'] + metrics)
    return model

In [None]:
# Loss Function
# weighted_binary_crossentropy_loss = metrics.weighted_binary_crossentropy_loss_func(w0=W0, w1=W1)
if hp[LOSS] == 'rough_dice_loss':
    loss = metrics.make_dice_coefficient_loss(smooth_numerator=False, smooth=1e-5)
    loss_name = 'dice_coefficient_loss'
elif hp[LOSS] == 'smooth_dice_loss':
    loss = metrics.make_dice_coefficient_loss(smooth_numerator=True, smooth=1e-5)
    loss_name = 'dice_coefficient_loss'
elif hp[LOSS] == 'dice_loss':
    loss = metrics.dice_coefficient_loss
    loss_name = 'dice_coefficient_loss'
elif hp[LOSS] == 'dice2_loss':
    loss = metrics.dice_coefficient2_loss
    loss_name = 'dice_coefficient2_loss'
elif hp[LOSS] == 'weighted_binary_crossentropy_loss':
    loss = metrics.weighted_binary_crossentropy_loss_func(w0=hp[W0], w1=hp[W1])
    loss_name = 'weighted_binary_crossentropy_loss'
    
model = build_model(input_shape=hp[IS], n_c=hp[NC], max_c=hp[MAXC], n_r=hp[NR], n_d=hp[ND], 
                    n_blocks=hp[NBL], n_rg=hp[NRG], n_dg=hp[NDG], n_rdg=hp[NRDG],
                    noise=hp[NOISE], kernel_size=hp[KS],
                    loss=loss,
                    metrics=[metrics.dice_coefficient, metrics.binary_dice_coefficient])
print(model.summary())
SVG(model_to_dot(model).create(prog='dot', format='svg'))

In [None]:
model_name = MODEL_NAME + '_' + md['id']

callbacks = [modelutil.get_tensorboard_callback(TENSORBOARD_DIR, model_name),
             modelutil.get_logger_callback(LOG_DIR, model_name),
             modelutil.get_checkpoint_callback(MODELS_DIR, model_name),
            ]
# datagen shuffles every epoch
history = model.fit_generator(train_gen, epochs=hp[EP], validation_data=val_gen, 
                              callbacks=callbacks, max_queue_size=hp[MQS],
                              use_multiprocessing=False, shuffle=True)


## Experimental Notes

## Visualize Training Progress

In [None]:
# read metrics from the log file
# get latest log path
log_path = sorted(LOG_DIR.glob(f'{MODEL_NAME}*_log.csv'))[-1]
print(log_path)
log_data = pd.read_csv(log_path)

In [None]:
pd.concat([log_data[::10], log_data[-1:]]) # every 10th metric and the last one

In [None]:
# Plot Training and Validation Accuracy 
axes = plt.gca()
axes.set_ylim([0.0,1.0]) # Show results on 0..1 range
plt.plot(log_data["acc"])
plt.plot(log_data["val_acc"])
plt.legend(['Training Accuracy', "Validation Accuracy"])
plt.show()

# Plot Training and Validation Loss
plt.plot(log_data["loss"])
plt.plot(log_data["val_loss"])
plt.legend(['Training Loss', "Validation Loss"])
plt.show()

# Plot Training and Validation Dice Coefficient
plt.plot(log_data["dice_coefficient"])
plt.plot(log_data["dice_coefficient"])
plt.legend(['Training Dice Coefficient', "Validation Dice Coefficient"])
plt.show()



### Confusion Matrix Results Over Time

Visualize how the results of the model improve over time.

TODO: Why do the confusion matrices look broken for epoch 10 and 20?


In [None]:
model_name = MODEL_NAME + '_' + md['id']

epochs = [100]
for epoch in epochs:
    print('Epoch', epoch)
    model = modelutil.get_epoch_model(MODELS_DIR, model_name, epoch,
                                      custom_objects={
                                          loss_name: loss
                                                      'dice_coefficient': metrics.dice_coefficient,
                                                      'binary_dice_coefficient': metrics.binary_dice_coefficient})
    modelutil.plot_binary_confusion_matrix(model, train_gen)
    

### Visualize Masks by Epoch

In [None]:
# Evaluate full images
# train_gen.config(batch_size=1, length=10, num_samples=1, crop_shape=None, flip=None, transpose=None, gray_std=None)

epochs = [1, 10, 30]
for epoch in epochs:
    print('Epoch', epoch)
    model = modelutil.get_epoch_model(MODELS_DIR, MODEL_NAME, epoch,
                                      custom_objects={
#                                           'dice_coefficient_loss': metrics.dice_coefficient_loss, 
                                          'dice_coefficient_loss': dice_coefficient_loss, 
                                          'dice_coefficient2_loss': metrics.dice_coefficient2_loss,
#          'weighted_binary_crossentropy_loss': weighted_binary_crossentropy_loss,
                                                      'dice_coefficient': metrics.dice_coefficient,
                                                      'binary_dice_coefficient': metrics.binary_dice_coefficient})
    for i in range(len(train_gen)):
        print('Sequence', i)
        x, y = train_gen[i]
        print(x.shape)
        for j in range(x.shape[0]): # batch size
            print('Input')
            display(util.animate_crop(x[j, :, :, :, 0], step=20))
            print('True')
            display(util.animate_crop(y[j, :, :, :, 0], step=20))
            print('predicting...')
            y_pred = model.predict_on_batch(x)
            y_pred = y_pred > BINARY_MASK_THRESH
            print('Predicted')
            display(util.animate_crop(y_pred[j, :, :, :, 0], step=20))
            

    