In [1]:
from cnn.model import build_net

from spleen_dataset.dataloader import SpleenDataloader, SpleenDataset, get_training_augmentation, get_validation_augmentation
from spleen_dataset.config import dataset_folder
from spleen_dataset.utils import get_split_deterministic, get_list_of_patients

from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
import matplotlib.pyplot as plt
import random
import numpy as np

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    #tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=6400)])
    tf.config.experimental.set_memory_growth(gpus[0], False)
  except RuntimeError as e:
    print(e)

2022-09-26 18:11:47.299504: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-26 18:11:47.323209: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-26 18:11:47.323576: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero


In [2]:
patients = get_list_of_patients(dataset_folder)

In [3]:
patch_size = (128, 128)
batch_size = 32
num_classes = 2
train_augmentation = get_training_augmentation(patch_size)
val_augmentation = get_validation_augmentation(patch_size)

net_list = [ 
    'vgg_3_64',
    'vgg_3_128',
    'vgg_3_256',
    'vgg_3_512',
    'vgg_3_1024',
    'vgg_3_512',
    'vgg_3_256',
    'vgg_3_128',
    'vgg_3_64',
]


fn_dict = {
    'vgg_3_32':       {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 32},   'prob': 1/36},
    'vgg_3_64':       {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 64},   'prob': 1/36},
    'vgg_3_128':      {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 128},  'prob': 1/36},
    'vgg_3_256':      {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 256},  'prob': 1/36},
    'vgg_3_512':      {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 512},  'prob': 1/36},
    'vgg_3_1024':     {'block': 'VGGBlock',      'params': {'kernel': 3, 'filters': 1024}, 'prob': 1/36},
    # 'resnet_3_32':    {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 32},   'prob': 1/36},
    # 'resnet_3_64':    {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 64},   'prob': 1/36},
    # 'resnet_3_128':   {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 128},  'prob': 1/36},
    # 'resnet_3_256':   {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 256},  'prob': 1/36},
    # 'resnet_3_512':   {'block': 'ResNetBlock',   'params': {'kernel': 3, 'filters': 512},  'prob': 1/36},
    # 'xception_3_32':  {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 32},   'prob': 1/36},
    # 'xception_3_64':  {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 64},   'prob': 1/36},
    # 'xception_3_128': {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 128},  'prob': 1/36},
    # 'xception_3_256': {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 256},  'prob': 1/36},
    # 'xception_3_512': {'block': 'XceptionBlock', 'params': {'kernel': 3, 'filters': 512},  'prob': 1/36},
    # 'mbconv_3_32':    {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 32},   'prob': 1/36},
    # 'mbconv_3_64':    {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 64},   'prob': 1/36},
    # 'mbconv_3_128':   {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 128},  'prob': 1/36},
    # 'mbconv_3_256':   {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 256},  'prob': 1/36},
    # 'mbconv_3_512':   {'block': 'MBConvBlock',   'params': {'kernel': 3, 'filters': 512},  'prob': 1/36},
}



In [4]:
val_gen_dice_coef_list = []
num_splits = 5
num_initializations = 3
epochs = 50
evaluation_epochs = int(0.2 * epochs)

for initialization in range(num_initializations):

    for fold in range(num_splits):
        train_patients, val_patients = get_split_deterministic(patients, fold=fold, num_splits=num_splits, random_state=initialization)

        train_dataset = SpleenDataset(train_patients, only_non_empty_slices=True)
        val_dataset = SpleenDataset(val_patients, only_non_empty_slices=True)

        train_dataloader = SpleenDataloader(train_dataset, batch_size, train_augmentation)
        val_dataloader = SpleenDataloader(val_dataset, batch_size, val_augmentation)

        model = build_net((*patch_size, 1), num_classes, fn_dict, net_list)

        # checkpoint_filepath = '/tmp/checkpoint'
        # model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        #     filepath=checkpoint_filepath,
        #     save_weights_only=True,
        #     monitor='val_gen_dice_coef',
        #     mode='max',
        #     save_best_only=True)

        # def learning_rate_fn(epoch, lr):
        #     return lr * 0.9

        # learning_rate_scheduler_callback = tf.keras.callbacks.LearningRateScheduler(learning_rate_fn)

        def learning_rate_fn(epoch):
            initial_lr = 1e-3
            power = 0.9
            return float(initial_lr * (1 - (epoch / float(epochs))) ** power)
            
        lr_callback = tf.keras.callbacks.LearningRateScheduler(learning_rate_fn, verbose=False)

        history = model.fit(
            train_dataloader,
            validation_data=val_dataloader,
            epochs=epochs,
            verbose=1,
            callbacks=[
                lr_callback
                #model_checkpoint_callback,
                # learning_rate_scheduler_callback
            ]
        )
        
        # model.load_weights(checkpoint_filepath)

        print(history.history['val_gen_dice_coef'][-evaluation_epochs:])

        val_gen_dice_coef_list.extend(history.history['val_gen_dice_coef'][-evaluation_epochs:])

        # for patient in val_patients:
        #     patient_dataset = SpleenDataset([patient], only_non_empty_slices=True)
        #     patient_dataloader = SpleenDataloader(patient_dataset, 1, val_augmentation, shuffle=False)
        #     results = model.evaluate(patient_dataloader)
        #     val_gen_dice_coef_patient = results[-1]
        #     val_gen_dice_coef_list.append(val_gen_dice_coef_patient)

        #plotting the dice coef results (accuracy) as a function of the number of epochs
        plt.figure()
        plt.plot(history.history['gen_dice_coef'])
        plt.plot(history.history['val_gen_dice_coef'])
        plt.title('Model: Generalized Dice Coeficient')
        plt.ylabel('Dice Coef')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Test'], loc='upper left')
        plt.show()

        #plotting the dice coeff results (loss function) as a function of the number of epochs
        # plt.figure()
        # plt.plot(history.history['loss'])
        # plt.plot(history.history['val_loss'])
        # plt.title('Model: Generalized Dice Coeficient')
        # plt.ylabel('Dice Loss')
        # plt.xlabel('Epoch')
        # plt.legend(['Train', 'Test'], loc='upper right')
        # plt.show()

mean_val_gen_dice_coef = (np.mean(val_gen_dice_coef_list))
std_val_gen_dice_coef = (np.std(val_gen_dice_coef_list))

print(f'Dice {mean_val_gen_dice_coef} +- {std_val_gen_dice_coef}')

2022-09-26 18:12:10.581406: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-26 18:12:10.581968: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-26 18:12:10.582132: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-09-26 18:12:10.582220: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zer

Epoch 1/50


2022-09-26 18:12:16.332977: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8400
2022-09-26 18:12:18.693106: I tensorflow/core/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2022-09-26 18:12:20.200230: W tensorflow/core/framework/op_kernel.cc:1733] INVALID_ARGUMENT: required broadcastable shapes
2022-09-26 18:12:20.200269: W tensorflow/core/framework/op_kernel.cc:1733] INVALID_ARGUMENT: required broadcastable shapes


InvalidArgumentError: Graph execution error:

Detected at node 'mul_3' defined at (most recent call last):
    File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/traitlets/config/application.py", line 976, in launch_instance
      app.start()
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 712, in start
      self.io_loop.start()
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
      self._run_once()
    File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
      handle._run()
    File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2881, in run_cell
      result = self._run_cell(
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2936, in _run_cell
      return runner(coro)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3135, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3338, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_420686/2702699341.py", line 40, in <cell line: 7>
      history = model.fit(
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/engine/training.py", line 1409, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/engine/training.py", line 1051, in train_function
      return step_function(self, iterator)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/engine/training.py", line 1040, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/engine/training.py", line 1030, in run_step
      outputs = model.train_step(data)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/engine/training.py", line 894, in train_step
      return self.compute_metrics(x, y, y_pred, sample_weight)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/engine/training.py", line 987, in compute_metrics
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/engine/compile_utils.py", line 501, in update_state
      metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/utils/metrics_utils.py", line 70, in decorated
      update_op = update_state_fn(*args, **kwargs)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/metrics/base_metric.py", line 140, in update_state_fn
      return ag_update_state(*args, **kwargs)
    File "/home/guilherme/git/segqnas/.venv/lib/python3.8/site-packages/keras/metrics/base_metric.py", line 646, in update_state
      matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/home/guilherme/git/segqnas/cnn/metric.py", line 23, in dsc
      intersect = K.sum(y_true_f * y_pred_f, axis=-1)
Node: 'mul_3'
required broadcastable shapes
	 [[{{node mul_3}}]] [Op:__inference_train_function_5039]

In [None]:
#!tensorboard --logdir='./logs'