In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import csv 

from cnn import model

from cnn.input import (
    get_list_of_patients,
    get_training_augmentation,
    get_validation_augmentation,
    Dataset,
    Dataloader,
    get_split_deterministic,
)

2022-12-18 14:18:18.688055: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
train_params = {
    "data_path": "spleen_dataset/data/Task09_Spleen_preprocessed",
    "num_classes": 2,
    "num_channels": 1,
    "skip_slices": 0,
    "image_size": 128,
    "batch_size": 32,
    "epochs": 100,
    "eval_epochs": 20,
    "folds": 5,
    "initializations": 5,
    "stem_filters": 16,
    "max_depth": 4
}

# train_params = {
#     "data_path": "prostate_dataset/data/Task05_Prostate_preprocessed",
#     "num_classes": 3,
#     "num_channels": 2,
#     "skip_slices": 0,
#     "image_size": 128,
#     "batch_size": 32,
#     "epochs": 100,
#     "eval_epochs": 20,
#     "folds": 5,
#     "initializations": 5,
#     "stem_filters": 16,
#     "max_depth": 4
# }

unet_net_list = ["vgg_n_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_u_3","vgg_u_3","vgg_u_3","vgg_u_3","vgg_n_3"]
resunet_net_list = ["res_n_3","res_d_3","res_d_3","res_d_3","res_d_3","res_u_3","res_u_3","res_u_3","res_u_3","res_n_3"]

experiment_1_spleen_net_list = ["vgg_n_3","vgg_d_3","vgg_u_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_n_3","vgg_u_3"]
experiment_2_spleen_net_list = ["den_n_3","res_d_5","vgg_d_5","res_d_5","inc_d_3","res_u_3","res_u_3","res_u_5","inc_u_3","ide_n"]
experiment_3_spleen_net_list = ["res_n_3","den_d_3","res_u_7","inc_d_7","vgg_d_3","vgg_d_3","res_d_5","inc_d_7","res_n_7","den_u_3"]
experiment_4_spleen_net_list = ["inc_n_3","den_d_3","den_d_3","res_u_5","res_n_3","res_n_5","res_d_5","den_n_5","den_n_5","den_d_3"]

experiment_1_prostate_net_list = ["vgg_u_3","vgg_u_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_d_3","vgg_u_3","vgg_u_3","vgg_n_3"]
experiment_2_prostate_net_list = ["ide_n","vgg_d_3","vgg_d_5","inc_d_3","den_d_3","den_u_5","vgg_u_5","res_u_5","inc_u_7","vgg_n_7"]
experiment_3_prostate_net_list = ["vgg_u_5","inc_u_5","vgg_d_7","res_d_3","inc_d_3","ide_d","den_d_7","vgg_u_7","res_u_5","den_n_5"]
experiment_4_prostate_net_list = ['vgg_d_7',"vgg_d_7","inc_d_5","den_d_5","res_d_5","vgg_u_7","vgg_d_5","vgg_n_5","inc_n_5","den_d_7"]

layer_dict = {
    "den_3":   {                           "block": "DenseBlock",     "kernel": 3},
    "den_5":   {                           "block": "DenseBlock",     "kernel": 5},
    "den_7":   {                           "block": "DenseBlock",     "kernel": 7},
    "den_d_3": {"cell": "DownscalingCell", "block": "DenseBlock",     "kernel": 3},
    "den_d_5": {"cell": "DownscalingCell", "block": "DenseBlock",     "kernel": 5},
    "den_d_7": {"cell": "DownscalingCell", "block": "DenseBlock",     "kernel": 7},
    "den_n_3": {"cell": "NonscalingCell",  "block": "DenseBlock",     "kernel": 3},
    "den_n_5": {"cell": "NonscalingCell",  "block": "DenseBlock",     "kernel": 5},
    "den_n_7": {"cell": "NonscalingCell",  "block": "DenseBlock",     "kernel": 7},
    "den_u_3": {"cell": "UpscalingCell",   "block": "DenseBlock",     "kernel": 3},
    "den_u_5": {"cell": "UpscalingCell",   "block": "DenseBlock",     "kernel": 5},
    "den_u_7": {"cell": "UpscalingCell",   "block": "DenseBlock",     "kernel": 7},
    "inc_3":   {                           "block": "InceptionBlock", "kernel": 3},
    "inc_5":   {                           "block": "InceptionBlock", "kernel": 5},
    "inc_7":   {                           "block": "InceptionBlock", "kernel": 7},
    "inc_d_3": {"cell": "DownscalingCell", "block": "InceptionBlock", "kernel": 3},
    "inc_d_5": {"cell": "DownscalingCell", "block": "InceptionBlock", "kernel": 5},
    "inc_d_7": {"cell": "DownscalingCell", "block": "InceptionBlock", "kernel": 7},
    "inc_n_3": {"cell": "NonscalingCell",  "block": "InceptionBlock", "kernel": 3},
    "inc_n_5": {"cell": "NonscalingCell",  "block": "InceptionBlock", "kernel": 5},
    "inc_n_7": {"cell": "NonscalingCell",  "block": "InceptionBlock", "kernel": 7},
    "inc_u_3": {"cell": "UpscalingCell",   "block": "InceptionBlock", "kernel": 3},
    "inc_u_5": {"cell": "UpscalingCell",   "block": "InceptionBlock", "kernel": 5},
    "inc_u_7": {"cell": "UpscalingCell",   "block": "InceptionBlock", "kernel": 7},
    "ide":     {                           "block": "IdentityBlock"              },
    "ide_d":   {"cell": "DownscalingCell", "block": "IdentityBlock"              },
    "ide_n":   {"cell": "NonscalingCell",  "block": "IdentityBlock"              },
    "ide_u":   {"cell": "UpscalingCell",   "block": "IdentityBlock"              },
    "res_3":   {                           "block": "ResNetBlock",    "kernel": 3},
    "res_5":   {                           "block": "ResNetBlock",    "kernel": 5},
    "res_7":   {                           "block": "ResNetBlock",    "kernel": 7},
    "res_d_3": {"cell": "DownscalingCell", "block": "ResNetBlock",    "kernel": 3},
    "res_d_5": {"cell": "DownscalingCell", "block": "ResNetBlock",    "kernel": 5},
    "res_d_7": {"cell": "DownscalingCell", "block": "ResNetBlock",    "kernel": 7},
    "res_n_3": {"cell": "NonscalingCell",  "block": "ResNetBlock",    "kernel": 3},
    "res_n_5": {"cell": "NonscalingCell",  "block": "ResNetBlock",    "kernel": 5},
    "res_n_7": {"cell": "NonscalingCell",  "block": "ResNetBlock",    "kernel": 7},
    "res_u_3": {"cell": "UpscalingCell",   "block": "ResNetBlock",    "kernel": 3},
    "res_u_5": {"cell": "UpscalingCell",   "block": "ResNetBlock",    "kernel": 5},
    "res_u_7": {"cell": "UpscalingCell",   "block": "ResNetBlock",    "kernel": 7},
    "vgg_3":   {                           "block": "VGGBlock",       "kernel": 3},
    "vgg_5":   {                           "block": "VGGBlock",       "kernel": 5},
    "vgg_7":   {                           "block": "VGGBlock",       "kernel": 7},
    "vgg_d_3": {"cell": "DownscalingCell", "block": "VGGBlock",       "kernel": 3},
    "vgg_d_5": {"cell": "DownscalingCell", "block": "VGGBlock",       "kernel": 5},
    "vgg_d_7": {"cell": "DownscalingCell", "block": "VGGBlock",       "kernel": 7},
    "vgg_n_3": {"cell": "NonscalingCell",  "block": "VGGBlock",       "kernel": 3},
    "vgg_n_5": {"cell": "NonscalingCell",  "block": "VGGBlock",       "kernel": 5},
    "vgg_n_7": {"cell": "NonscalingCell",  "block": "VGGBlock",       "kernel": 7},
    "vgg_u_3": {"cell": "UpscalingCell",   "block": "VGGBlock",       "kernel": 3},
    "vgg_u_5": {"cell": "UpscalingCell",   "block": "VGGBlock",       "kernel": 5},
    "vgg_u_7": {"cell": "UpscalingCell",   "block": "VGGBlock",       "kernel": 7},
}

In [3]:
def cross_val_train(experiment_name, train_params, layer_dict, net_list, cell_list=None):

    data_path = train_params["data_path"]
    num_classes = train_params["num_classes"]
    num_channels = train_params["num_channels"]
    skip_slices = train_params["skip_slices"]
    image_size = train_params["image_size"]
    batch_size = train_params["batch_size"]
    epochs = train_params["epochs"]
    eval_epochs = train_params["eval_epochs"]
    num_folds = train_params["folds"]
    num_initializations = train_params["initializations"]
    stem_filters = train_params["stem_filters"]
    max_depth = train_params["max_depth"]

    patch_size = (image_size, image_size, num_channels)

    patients = get_list_of_patients(data_path)
    train_augmentation =get_training_augmentation(patch_size)
    val_augmentation = get_validation_augmentation(patch_size)

    val_gen_dice_coef_list = []

    for initialization in range(num_initializations):
        for fold in range(num_folds):

            net = model.build_net(
                input_shape=patch_size,
                num_classes=num_classes,
                stem_filters=stem_filters,
                max_depth=max_depth,
                layer_dict=layer_dict,
                net_list=net_list,
                cell_list=cell_list,
            )

            train_patients, val_patients = get_split_deterministic(
                patients,
                fold=fold,
                num_splits=num_folds,
                random_state=initialization,
            )

            train_dataset = Dataset(
                data_path=data_path,
                patients=train_patients,
                only_non_empty_slices=True,
            )

            val_dataset = Dataset(
                data_path=data_path,
                patients=val_patients,
                only_non_empty_slices=True,
            )

            train_dataloader = Dataloader(
                dataset=train_dataset,
                batch_size=batch_size,
                skip_slices=skip_slices,
                augmentation=train_augmentation,
                shuffle=True,
            )

            val_dataloader = Dataloader(
                dataset=val_dataset,
                batch_size=batch_size,
                skip_slices=0,
                augmentation=val_augmentation,
                shuffle=False,
            )
            
            def learning_rate_fn(epoch):
                initial_learning_rate = 1e-3
                end_learning_rate = 1e-4
                power = 0.9
                return (
                    (initial_learning_rate - end_learning_rate)
                    * (1 - epoch / float(epochs)) ** (power)
                ) + end_learning_rate

            lr_callback = tf.keras.callbacks.LearningRateScheduler(
                learning_rate_fn, verbose=False
            )

            history = net.fit(
                train_dataloader,
                validation_data=val_dataloader,
                epochs=epochs,
                verbose=0,
                callbacks=[lr_callback],
            )

            history_eval_epochs = history.history["val_gen_dice_coef"][-eval_epochs:]

            val_gen_dice_coef_list.extend(history_eval_epochs)

            mean_dsc = np.mean(history_eval_epochs)
            std_dsc = np.std(history_eval_epochs)
            print(
                f"{fold + initialization*num_folds + 1}/{num_folds*num_initializations}: {mean_dsc} +- {std_dsc}"
            )

            # plt.figure()
            # plt.plot(history.history["gen_dice_coef"])
            # plt.plot(history.history["val_gen_dice_coef"])
            # plt.title("DSC")
            # plt.ylabel("DSC")
            # plt.xlabel("Epoch")
            # plt.legend(["Train", "Test"], loc="upper left")
            # plt.show()

    mean_dsc = np.mean(val_gen_dice_coef_list)
    std_dsc = np.std(val_gen_dice_coef_list)

    with open(f'{experiment_name}.csv', 'w') as f:
        csv_writer = csv.writer(f)
        csv_writer.writerow(val_gen_dice_coef_list)

    return mean_dsc, std_dsc

In [4]:
cross_val_train(experiment_name='unet_spleen_baseline', train_params=train_params, layer_dict=layer_dict, net_list=unet_net_list)

2022-12-18 14:18:20.867831: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-18 14:18:22.583519: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22122 MB memory:  -> device: 0, name: NVIDIA A30, pci bus id: 0000:3b:00.0, compute capability: 8.0
2022-12-18 14:18:22.585277: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 1170 MB memory:  -> device: 1, name: NVIDIA A30, pci bus id: 0000:af:00.0, compute capability: 8.0
2022-12-18 14:18:22.586960: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replic

1/25: 0.9590796172618866 +- 0.0017484044537968695
2/25: 0.9418266177177429 +- 0.004545904136302152
3/25: 0.9343091249465942 +- 0.009001829737679995
4/25: 0.9528744488954544 +- 0.000680469969283874
5/25: 0.9514697462320327 +- 0.0012796527482774158
6/25: 0.957508385181427 +- 0.0018557272996450004
7/25: 0.9591067761182785 +- 0.0017399494241993182
8/25: 0.939325925707817 +- 0.002497109703725119
9/25: 0.9445828586816788 +- 0.0020870417191923897
10/25: 0.9495691448450089 +- 0.002440497241147055
11/25: 0.9598509579896927 +- 0.002294107952322626
12/25: 0.9495642751455307 +- 0.003317877646237473
13/25: 0.9455975115299224 +- 0.0028008467795169133
14/25: 0.9511992812156678 +- 0.0017041212961312407
15/25: 0.966061481833458 +- 0.0010702267408615044
16/25: 0.9488098233938217 +- 0.0009402186526117369
17/25: 0.9495856583118438 +- 0.002323914978746634
18/25: 0.9635136544704437 +- 0.0007691265930208228
19/25: 0.9516364812850953 +- 0.002033662914306704
20/25: 0.9544284731149674 +- 0.0008218248078323849
2

(0.9517141169309616, 0.008844633555743473)

In [5]:
cross_val_train(experiment_name='resunet_spleen_baseline', train_params=train_params, layer_dict=layer_dict, net_list=resunet_net_list)

1/25: 0.9580796629190445 +- 0.002556151374757031
2/25: 0.958809107542038 +- 0.0010651437957863103
3/25: 0.9461138695478439 +- 0.0033431075564829844
4/25: 0.953932011127472 +- 0.002258761090500036
5/25: 0.9557862192392349 +- 0.0008229310022145897
6/25: 0.956477802991867 +- 0.0018993791802995688
7/25: 0.9376673251390457 +- 0.006294451194151809
8/25: 0.9333404719829559 +- 0.004581879442697995
9/25: 0.9403485596179962 +- 0.007442805387409854
10/25: 0.9511627197265625 +- 0.0016080574001112226
11/25: 0.9563082933425904 +- 0.0014407544461849204
12/25: 0.9524090498685837 +- 0.0014889968428627224
13/25: 0.9568515390157699 +- 0.001049255336899913
14/25: 0.9570066332817078 +- 0.0012430766911020975
15/25: 0.9591039776802063 +- 0.0030753178065585497
16/25: 0.9562106937170028 +- 0.0011456051887867712
17/25: 0.9560454607009887 +- 0.0008607421423792429
18/25: 0.9580247849225998 +- 0.0017193582653436002
19/25: 0.9498190313577652 +- 0.0029715462955811284
20/25: 0.9598807632923126 +- 0.000446167005311694

(0.9529121185541153, 0.007879200235252943)

In [6]:
cross_val_train(experiment_name='experiment_1_spleen', train_params=train_params, layer_dict=layer_dict, net_list=experiment_1_spleen_net_list)

1/25: 0.9541149467229844 +- 0.002503541616601207
2/25: 0.9509168207645416 +- 0.0040296094526084065
3/25: 0.946830439567566 +- 0.0035015486521456566
4/25: 0.95502310693264 +- 0.0013298342000699992
5/25: 0.9577293902635574 +- 0.0009652850941584457
6/25: 0.9572775781154632 +- 0.004682834595339301
7/25: 0.9621615737676621 +- 0.0012445919428731242
8/25: 0.9467104107141495 +- 0.003050951899862184
9/25: 0.956660145521164 +- 0.003213890813749987
10/25: 0.9545305281877517 +- 0.0007363977080063361
11/25: 0.9505843311548233 +- 0.0027965248271467303
12/25: 0.9572696119546891 +- 0.0018795136675005205
13/25: 0.9454546719789505 +- 0.004517236042355438
14/25: 0.9480833083391189 +- 0.0016336932109292906
15/25: 0.9679292768239975 +- 0.0008336887345192551
16/25: 0.9487917572259903 +- 0.0006282828417036524
17/25: 0.9533411115407944 +- 0.0011391023288820753
18/25: 0.9606295198202133 +- 0.0016101682593299073
19/25: 0.9512471139431 +- 0.001468219155822834
20/25: 0.9590654999017716 +- 0.001514747507445293
21/

(0.9540767500400543, 0.006269004425904412)

In [7]:
cross_val_train(experiment_name='experiment_2_spleen', train_params=train_params, layer_dict=layer_dict, net_list=experiment_2_spleen_net_list)

1/25: 0.9586336076259613 +- 0.0015296674651668514
2/25: 0.9521344661712646 +- 0.0018799378507229289
3/25: 0.9431748479604721 +- 0.0023584449777548128
4/25: 0.9533429741859436 +- 0.0012945018276879505
5/25: 0.9552649736404419 +- 0.001315313309668335
6/25: 0.9613033324480057 +- 0.0006681284281454982
7/25: 0.9616694778203965 +- 0.0007620260650738953
8/25: 0.9300987869501114 +- 0.0035955751018805883
9/25: 0.9544148027896882 +- 0.002155538950400743
10/25: 0.9611081629991531 +- 0.0011751988898796853
11/25: 0.9570348113775253 +- 0.0015405312508649839
12/25: 0.9532516121864318 +- 0.0020191445752437546
13/25: 0.9537991464138031 +- 0.0021480603303381174
14/25: 0.9511320471763611 +- 0.0013085587107267839
15/25: 0.9684049993753433 +- 0.0006929646537940563
16/25: 0.9564472198486328 +- 0.0009185129917181294
17/25: 0.956426602602005 +- 0.0009138430142291659
18/25: 0.964323765039444 +- 0.0008415489987128799
19/25: 0.9470460265874863 +- 0.003731269459559279
20/25: 0.9621238321065902 +- 0.00117169652109

(0.9554525258541107, 0.007804035073413699)

In [8]:
cross_val_train(experiment_name='experiment_3_spleen', train_params=train_params, layer_dict=layer_dict, net_list=experiment_3_spleen_net_list)

1/25: 0.9582314223051072 +- 0.0010080013248845935
2/25: 0.9527800440788269 +- 0.004334426235167483
3/25: 0.9437697112560273 +- 0.0026870757314198203
4/25: 0.9530692040920258 +- 0.0016631264125441256
5/25: 0.9583071231842041 +- 0.0012056526620755138
6/25: 0.9605179369449616 +- 0.0012967398940384874
7/25: 0.9600597769021988 +- 0.001064443932078062
8/25: 0.9464783042669296 +- 0.002391589025631526
9/25: 0.9567580074071884 +- 0.0010827843304920576
10/25: 0.9581779718399048 +- 0.000774247823638988
11/25: 0.9486613184213638 +- 0.0018207323317174658
12/25: 0.9657627820968628 +- 0.0013346658435434595
13/25: 0.9553629428148269 +- 0.0011079460013464757
14/25: 0.9466056287288666 +- 0.0008927417258934188
15/25: 0.9686016619205475 +- 0.0009300531433592611
16/25: 0.9553572863340378 +- 0.0008422526772411041
17/25: 0.9550779730081558 +- 0.0010268978392425831
18/25: 0.9626703441143036 +- 0.0007131460939318709
19/25: 0.9519388765096665 +- 0.0013352564756483698
20/25: 0.9632912576198578 +- 0.0009508719086

(0.9563304506540299, 0.006759491394000532)

In [9]:
cross_val_train(experiment_name='experiment_4_spleen', train_params=train_params, layer_dict=layer_dict, net_list=experiment_4_spleen_net_list)

1/25: 0.9602617084980011 +- 0.002107420213527357
2/25: 0.9533402383327484 +- 0.0030326379955598667
3/25: 0.9286989241838455 +- 0.005967437542173941
4/25: 0.9534731864929199 +- 0.0014622114444631195
5/25: 0.9534440219402314 +- 0.0006621717528118438
6/25: 0.9625387042760849 +- 0.0019248732461814757
7/25: 0.9626891642808915 +- 0.0010700523109551216
8/25: 0.9413607448339463 +- 0.002647161341669799
9/25: 0.954887393116951 +- 0.0011459818421624842
10/25: 0.956618195772171 +- 0.001591457599488605
11/25: 0.9501158893108368 +- 0.00282410553619478
12/25: 0.9634248316287994 +- 0.0015166314424155708
13/25: 0.950113770365715 +- 0.005070848900737652
14/25: 0.9611994832754135 +- 0.0014200916450315673
15/25: 0.9608319103717804 +- 0.005174864871226907
16/25: 0.9583731442689896 +- 0.0015916039670073004
17/25: 0.9577411562204361 +- 0.0006961160930244055
18/25: 0.9660636603832244 +- 0.000916739024797402
19/25: 0.9540601372718811 +- 0.0009760237267881314
20/25: 0.9593404948711395 +- 0.0014869653100241735
2

(0.9557779929637908, 0.008722250182194234)

In [10]:
# cross_val_train(experiment_name='experiment_1_prostate', train_params=train_params, layer_dict=layer_dict, net_list=experiment_1_prostate_net_list)

In [11]:
# cross_val_train(experiment_name='experiment_2_prostate', train_params=train_params, layer_dict=layer_dict, net_list=experiment_2_prostate_net_list)

In [12]:
# cross_val_train(experiment_name='experiment_3_prostate', train_params=train_params, layer_dict=layer_dict, net_list=experiment_3_prostate_net_list)

In [13]:
# cross_val_train(experiment_name='experiment_4_prostate', train_params=train_params, layer_dict=layer_dict, net_list=experiment_4_prostate_net_list)