In [1]:
from cnn.model import build_net

from spleen_dataset.dataloader import SpleenDataloader, SpleenDataset, get_training_augmentation, get_validation_augmentation
from spleen_dataset.config import dataset_folder
from spleen_dataset.utils import get_split_deterministic, get_list_of_patients

from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
import matplotlib.pyplot as plt
import random
import numpy as np

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
    #tf.config.experimental.set_memory_growth(gpus[0], False)
  except RuntimeError as e:
    print(e)

In [None]:
patients = get_list_of_patients(dataset_folder)

In [None]:
patch_size = (128, 128)
batch_size = 32
num_classes = 2
train_augmentation = get_training_augmentation(patch_size)
val_augmentation = get_validation_augmentation(patch_size)

net_list = [
    'xception_3_32', 
    'xception_3_64',
    'xception_3_64',
    'xception_3_256',
    'xception_3_512',
    'xception_3_256',
    'xception_3_128',
    'xception_3_64',
    'xception_3_32'
]

fn_dict = {
    'vgg_3_32':       {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 32},  'prob': 1/36},
    'vgg_3_64':       {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 64},  'prob': 1/36},
    'vgg_3_128':      {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 128}, 'prob': 1/36},
    'vgg_3_256':      {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 256}, 'prob': 1/36},
    'vgg_3_512':      {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 512}, 'prob': 1/36},
    'resnet_3_32':    {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 32},  'prob': 1/36},
    'resnet_3_64':    {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 64},  'prob': 1/36},
    'resnet_3_128':   {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 128}, 'prob': 1/36},
    'resnet_3_256':   {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 256}, 'prob': 1/36},
    'resnet_3_512':   {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 512}, 'prob': 1/36},
    'xception_3_32':  {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 32},  'prob': 1/36},
    'xception_3_64':  {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 64},  'prob': 1/36},
    'xception_3_128': {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 128}, 'prob': 1/36},
    'xception_3_256': {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 256}, 'prob': 1/36},
    'xception_3_512': {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 512}, 'prob': 1/36},
    'mbconv_3_32':    {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 32},  'prob': 1/36},
    'mbconv_3_64':    {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 64},  'prob': 1/36},
    'mbconv_3_128':   {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 128}, 'prob': 1/36},
    'mbconv_3_256':   {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 256}, 'prob': 1/36},
    'mbconv_3_512':   {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 512}, 'prob': 1/36},
}

In [None]:
val_gen_dice_coef_list = []
num_splits = 5
num_initializations = 3
epochs = 50

for initialization in range(num_initializations):

    for fold in range(num_splits):
        train_patients, val_patients = get_split_deterministic(patients, fold=fold, num_splits=num_splits, random_state=initialization)

        train_dataset = SpleenDataset(train_patients, only_non_empty_slices=True)
        val_dataset = SpleenDataset(val_patients, only_non_empty_slices=True)

        train_dataloader = SpleenDataloader(train_dataset, batch_size, train_augmentation)
        val_dataloader = SpleenDataloader(val_dataset, batch_size, val_augmentation)

        model = build_net((*patch_size, 1), num_classes, fn_dict, net_list)

        checkpoint_filepath = '/tmp/checkpoint'
        model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_filepath,
            save_weights_only=True,
            monitor='val_gen_dice_coef',
            mode='max',
            save_best_only=True)

        history = model.fit(
            train_dataloader,
            validation_data=val_dataloader,
            epochs=epochs,
            verbose=1,
            callbacks=[
                model_checkpoint_callback
            ]
        )

        model.load_weights(checkpoint_filepath)

        for patient in val_patients:
            patient_dataset = SpleenDataset([patient], only_non_empty_slices=True)
            patient_dataloader = SpleenDataloader(patient_dataset, 1, val_augmentation, shuffle=False)
            results = model.evaluate(patient_dataloader)
            val_gen_dice_coef_patient = results[-1]
            val_gen_dice_coef_list.append(val_gen_dice_coef_patient)

        #plotting the dice coef results (accuracy) as a function of the number of epochs
        plt.figure()
        plt.plot(history.history['gen_dice_coef'])
        plt.plot(history.history['val_gen_dice_coef'])
        plt.title('Model: Generalized Dice Coeficient')
        plt.ylabel('Dice Coef')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Test'], loc='upper left')
        plt.show()

        #plotting the dice coeff results (loss function) as a function of the number of epochs
        plt.figure()
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('Model: Generalized Dice Coeficient')
        plt.ylabel('Dice Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Test'], loc='upper right')
        plt.show()

mean_val_gen_dice_coef = (np.mean(val_gen_dice_coef_list))
std_val_gen_dice_coef = (np.std(val_gen_dice_coef_list))

print(f'Dice {mean_val_gen_dice_coef} +- {std_val_gen_dice_coef}')

NameError: name 'regularizers' is not defined

In [None]:
#!tensorboard --logdir='./logs'