In [None]:
%matplotlib inline
from openpyxl import Workbook, load_workbook

import sys
import nibabel as nib
import numpy as np
from scipy import ndimage
from keras import backend as K
from sklearn.metrics import accuracy_score, precision_score

from utils import *
from model_FCNN_aniso_loc import generate_model


In [None]:
from importlib import reload

import utils
reload(utils)
from utils import *

import model_FCNN_aniso_loc
reload(model_FCNN_aniso_loc)
from model_FCNN_aniso_loc import generate_model

import callback_custom
reload(callback_custom)

# Parameter setting

In [None]:
num_classes = 11
num_channel = 1
num_dim_loc = 3

# K-fold validation (K=10)
#    iK: iK-th fold validation
iK = 1
n_training = 18
n_test = 2

idxs_test = list(range(1+n_test*(iK-1),1+n_test*iK))
idxs_training = sorted(list(set(range(1,1+n_training+n_test))-set(idxs_test)))

patience = 20
patience_ft = 5
model_filename = 'models/iK{}_outrun_step_{}.h5'.format(iK, '{}')
csv_filename = 'log/iK{}_outrun_step_{}.cvs'.format(iK, '{}')
#output_filename = 'output/iK{}_result.cvs'.format(iK)

nb_epoch = 100
validation_split = 0.10
monitor = 'val_loss'

class_mapper = {0:0}
class_mapper.update({ j+1:i+1 for i,j in enumerate(range(1, 1+10)) })
class_mapper_inv = {0:0}
class_mapper_inv.update({ i+1:j+1 for i,j in enumerate(range(1, 1+10)) })

matrix_size = (160, 220, 48)

extraction_step = (9,9,3)
extraction_step_ft = (9, 9, 3)

segment_size = (27, 27, 9)
core_size = (9, 9, 3)

# 1. Read data

## 1.1 Training data

In [None]:
QSM_train = np.empty(((n_training,) + matrix_size), dtype=precision_global)
XYZ_train = np.empty(((n_training, 3) + matrix_size), dtype=precision_global)
label_train = np.empty(((n_training,) + matrix_size), dtype=precision_global)
for i, case_idx in enumerate(idxs_training):
    QSM_train[i, :, :, :] = read_data(case_idx, 'QSM')
    XYZ_train[i, 0, :, :, :] = read_data(case_idx, 'X')
    XYZ_train[i, 1, :, :, :] = read_data(case_idx, 'Y')
    XYZ_train[i, 2, :, :, :] = read_data(case_idx, 'Z')
    label_train[i, :, :, :] = read_data(case_idx, 'label')

data_train = np.stack((QSM_train,), axis = 1)

if num_dim_loc > 0:
    aux_train = XYZ_train
    data_train = np.concatenate((data_train, aux_train), axis = 1)

## 1.2 Test data

In [None]:
QSM_test = np.empty(((n_test,) + matrix_size), dtype=precision_global)
XYZ_test = np.empty(((n_test, 3) + matrix_size), dtype=precision_global)
label_test = np.empty(((n_test,) + matrix_size), dtype=precision_global)
for i, case_idx in enumerate(idxs_test):
    QSM_test[i, :, :, :] = read_data(case_idx, 'QSM')
    XYZ_test[i, 0, :, :, :] = read_data(case_idx, 'X')
    XYZ_test[i, 1, :, :, :] = read_data(case_idx, 'Y')
    XYZ_test[i, 2, :, :, :] = read_data(case_idx, 'Z')
    label_test[i, :, :, :] = read_data(case_idx, 'label')

data_test = np.stack((QSM_test,), axis = 1)

if num_dim_loc > 0:
    aux_test = XYZ_test
    data_test = np.concatenate((data_test, aux_test), axis = 1)

# 2. Pre-processing

## 2.1 Normalization

In [None]:
input_mean = 127.0
input_std = 128.0
data_train = (data_train - input_mean) / input_std
data_test = (data_test - input_mean) / input_std

## 2.2 Map class label

In [None]:
def map_class_label(arr_label):
    res = np.zeros(arr_label.shape)
    for class_idx in class_mapper:
        res[arr_label == class_idx] = class_mapper[class_idx]
    return res
    
label_train = map_class_label(label_train)
label_test = map_class_label(label_test)

## 2.3 Prepare patch

In [None]:
x_train, y_train, aux_train = build_set(data_train, label_train, extraction_step, segment_size, core_size, None, num_dim_loc)

# shuffle all patches
idxs_shuffle = shuffle(x_train)
shuffle(y_train, idxs_shuffle);
shuffle(aux_train, idxs_shuffle);

# 3. Training

## 3.1 Generate model

In [None]:
seed = 47
np.random.seed(seed)

# Build model
model = generate_model(num_classes, num_channel, segment_size, core_size, num_dim_loc)

## 3.2 Configure callback

In [None]:
from keras.callbacks import ModelCheckpoint
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping

from callback_custom import EarlyStoppingLowLR
from callback_custom import ReduceLROnPlateauBestWeight

# Model checkpoint to save the training results
checkpointer = ModelCheckpoint(
    filepath=model_filename.format('1'),
    monitor=monitor,
    verbose=0,
    save_best_only=True,
    save_weights_only=True)

# CSVLogger to save the training results in a csv file
csv_logger = CSVLogger(csv_filename.format(1), separator=';')

# Early stopper with minimum learning rate
stopper = EarlyStoppingLowLR(patience=patience, monitor=monitor, thresh_LR=1e-5)

# Reduce learning rate in case of no improvement
learning_rate_reduction = ReduceLROnPlateauBestWeight(filepath=model_filename.format('1'),
                                                      monitor=monitor, 
                                                      patience=patience, 
                                                      verbose=1, 
                                                      factor=0.1, 
                                                      min_lr=1.001e-5)

callbacks = [checkpointer, csv_logger, learning_rate_reduction, stopper]

## 3.3 Start training

In [None]:
K.set_value(model.optimizer.lr, 1e-3)

history = model.fit(
    [x_train, aux_train],
    y_train,
    epochs=nb_epoch,
    validation_split=validation_split,
    verbose=1,
    callbacks=callbacks)

# freeing space
#del x_train
#del y_train
#del aux_train

# summarize history for accuracy
plt.figure()
plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

# 4. Validation

## 4.1 Load best model

In [None]:
model = generate_model(num_classes, num_channel, segment_size, core_size, num_dim_loc)

#model.load_weights(model_filename.format(1))
model.load_weights('models/weights_optimal.h5')

## 4.2 Apply to Test data

In [None]:
len_patch = extract_patches(read_data(1, 'QSM'), patch_shape=segment_size, extraction_step=(9, 9, 3)).shape[0]

segmentations_test = []

for i_case, case_idx in enumerate(idxs_test):

    print(case_idx)
    input_test = data_test[i_case, :, :, :, :]

    x_test = np.zeros((len_patch, num_channel+num_dim_loc,) + segment_size, dtype=precision_global)
    for i_channel in range(num_channel):
        x_test[:, i_channel, :, :, :] = extract_patches(input_test[i_channel], patch_shape=segment_size, extraction_step=(9, 9, 3))
    x_test, aux_test = extract_aux(x_test, core_size, num_dim_loc)

    pred = model.predict([x_test, aux_test], verbose=1)
    pred_classes = np.argmax(pred, axis=2)
    pred_classes = pred_classes.reshape((len(pred_classes), 9, 9, 3))
    segmentation = reconstruct_volume(pred_classes, matrix_size)

    segmentations_test = segmentations_test + [segmentation]

segmentations_test = np.stack(segmentations_test, axis=0)

## 4.3 Post-processing

In [None]:
# Pick the largest connected component for each class
for i_case, case_idx in enumerate(idxs_test):
    segmentation = np.squeeze(segmentations_test[i_case,:,:,:]);
    tmp = np.zeros(segmentation.shape, dtype=segmentation.dtype)

    for class_idx in class_mapper_inv :
        mask = (segmentation == class_idx)

        if class_idx != 0 and mask.sum() > 0:
            labeled_mask, num_cc = ndimage.label(mask)
            largest_cc_mask = (labeled_mask == (np.bincount(labeled_mask.flat)[1:].argmax() + 1))

            tmp[largest_cc_mask == 1] = class_idx

    segmentations_test[i_case,:,:,:] = tmp

## 4.4 Calculate metric (Precision and Dice score)

In [None]:
#orig_stdout = sys.stdout
#f = open(output_filename, 'w')
#sys.stdout = f

def calc_dice(m1, m2):
    return 2*((m1==1) & (m2==1)).sum()/((m1==1).sum() + (m2==1).sum())


for i_case, case_idx in enumerate(idxs_test):
    print(case_idx, end='\t')
    print('{:.4f}'.format(accuracy_score(label_test[i_case,:,:,:].flat, segmentations_test[i_case,:,:,:].flat)), end='\t')
    for class_idx in class_mapper_inv:
        mask = (np.squeeze(segmentations_test[i_case,:,:,:]) == class_idx)
        if class_idx != 0 and mask.sum() > 0:
            print('{:.4f}'.format(precision_score(label_test[i_case,:,:,:][mask], segmentations_test[i_case,:,:,:][mask], average='micro')), end='\t')
        else:
            print('N/A', end='\t')
    print()


for i_case, case_idx in enumerate(idxs_test):
    print(case_idx, end='\t')
    for class_idx in class_mapper_inv:
        mask = (np.squeeze(segmentations_test[i_case,:,:,:]) == class_idx)
        if class_idx != 0 and mask.sum() > 0:
            print('{:.4f}'.format(calc_dice((label_test[i_case,:,:,:]==class_idx).flat, (segmentations_test[i_case,:,:,:]==class_idx).flat)), end='\t')
        else:
            print(0, end='\t')
    print()


#sys.stdout = orig_stdout
#f.close()

## 4.5 Save segmentation

In [None]:
for i_case, case_idx in enumerate(idxs_test):
    print(case_idx)

    segmentation = np.copy(np.squeeze(segmentations_test[i_case,:,:,:]))

    tmp = np.copy(segmentation)
    for class_idx in class_mapper_inv:
        segmentation[tmp == class_idx] = class_mapper_inv[class_idx]
    del tmp

    save_data(segmentation, case_idx, 'label')    

print("Finished saving.")