# CNN to classify large 3D regions

## Requirements

* Keras (with TensorFlow backend)
* nibabel (interface to nifti files)
* scikit-learn

In [None]:
import numpy as np
import nibabel as nib
import os
from glob import glob
from sklearn.model_selection import StratifiedKFold

from keras.models import Model
from keras.layers import Input, merge, Convolution3D, MaxPooling3D, Flatten, Dense, Dropout, Activation
from keras.layers.normalization import BatchNormalization
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as K
from keras import metrics
from keras.regularizers import l2
K.set_image_dim_ordering('th')

os.environ['CUDA_VISIBLE_DEVICES']="0"

## Define functions to load 3D images

In [None]:
################################################
# load image
################################################

def img_load(load_data_flag, img_suffix, dir_class1, dir_class2, dir_all, idno, train_ind, train_train_ind, train_val_ind, test_ind):

    if load_data_flag:

        # loading all the image together does not work for large dataset (due to the limit in memory)

        img_train_class1 = np.array(
            [nib.load(dir_class1+str(id)+img_suffix).get_data() for id in
             [idno[i] for i in train_ind[train_train_ind]]])
        img_train_class2 = np.array(
            [nib.load(dir_class2+str(id)+img_suffix).get_data() for id in
             [idno[i] for i in train_ind[train_train_ind]]])

        img_val_class1 = np.array(
            [nib.load(dir_class1+str(id)+img_suffix).get_data() for id in
             [idno[i] for i in train_ind[train_val_ind]]])
        img_val_class2 = np.array(
            [nib.load(dir_class2+str(id)+img_suffix).get_data() for id in
             [idno[i] for i in train_ind[train_val_ind]]])

        img_test_class1 = np.array(
            [nib.load(dir_class1 + str(id)+img_suffix).get_data() for id in
            [idno[i] for i in test_ind]])
        img_test_class2 = np.array(
            [nib.load(dir_class2 + str(id)+img_suffix).get_data() for id in
            [idno[i] for i in test_ind]])

        img_train = \
        np.zeros((img_train_class1.shape[0]*2,img_train_class1.shape[1],img_train_class1.shape[2],img_train_class1.shape[3]))
        img_train[range(0,img_train.shape[0],2),:,:,:]=img_train_class1
        img_train[(range(1,img_train.shape[0],2)),:,:,:]=img_train_class2
        img_val  = np.concatenate((img_val_class1,img_val_class2),axis=0)      
        img_test = np.concatenate((img_test_class1,img_test_class2),axis=0)

        img_train = np.expand_dims(img_train,axis=1)
        img_val   = np.expand_dims(img_val, axis=1)
        img_test  = np.expand_dims(img_test,axis=1)

        label_train = np.zeros(img_train.shape[0])
        label_train[range(0,img_train.shape[0],2)]=0
        label_train[range(1,img_train.shape[0],2)]=1        
        label_val  = np.concatenate((np.zeros(img_val.shape[0]/2),np.ones(img_val.shape[0]/2)),axis=0)
        label_test = np.concatenate((np.zeros(img_test.shape[0]/2),np.ones(img_test.shape[0]/2)),axis=0)

        np.save(dir_all+'img_train.npy',img_train)
        np.save(dir_all+'img_val.npy', img_val)
        np.save(dir_all+'img_test.npy',img_test)
        np.save(dir_all+'label_train.npy',label_train)
        np.save(dir_all+'label_val.npy', label_val)
        np.save(dir_all+'label_test.npy',label_test)
        
    else:
        img_train   = np.load(dir_all + 'img_train.npy')
        img_val     = np.load(dir_all+'img_val.npy')
        img_test    = np.load(dir_all+'img_test.npy')
        label_train = np.load(dir_all + 'label_train.npy')
        label_val   = np.load(dir_all+'label_val.npy')
        label_test  = np.load(dir_all+'label_test.npy')
       
    
    return img_train, img_val, img_test, label_train, label_val, label_test
 
################################################
# image normalization
################################################

def img_norm(img, int_max, int_min, int_mean):
    
    img[img < int_min]    = int_min
    img[img > int_max]    = int_max
    img = (img - int_mean)*1.0/(int_max - int_min)
    
    return img


## Define network and model

In [None]:
def train_network(img_train, img_val, label_train, label_val, image_size, depth_level):
    model = get_network(image_size, depth_level)
    model_checkpoint = ModelCheckpoint('model.hdf5', monitor='loss', save_best_only=True)
    model.summary()
    model.fit(img_train, label_train, batch_size=8, epochs=10, verbose=1, shuffle=False,
              validation_data=[img_val,label_val])    
    return model

def get_network(img_size, depth_level):
    inputs = Input((1,img_size[0], img_size[1], img_size[2]))
    first_channel_num = 4
    conv1 = Convolution3D(first_channel_num, (3, 3, 3), activation='relu', padding = 'same',
                          kernel_regularizer=l2(0.00))(inputs)
    conv1 = Convolution3D(first_channel_num, (3, 3, 3), activation=None, padding = 'same',
                          kernel_regularizer=l2(0.00))(conv1)
    bn1 = BatchNormalization(axis=1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros',
                       gamma_initializer='ones', moving_mean_initializer='zeros', moving_variance_initializer='ones',
                       beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None)(conv1)
    relu1 = Activation('relu')(bn1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(relu1)

    conv2 = Convolution3D(first_channel_num*2, (3, 3, 3), activation='relu', padding = 'same',
                          kernel_regularizer=l2(0.00))(pool1)
    conv2 = Convolution3D(first_channel_num*2, (3, 3, 3), activation=None, padding = 'same',
                          kernel_regularizer=l2(0.00))(conv2)
    bn2 = BatchNormalization(axis=1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros',
                             gamma_initializer='ones', moving_mean_initializer='zeros',
                             moving_variance_initializer='ones',
                             beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
                             gamma_constraint=None)(conv2)
    relu2 = Activation('relu')(bn2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(relu2)

    conv3 = Convolution3D(first_channel_num*4, (3, 3, 3), activation='relu', padding = 'same',
                          kernel_regularizer=l2(0.00))(pool2)
    conv3 = Convolution3D(first_channel_num * 4, (3, 3, 3), activation=None, padding='same',
                          kernel_regularizer=l2(0.00))(conv3)
    bn3 = BatchNormalization(axis=1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros',
                             gamma_initializer='ones', moving_mean_initializer='zeros',
                             moving_variance_initializer='ones',
                             beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
                             gamma_constraint=None)(conv3)
    relu3 = Activation('relu')(bn3)
    pool3 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(relu3)

    conv4 = Convolution3D(first_channel_num*8, (3, 3, 3), activation='relu', padding = 'same',
                          kernel_regularizer=l2(0.00))(pool3)
    conv4 = Convolution3D(first_channel_num * 8, (3, 3, 3), activation=None, padding='same',
                          kernel_regularizer=l2(0.00))(conv4)
    bn4 = BatchNormalization(axis=1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros',
                             gamma_initializer='ones', moving_mean_initializer='zeros',
                             moving_variance_initializer='ones',
                             beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
                             gamma_constraint=None)(conv4)
    relu4 = Activation('relu')(bn4)
    pool4 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(relu4)

    conv5 = Convolution3D(first_channel_num * 16, (3, 3, 3), activation='relu', padding='same',
                          kernel_regularizer=l2(0.00))(pool4)
    conv5 = Convolution3D(first_channel_num * 16, (3, 3, 3), activation=None, padding='same',
                          kernel_regularizer=l2(0.00))(conv5)
    bn5 = BatchNormalization(axis=1, momentum=0.99, epsilon=0.001, center=True, scale=True, beta_initializer='zeros',
                             gamma_initializer='ones', moving_mean_initializer='zeros',
                             moving_variance_initializer='ones',
                             beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
                             gamma_constraint=None)(conv5)
    relu5 = Activation('relu')(bn5)
    pool5 = MaxPooling3D(pool_size=(2, 2, 2), strides=2)(relu5)
    
    if depth_level == 1:
        flatten1 = Flatten()(pool1)
    elif depth_level == 2:
        flatten1 = Flatten()(pool2)
    elif depth_level == 3:
        flatten1 = Flatten()(pool3)
    elif depth_level == 4:
        flatten1 = Flatten()(pool4)
    else:
        flatten1 = Flatten()(pool5)

    fc1 = Dense(1, activation='sigmoid', use_bias=True)(flatten1)

    model = Model(output=fc1,input=inputs)
    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=["accuracy"])

    return model

## Begin training

In [None]:
####################################################################
# Predefined parameters
####################################################################

img_size       = [128, 128, 128] # size of ROI
int_min        = -1000  # min value for intensity normalization
int_max        = 0      # max value for intensity normalization
int_mean       = -800   # mean value for intensity normalization
load_data_flag = True

####################################################################
# Predefined directories
####################################################################

dir_class1  = '../sample_class1/'
dir_class2  = '../sample_class2/'
dir_all     = '../sample_all/'

img_suffix  = '.nii.gz' # suffix for 3D data
idno_list   = sorted(glob(dir_class1+'*'+img_suffix))
idno        = []

for id in range(0,len(idno_list)):
    idno.append(idno_list[id].split('/')[-1][:-len(img_suffix)])

####################################################################
# Split training, validation and test
####################################################################

# train vs. val vs. test = 4:1:1
skf_train_test = StratifiedKFold(n_splits=6,random_state=1)
skf_train_val  = StratifiedKFold(n_splits=5,random_state=1)

# change stratify_group if you have prior knowledge (e.g. demographic info) for stratified sampling
stratify_group = np.ones(len(idno))

####################################################################
# Begin training and evaluation
####################################################################

for train_ind,test_ind in skf_train_test.split(idno,stratify_group):
    
    # just access the first fold without a for loop
    train_train_ind, train_val_ind = list(skf_train_val.split([idno[i] for i in train_ind], stratify_group[train_ind]))[0]
      
    img_train, img_val, img_test, label_train, label_val, label_test = \
    img_load(load_data_flag, img_suffix, dir_class1, dir_class2, dir_all, idno, train_ind, train_train_ind, train_val_ind, test_ind)
    
    img_train = img_norm(img_train, int_max, int_min, int_mean)
    img_val   = img_norm(img_val, int_max, int_min, int_mean)
    img_test  = img_norm(img_test, int_max, int_min, int_mean)
        
    depth_level  = 4 # number of conv layers = depth_level * 2
    model        = train_network(img_train,img_val,label_train,label_val,img_size,depth_level)
    test_predict = model.predict(img_test)
    model.evaluate(img_test,label_test)