In [None]:
from cnn.model import build_net

from spleen_dataset.dataloader import SpleenDataloader, SpleenDataset, get_training_augmentation, get_validation_augmentation
from spleen_dataset.config import dataset_folder
from spleen_dataset.utils import get_split_deterministic, get_list_of_patients

from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
import matplotlib.pyplot as plt
import random
import numpy as np

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_memory_growth(gpus[0], False)
  except RuntimeError as e:
    print(e)

In [None]:
patients = get_list_of_patients(dataset_folder)

In [None]:
patch_size = (128, 128)
batch_size = 32
num_classes = 2
train_augmentation = get_training_augmentation(patch_size)
val_augmentation = get_validation_augmentation(patch_size)

net_list = [
    'vgg_n_3',
    'vgg_d_3',
    'vgg_d_3',
    'vgg_n_3',
    'vgg_u_3',
    'vgg_u_3',
    'vgg_n_3',
    'vgg_n_3',
    'vgg_n_3',
    'vgg_u_3',
]

fn_dict =  {
    'den_d_3': {'cell': 'DownscalingCell', 'block': 'DenseBlock',     'kernel': 3},
    'den_d_5': {'cell': 'DownscalingCell', 'block': 'DenseBlock',     'kernel': 5},
    'den_d_7': {'cell': 'DownscalingCell', 'block': 'DenseBlock',     'kernel': 7},
    'den_n_3': {'cell': 'NonscalingCell',  'block': 'DenseBlock',     'kernel': 3},
    'den_n_5': {'cell': 'NonscalingCell',  'block': 'DenseBlock',     'kernel': 5},
    'den_n_7': {'cell': 'NonscalingCell',  'block': 'DenseBlock',     'kernel': 7},
    'den_u_3': {'cell': 'UpscalingCell',   'block': 'DenseBlock',     'kernel': 3},
    'den_u_5': {'cell': 'UpscalingCell',   'block': 'DenseBlock',     'kernel': 5},
    'den_u_7': {'cell': 'UpscalingCell',   'block': 'DenseBlock',     'kernel': 7},
    'inc_d_3': {'cell': 'DownscalingCell', 'block': 'InceptionBlock', 'kernel': 3},
    'inc_d_5': {'cell': 'DownscalingCell', 'block': 'InceptionBlock', 'kernel': 5},
    'inc_d_7': {'cell': 'DownscalingCell', 'block': 'InceptionBlock', 'kernel': 7},
    'inc_n_3': {'cell': 'NonscalingCell',  'block': 'InceptionBlock', 'kernel': 3},
    'inc_n_5': {'cell': 'NonscalingCell',  'block': 'InceptionBlock', 'kernel': 5},
    'inc_n_7': {'cell': 'NonscalingCell',  'block': 'InceptionBlock', 'kernel': 7},
    'inc_u_3': {'cell': 'UpscalingCell',   'block': 'InceptionBlock', 'kernel': 3},
    'inc_u_5': {'cell': 'UpscalingCell',   'block': 'InceptionBlock', 'kernel': 5},
    'inc_u_7': {'cell': 'UpscalingCell',   'block': 'InceptionBlock', 'kernel': 7},
    'ind_d':   {'cell': 'DownscalingCell', 'block': 'IdentityBlock',             },
    'ind_n':   {'cell': 'NonscalingCell',  'block': 'IdentityBlock',             },   
    'ind_u':   {'cell': 'UpscalingCell',   'block': 'IdentityBlock',             },
    'res_d_3': {'cell': 'DownscalingCell', 'block': 'ResNetBlock',    'kernel': 3},
    'res_d_5': {'cell': 'DownscalingCell', 'block': 'ResNetBlock',    'kernel': 5},
    'res_d_7': {'cell': 'DownscalingCell', 'block': 'ResNetBlock',    'kernel': 7},
    'res_n_3': {'cell': 'NonscalingCell',  'block': 'ResNetBlock',    'kernel': 3},
    'res_n_5': {'cell': 'NonscalingCell',  'block': 'ResNetBlock',    'kernel': 5},
    'res_n_7': {'cell': 'NonscalingCell',  'block': 'ResNetBlock',    'kernel': 7},
    'res_u_3': {'cell': 'UpscalingCell',   'block': 'ResNetBlock',    'kernel': 3},
    'res_u_5': {'cell': 'UpscalingCell',   'block': 'ResNetBlock',    'kernel': 5},
    'res_u_7': {'cell': 'UpscalingCell',   'block': 'ResNetBlock',    'kernel': 7},
    'vgg_d_3': {'cell': 'DownscalingCell', 'block': 'VGGBlock',       'kernel': 3},
    'vgg_d_5': {'cell': 'DownscalingCell', 'block': 'VGGBlock',       'kernel': 5},
    'vgg_d_7': {'cell': 'DownscalingCell', 'block': 'VGGBlock',       'kernel': 7},
    'vgg_n_3': {'cell': 'NonscalingCell',  'block': 'VGGBlock',       'kernel': 3},
    'vgg_n_5': {'cell': 'NonscalingCell',  'block': 'VGGBlock',       'kernel': 5},
    'vgg_n_7': {'cell': 'NonscalingCell',  'block': 'VGGBlock',       'kernel': 7},
    'vgg_u_3': {'cell': 'UpscalingCell',   'block': 'VGGBlock',       'kernel': 3},
    'vgg_u_5': {'cell': 'UpscalingCell',   'block': 'VGGBlock',       'kernel': 5},
    'vgg_u_7': {'cell': 'UpscalingCell',   'block': 'VGGBlock',       'kernel': 7},
}

In [None]:
val_gen_dice_coef_list = []
num_splits = 5
num_initializations = 3
epochs = 50
evaluation_epochs = int(0.2 * epochs)

for initialization in range(num_initializations):

    for fold in range(num_splits):
        train_patients, val_patients = get_split_deterministic(patients, fold=fold, num_splits=num_splits, random_state=initialization)

        train_dataset = SpleenDataset(train_patients, only_non_empty_slices=True)
        val_dataset = SpleenDataset(val_patients, only_non_empty_slices=True)

        train_dataloader = SpleenDataloader(train_dataset, batch_size, train_augmentation)
        val_dataloader = SpleenDataloader(val_dataset, batch_size, val_augmentation)

        model = build_net((*patch_size, 1), num_classes, fn_dict, net_list)
 
        def learning_rate_fn(epoch):
            initial_learning_rate = 1e-3
            end_learning_rate = 1e-4
            power = 0.9
            return ((initial_learning_rate - end_learning_rate) *
                (1 - epoch / float(epochs)) ** (power)
                ) + end_learning_rate
            
        lr_callback = tf.keras.callbacks.LearningRateScheduler(learning_rate_fn, verbose=False)

        history = model.fit(
            train_dataloader,
            validation_data=val_dataloader,
            epochs=epochs,
            verbose=0,
            callbacks=[
                lr_callback
            ]
        )
        
        print(history.history['val_gen_dice_coef'][-evaluation_epochs:])

        val_gen_dice_coef_list.extend(history.history['val_gen_dice_coef'][-evaluation_epochs:])

        # for patient in val_patients:
        #     patient_dataset = SpleenDataset([patient], only_non_empty_slices=True)
        #     patient_dataloader = SpleenDataloader(patient_dataset, 1, val_augmentation, shuffle=False)
        #     results = model.evaluate(patient_dataloader)
        #     val_gen_dice_coef_patient = results[-1]
        #     val_gen_dice_coef_list.append(val_gen_dice_coef_patient)

        #plotting the dice coef results (accuracy) as a function of the number of epochs
        plt.figure()
        plt.plot(history.history['gen_dice_coef'])
        plt.plot(history.history['val_gen_dice_coef'])
        plt.title('Model: Generalized Dice Coeficient')
        plt.ylabel('Dice Coef')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Test'], loc='upper left')
        plt.show()

        #plotting the dice coef results (loss function) as a function of the number of epochs
        # plt.figure()
        # plt.plot(history.history['loss'])
        # plt.plot(history.history['val_loss'])
        # plt.title('Model: Generalized Dice Coeficient')
        # plt.ylabel('Dice Loss')
        # plt.xlabel('Epoch')
        # plt.legend(['Train', 'Test'], loc='upper right')
        # plt.show()

mean_val_gen_dice_coef = (np.mean(val_gen_dice_coef_list))
std_val_gen_dice_coef = (np.std(val_gen_dice_coef_list))

print(f'Dice {mean_val_gen_dice_coef} +- {std_val_gen_dice_coef}')

In [None]:
#!tensorboard --logdir='./logs'