In [None]:
base_path = "D:/BELEG/"
# contains *.tif, *_MRI.nii, *_LABEL.nii, *_GT.png files
path_training = base_path + "Dataset/training-cropped-masked/"
path_testing = base_path + "Dataset/testing-cropped-masked/"

In [None]:
import os
import tensorflow as tf

from tensorflow.python.framework import constant_op, dtypes, ops
from tensorflow.python.ops import array_ops, control_flow_ops, math_ops, nn, nn_ops
import os

import imageio
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras import backend as keras

import glob
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from matplotlib import pyplot
import ipywidgets
from ipywidgets import interact

# define parameters
batch_size = 1
epoch = 30
lr = 0.002
lamda = 1

In [None]:
@tf.function
def save_png(tensor, name):
    img = tf.cast(img, tf.uint8)
    img = tf.image.encode_png(img)
    tf.io.write_file('pred/' + name, img)
    return tensor

def show_volume(volume):
  show_information(volume)
  a = tf.transpose(volume,(2,0,1))
  def print_volume(x):
    plt.imshow(a[x], cmap='jet')
  interact(print_volume, x=ipywidgets.IntSlider(min=0, max=tf.shape(volume)[2]-1, step=1, value=0))

def show_information(x):
  tf.print("Tensor Shape:", tf.shape(x))
  tf.print("Shape:", x.shape)
  tf.print("Mean:", tf.reduce_mean(x))
  tf.print("Min:", tf.reduce_min(x))
  tf.print("Max:", tf.reduce_max(x))

@tf.function
def normalize_dataset(tensor):
    return tf.divide(
       tf.subtract(
          tensor,
          tf.reduce_min(tensor)
       ),
       tf.subtract(
          tf.reduce_max(tensor),
          tf.reduce_min(tensor)
       )
    )
  
@tf.function
def to_float_dataset(tensor):
    return tf.cast(tensor, tf.float32)

@tf.function
def zero_to_one_dataset(tensor):
    # int16 to float32, scaling from 0-32k to 0-1 is done automatically
   return tf.image.convert_image_dtype(tensor, tf.float32)

# loss function : 2 possibilities
# SSIM on prediction vs NSST groundtruth
# or SSIM on prediction vs each of the input slices
# y_true is the input volume
@tf.function
def loss_ssim_unsupervised(y_true, y_pred):
    #shape pred [1 240x 240y 1z]
    #shape true [1 240x 240y 155z]
    y_true = tf.transpose(y_true, perm=[3,1,2,0])
    y_pred = tf.squeeze(y_pred, 0)
    #reshape pred [240x 240y 1z]
    #reshape true [155z 240x 240y 1]

    def ssim_l1_loss(elems):
        ssim_layer_loss = lamda * (1 - tf.image.ssim(elems, y_pred, max_val = 1))
        #tf.print(ssim_layer_loss)
        return ssim_layer_loss

    loss = tf.map_fn(ssim_l1_loss, y_true, dtype=tf.float32)
    # sum of all (l1 + ssim) layer losses = total loss
    return tf.reduce_sum(loss)
    
@tf.function
def loss_binary_crossentropy_unsupervised(y_true, y_pred):
    y_true = tf.transpose(y_true, perm=[3,1,2,0])
    y_pred = tf.squeeze(y_pred, 0)

    def binary_crossentropy_l1_loss(elems):
        binary_crossentropy_layer_loss = lamda * (1 - tf.keras.backend.binary_crossentropy(elems, y_pred))
        return binary_crossentropy_layer_loss

    loss = tf.map_fn(binary_crossentropy_l1_loss, y_true, dtype=tf.float32)
    return tf.reduce_sum(loss)

# SSIM on prediction vs groundtruth
# y_true is the NSST groundtruth
@tf.function
def loss_ssim_supervised_nsst(y_true, y_pred):
    # ssim and l1 loss : compare NSST groundtruth to the network's prediction
    ssim_mri_loss = lamda * (1 - tf.image.ssim(y_pred, y_true, max_val = 1.0))
    l1_mri_loss = (1 - lamda) * (tf.reduce_mean(tf.abs(y_pred - y_true)))
    return ssim_mri_loss + l1_mri_loss

In [None]:
print("PREPARING DATASET")
# load the training files from img folder : 3D MRI, 3D label mask
get_volume = lambda file: nib.load(file).get_data()
get_image = lambda file: pyplot.imread(file)

# 100 volumes, 100x240x240x155 voxels, values in 0-2^16/2, int16 array
train_volumes_mri = np.array(list(map(get_volume, glob.glob(path_training + "*.nii"))))
print("Found training dataset")
# implicit cast to float32 tensor, values in 0-1
# 3 options: normalize or divide by max integer value or cast to float
train_volumes_mri = tf.map_fn(normalize_dataset, train_volumes_mri, dtype=tf.float32)
print("Normalized training dataset")

test_volumes_mri = np.array(list(map(get_volume, glob.glob(path_testing + "*.nii"))))
print("Found testing dataset")
# normalize the data
test_volumes_mri = tf.map_fn(normalize_dataset, test_volumes_mri, dtype=tf.float32)
print("Normalized testing dataset")

In [None]:
print(tf.test.is_gpu_available())
show_information(train_volumes_mri[10])

In [None]:
input = Input(shape=(128,128,128))
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(input)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(pool1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(pool2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv5)

up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(UpSampling2D(size = (2,2))(conv5))
merge6 = concatenate([conv4,up6], axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv6)

up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv3,up7], axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv7)

up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv2,up8], axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv8)

up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv1,up9], axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv9)
network = Conv2D(1, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal', data_format = 'channels_last')(conv9)

In [None]:
print("DEFINING MODEL")
model = tf.keras.Model(inputs = input, outputs = network)
#model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002), loss = loss_binary_crossentropy_unsupervised, metrics = ['accuracy'])
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.002), loss = loss_ssim_unsupervised, metrics = ['accuracy'])

checkpoint_path = base_path + "tfgan/checkpoints-unet/checkpoint-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1)
latest = tf.train.latest_checkpoint(base_path + "tfgan/checkpoints-unet")
if(latest == None):
    print("No checkpoints")
else:
    print("Loading latest checkpoint")
    model.load_weights(latest)
#model.summary(line_length=100)

tb_callback = tf.keras.callbacks.TensorBoard(log_dir=base_path + 'logs')

In [None]:
class save_epoch_prediction(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        predictions = self.model.predict(test_volumes_mri, batch_size = 1, verbose = 1)
        i=0
        for pred in predictions:
            img = tf.image.convert_image_dtype(pred, tf.uint8)
            img = tf.image.encode_png(img)
            name = base_path + 'tfgan/checkpoints-unet/pred/' + str(i) + '.png'
            tf.io.write_file(name, img)
            i = i + 1

In [None]:
print("STARTING TRAINING PHASE")
# train
history_fit = model.fit(train_volumes_mri, train_volumes_mri, epochs = epoch, batch_size = batch_size, callbacks = [cp_callback, save_epoch_prediction()])

In [None]:
print("STARTING TESTING PHASE")
show_information(test_volumes_mri)
history_eval = model.evaluate(test_volumes_mri, test_volumes_mri, batch_size = batch_size)

# get results
predictions = model.predict(test_volumes_mri, batch_size = 1, verbose = 1)
print("max pixel value:", tf.reduce_max(predictions))

In [None]:
tf.print("predictions")
show_information(predictions)
tf.print("\ntrain-dataset")
show_information(train_volumes_mri)
tf.print("\ntest-dataset")
show_information(test_volumes_mri)