In [None]:
from   ipywidgets import interactive, fixed
import matplotlib.pyplot as plt
from   model.cnn import CNN
import numpy as np
import pandas as pd
import random
from   scipy.ndimage import gaussian_filter1d
from   scipy.stats import norm
from   source.dataloader import BatchgenWrapper, TrainingDataGenerator, ValidationDataGenerator
from   source.metrics import roc_auc
from   source.preprocessing import get_preprocessed_polyp_segmentation_mask_info, train_validation_split
from   source.visualization import plotter_batch, plotter_gradcam
from   tensorflow.keras.callbacks import Callback, EarlyStopping
from   tensorflow.keras.optimizers import Adam, SGD
%config Completer.use_jedi = False

#### Settings

In [None]:
data_dir = "preprocessed-data"

# Number of trained estimators in the ensemble
n_estimators = 3

# Image properties
raw_image_size = (100,100,100)
patch_size     = (50,50,50) # Field-of-view of the network
n_channels     = 1 # 1 = CT image only, 2 = CT image + manual expert segmentation mask

# Training
batch_size = 1
train_size = 0.8 # Train-Validation Split

# Model
persistance        = True
pretrained         = True
pretrained_weights = 'weights/noseg_cnn_pretraining'

# Fix global seed (as good as possible) to ensure reproducibility of results
seed = 42
np.random.seed(seed)
random.seed(seed)

### Meta information

Load information of the CT scans from 'ct_info.csv' and get a list of the preprocessed ct scans and segmentation masks that are available in 'preprocessed-data/'.

In [None]:
df_ct_info = pd.read_csv('ct_info.csv')
df_ct_info

In [None]:
df_preprocessed_info = get_preprocessed_polyp_segmentation_mask_info(data_dir)
df_preprocessed_info

Merge information into a single dataframe

In [None]:
df_data = df_ct_info.merge(df_preprocessed_info, how='inner', on=['patient', 'polyp', 'segmentation'])
df_data

### Test training datagenerator and augmentations (sanity check)

In [None]:
train_data_generator = BatchgenWrapper(data=df_data,
                                       batch_size=batch_size,
                                       raw_image_size=raw_image_size,
                                       patch_size=patch_size,
                                       n_channels=n_channels,
                                       one_hot=False,
                                       num_processes=1,
                                       num_cached_per_queue=1)

In [None]:
test_batch   = train_data_generator[0]
X_test_batch = test_batch[0]
y_test_batch = test_batch[1]
print('X_test_batch:', X_test_batch.shape, ', y_test_batch:', y_test_batch.shape)

In [None]:
interactive(plotter_batch,
            batch        = fixed(test_batch),
            sample_nr    = (0,X_test_batch.shape[0]-1),
            channel      = (0,X_test_batch.shape[4]-1),
            slice_x      = (0,X_test_batch.shape[1]-1),
            slice_y      = (0,X_test_batch.shape[2]-1),
            slice_z      = (0,X_test_batch.shape[3]-1),
            cmap         = ["gist_yarg", "cool", "inferno", "magma", "plasma", "viridis"],
            reverse_cmap = [True, False])

#### Model

In [None]:
model = CNN(input_shape=(50, 50, 50, n_channels), classes=1, dropout=0.1, mc=False)
model.summary()

In [None]:
if pretrained:
    # Get layer names
    layer_names = [layer.name for layer in model.layers]
    print('network layers:', layer_names)

    # Freeze layers
    for layer_name in layer_names[:9]:
        model.get_layer(layer_name).trainable = True

    # Verify trainability
    for layer in model.layers:
        print('layer:', layer.name, 'trainable:', layer.trainable)

model.compile(optimizer=SGD(lr=0.01),
             loss='binary_crossentropy',
             metrics=['accuracy', roc_auc])

if pretrained:
    model.load_weights(pretrained_weights)

initial_weights = model.get_weights()

#### Callbacks

In [None]:
cb_es = EarlyStopping(monitor='val_loss',
                      mode='min',
                      patience=4*((df_data.shape[0]*train_size)//batch_size),
                      restore_best_weights=True)

#### Training and validation

In [None]:
for est in range(n_estimators):
    
    print('\nEstimator #{:d}\n'.format(est))
    
    # Perform dataset split (unique seed for each estimator)
    df_data_train, df_data_valid = train_validation_split(df_data, train_size=train_size, random_state=est)
    print("\nDatasets: Train set =", df_data_train.shape[0], ", Validation set =", df_data_valid.shape[0])
    

    # Training Datagenerator
    train_data_generator = BatchgenWrapper(data=df_data_train,
                                           batch_size=batch_size,
                                           raw_image_size=raw_image_size,
                                           patch_size=patch_size,
                                           n_channels=n_channels,
                                           one_hot=False,
                                           num_processes=1,
                                           num_cached_per_queue=1)
    
    # Validation Datagenerator
    valid_data_generator = ValidationDataGenerator(data=df_data_valid,
                                                   batch_size=df_data_valid.shape[0],
                                                   patch_size=patch_size,
                                                   n_channels=n_channels,
                                                   num_threads=1,
                                                   shuffle=False)
    
    # Check batches
    train_test_batch = train_data_generator[0]
    print('\nTrain batch:', train_test_batch[0].shape)
    
    valid_test_batch = valid_data_generator[0]
    print('Valid batch:', valid_test_batch[0].shape)
        
    # Set initial weights
    model.set_weights(initial_weights)
    
    # Training
    history = model.fit(train_data_generator,
                        epochs           = 1000,
                        validation_data  = valid_data_generator,
                        verbose          = 0,
                        callbacks        = [cb_es])
    history = history.history
    
    # Plot model history
    keys    = ['loss','roc_auc']
    fig, ax = plt.subplots(1, len(keys), figsize=(8*len(keys),6), num=0, clear=True)
    for i, key in enumerate(keys):
        ax[i].plot(history[key], c='orange', alpha=0.5)
        ax[i].plot(history['val_'+key], c='lightblue', alpha=0.5)
        ax[i].plot(gaussian_filter1d(history[key], 3), c='red', lw=2, label='training')
        ax[i].plot(gaussian_filter1d(history['val_'+key], 3), c='blue', lw=2, label='validation')
        ax[i].legend(fontsize=16)
        ax[i].set_xlabel('epoch', fontsize=20)
        ax[i].set_ylabel(keys[i], fontsize=20)
        ax[i].tick_params(labelsize=16)
    plt.show()
    plt.close()
     
    # Validation
    y_true_valid      = np.expand_dims(np.asarray(valid_data_generator[0][1]), -1)
    predictions_valid = np.asarray(model.predict(valid_data_generator))
    eval_valid        = np.concatenate([predictions_valid, y_true_valid], 1)
    
    ################################
    ######### Persistance ##########
    ################################
    
    if persistance:
        weights_file    = 'results/weights_model_{:s}'.format(str(est+1))
        history_file    = 'results/history_model_{:s}.npy'.format(str(est+1))
        valid_eval_file = 'results/valid_eval_model_{:s}.npy'.format(str(est+1))

        print('\nSave weights:               {:s}'.format(weights_file))
        print('Save history:               {:s}'.format(history_file))
        print('Save validation evaluation: {:s}'.format(valid_eval_file))

        model.save_weights(weights_file)
        np.save(history_file,    history)
        np.save(valid_eval_file, eval_valid)