In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from cnn import model

from cnn.input import (
    get_list_of_patients,
    get_training_augmentation,
    get_validation_augmentation,
    Dataset,
    Dataloader,
    get_split_deterministic,
)

2022-12-01 01:36:16.757573: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
train_params = {
    "data_path": "spleen_dataset/data/Task09_Spleen_preprocessed",
    "num_classes": 2,
    "num_channels": 1,
    "skip_slices": 0,
    "image_size": 128,
    "batch_size": 32,
    "epochs": 50,
    "eval_epochs": 10,
    "folds": 5,
    "initializations": 5,
    "stem_filters": 16,
    "max_depth": 4
}

unet_net_list = [
    "vgg_n_3",
    "vgg_d_3",
    "vgg_d_3",
    "vgg_d_3",
    "vgg_d_3",
    "vgg_u_3",
    "vgg_u_3",
    "vgg_u_3",
    "vgg_u_3",
    "vgg_n_3" 
]

experiment_1_spleen_net_list = [
    "vgg_d_3",
    "vgg_d_3",
    "vgg_n_3",
    "ide_d",
    "vgg_n_3",
    "vgg_d_3",
    "vgg_u_3",
    "ide_d",
    "vgg_n_3",
    "ide_d"
]

experiment_2_spleen_net_list = [
    "inc_7",
    "vgg_5",
    "vgg_5",
    "res_5",
    "den_7",
    "vgg_7",
    "den_5",
    "res_5",
    "inc_7",
    "inc_5",
]

experiment_2_spleen_cell_list = [
    "NonscalingCell",
    "DownscalingCell",
    "DownscalingCell",
    "DownscalingCell",
    "DownscalingCell",
    "UpscalingCell",
    "UpscalingCell",
    "UpscalingCell",
    "UpscalingCell",
    "NonscalingCell"
]

layer_dict = {
    "den_3":   {                           "block": "DenseBlock",     "kernel": 3},
    "den_5":   {                           "block": "DenseBlock",     "kernel": 5},
    "den_7":   {                           "block": "DenseBlock",     "kernel": 7},
    "den_d_3": {"cell": "DownscalingCell", "block": "DenseBlock",     "kernel": 3},
    "den_d_5": {"cell": "DownscalingCell", "block": "DenseBlock",     "kernel": 5},
    "den_d_7": {"cell": "DownscalingCell", "block": "DenseBlock",     "kernel": 7},
    "den_n_3": {"cell": "NonscalingCell",  "block": "DenseBlock",     "kernel": 3},
    "den_n_5": {"cell": "NonscalingCell",  "block": "DenseBlock",     "kernel": 5},
    "den_n_7": {"cell": "NonscalingCell",  "block": "DenseBlock",     "kernel": 7},
    "den_u_3": {"cell": "UpscalingCell",   "block": "DenseBlock",     "kernel": 3},
    "den_u_5": {"cell": "UpscalingCell",   "block": "DenseBlock",     "kernel": 5},
    "den_u_7": {"cell": "UpscalingCell",   "block": "DenseBlock",     "kernel": 7},
    "inc_3":   {                           "block": "InceptionBlock", "kernel": 3},
    "inc_5":   {                           "block": "InceptionBlock", "kernel": 5},
    "inc_7":   {                           "block": "InceptionBlock", "kernel": 7},
    "inc_d_3": {"cell": "DownscalingCell", "block": "InceptionBlock", "kernel": 3},
    "inc_d_5": {"cell": "DownscalingCell", "block": "InceptionBlock", "kernel": 5},
    "inc_d_7": {"cell": "DownscalingCell", "block": "InceptionBlock", "kernel": 7},
    "inc_n_3": {"cell": "NonscalingCell",  "block": "InceptionBlock", "kernel": 3},
    "inc_n_5": {"cell": "NonscalingCell",  "block": "InceptionBlock", "kernel": 5},
    "inc_n_7": {"cell": "NonscalingCell",  "block": "InceptionBlock", "kernel": 7},
    "inc_u_3": {"cell": "UpscalingCell",   "block": "InceptionBlock", "kernel": 3},
    "inc_u_5": {"cell": "UpscalingCell",   "block": "InceptionBlock", "kernel": 5},
    "inc_u_7": {"cell": "UpscalingCell",   "block": "InceptionBlock", "kernel": 7},
    "ide":     {                           "block": "IdentityBlock"              },
    "ide_d":   {"cell": "DownscalingCell", "block": "IdentityBlock"              },
    "ide_n":   {"cell": "NonscalingCell",  "block": "IdentityBlock"              },
    "ide_u":   {"cell": "UpscalingCell",   "block": "IdentityBlock"              },
    "res_3":   {                           "block": "ResNetBlock",    "kernel": 3},
    "res_5":   {                           "block": "ResNetBlock",    "kernel": 5},
    "res_7":   {                           "block": "ResNetBlock",    "kernel": 7},
    "res_d_3": {"cell": "DownscalingCell", "block": "ResNetBlock",    "kernel": 3},
    "res_d_5": {"cell": "DownscalingCell", "block": "ResNetBlock",    "kernel": 5},
    "res_d_7": {"cell": "DownscalingCell", "block": "ResNetBlock",    "kernel": 7},
    "res_n_3": {"cell": "NonscalingCell",  "block": "ResNetBlock",    "kernel": 3},
    "res_n_5": {"cell": "NonscalingCell",  "block": "ResNetBlock",    "kernel": 5},
    "res_n_7": {"cell": "NonscalingCell",  "block": "ResNetBlock",    "kernel": 7},
    "res_u_3": {"cell": "UpscalingCell",   "block": "ResNetBlock",    "kernel": 3},
    "res_u_5": {"cell": "UpscalingCell",   "block": "ResNetBlock",    "kernel": 5},
    "res_u_7": {"cell": "UpscalingCell",   "block": "ResNetBlock",    "kernel": 7},
    "vgg_3":   {                           "block": "VGGBlock",       "kernel": 3},
    "vgg_5":   {                           "block": "VGGBlock",       "kernel": 5},
    "vgg_7":   {                           "block": "VGGBlock",       "kernel": 7},
    "vgg_d_3": {"cell": "DownscalingCell", "block": "VGGBlock",       "kernel": 3},
    "vgg_d_5": {"cell": "DownscalingCell", "block": "VGGBlock",       "kernel": 5},
    "vgg_d_7": {"cell": "DownscalingCell", "block": "VGGBlock",       "kernel": 7},
    "vgg_n_3": {"cell": "NonscalingCell",  "block": "VGGBlock",       "kernel": 3},
    "vgg_n_5": {"cell": "NonscalingCell",  "block": "VGGBlock",       "kernel": 5},
    "vgg_n_7": {"cell": "NonscalingCell",  "block": "VGGBlock",       "kernel": 7},
    "vgg_u_3": {"cell": "UpscalingCell",   "block": "VGGBlock",       "kernel": 3},
    "vgg_u_5": {"cell": "UpscalingCell",   "block": "VGGBlock",       "kernel": 5},
    "vgg_u_7": {"cell": "UpscalingCell",   "block": "VGGBlock",       "kernel": 7},
}

In [3]:
def cross_val_train(train_params, layer_dict, net_list, cell_list=None):

    data_path = train_params["data_path"]
    num_classes = train_params["num_classes"]
    num_channels = train_params["num_channels"]
    skip_slices = train_params["skip_slices"]
    image_size = train_params["image_size"]
    batch_size = train_params["batch_size"]
    epochs = train_params["epochs"]
    eval_epochs = train_params["eval_epochs"]
    num_folds = train_params["folds"]
    num_initializations = train_params["initializations"]
    stem_filters = train_params["stem_filters"]
    max_depth = train_params["max_depth"]

    patch_size = (image_size, image_size, num_channels)

    patients = get_list_of_patients(data_path)
    train_augmentation = get_training_augmentation(patch_size)
    val_augmentation = get_validation_augmentation(patch_size)

    val_gen_dice_coef_list = []

    for initialization in range(num_initializations):
        for fold in range(num_folds):

            net = model.build_net(
                input_shape=patch_size,
                num_classes=num_classes,
                stem_filters=stem_filters,
                max_depth=max_depth,
                layer_dict=layer_dict,
                net_list=net_list,
                cell_list=cell_list,
            )

            train_patients, val_patients = get_split_deterministic(
                patients,
                fold=fold,
                num_splits=num_folds,
                random_state=initialization,
            )

            train_dataset = Dataset(
                data_path=data_path,
                patients=train_patients,
                only_non_empty_slices=True,
            )

            val_dataset = Dataset(
                data_path=data_path,
                patients=val_patients,
                only_non_empty_slices=True,
            )

            train_dataloader = Dataloader(
                dataset=train_dataset,
                batch_size=batch_size,
                skip_slices=skip_slices,
                augmentation=train_augmentation,
                shuffle=True,
            )

            val_dataloader = Dataloader(
                dataset=val_dataset,
                batch_size=batch_size,
                skip_slices=0,
                augmentation=val_augmentation,
                shuffle=False,
            )

            def learning_rate_fn(epoch):
                initial_learning_rate = 1e-3
                end_learning_rate = 1e-4
                power = 0.9
                return (
                    (initial_learning_rate - end_learning_rate)
                    * (1 - epoch / float(epochs)) ** (power)
                ) + end_learning_rate

            lr_callback = tf.keras.callbacks.LearningRateScheduler(
                learning_rate_fn, verbose=False
            )

            history = net.fit(
                train_dataloader,
                validation_data=val_dataloader,
                epochs=epochs,
                verbose=0,
                callbacks=[lr_callback],
            )

            history_eval_epochs = history.history["val_gen_dice_coef"][-eval_epochs:]

            val_gen_dice_coef_list.extend(history_eval_epochs)

            mean_dsc = np.mean(val_gen_dice_coef_list)
            std_dsc = np.std(val_gen_dice_coef_list)
            print(
                f"{fold + initialization*num_folds}/{num_folds*num_initializations}: {mean_dsc} +- {std_dsc}"
            )

            plt.figure()
            plt.plot(history.history["gen_dice_coef"])
            plt.plot(history.history["val_gen_dice_coef"])
            plt.title("DSC")
            plt.ylabel("DSC")
            plt.xlabel("Epoch")
            plt.legend(["Train", "Test"], loc="upper left")
            plt.show()

    mean_dsc = np.mean(val_gen_dice_coef_list)
    std_dsc = np.std(val_gen_dice_coef_list)

    return mean_dsc, std_dsc

In [4]:
cross_val_train(train_params=train_params, layer_dict=layer_dict, net_list=unet_net_list, cell_list=None)

2022-12-01 01:36:18.901278: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-01 01:36:20.687391: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22122 MB memory:  -> device: 0, name: NVIDIA A30, pci bus id: 0000:3b:00.0, compute capability: 8.0
2022-12-01 01:36:20.689215: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 22122 MB memory:  -> device: 1, name: NVIDIA A30, pci bus id: 0000:af:00.0, compute capability: 8.0
2022-12-01 01:36:20.690861: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/repli

In [None]:
cross_val_train(train_params=train_params, layer_dict=layer_dict, net_list=experiment_1_spleen_net_list, cell_list=None)

In [None]:
cross_val_train(train_params=train_params, layer_dict=layer_dict, net_list=experiment_2_spleen_net_list, cell_list=experiment_2_spleen_cell_list)