In [None]:
!pip install SimpleITK
from keras import backend as K
import numpy as np
import SimpleITK as sitk
from keras.models import load_model
import matplotlib.pyplot as plt
import tensorflow as tf



class KerasParas:
    def __init__(self):
        self.model_path = None
        self.outID = 0                         
        self.thd = 0.5
        self.img_format = 'channels_first'
        self.loss = None


class PreParas:
    def __init__(self):
        self.patch_dims = []
        self.patch_label_dims = []
        self.patch_strides = []
        self.n_class = ''


def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + K.epsilon()) / (K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())


def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)


def dice_coef_np(y_true, y_pred):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return 2.0 * intersection / (np.sum(y_true_f) + np.sum(y_pred_f))


def resample_img(imgobj, new_spacing, interpolator, new_size=None):
    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(interpolator)
    resample.SetOutputDirection(imgobj.GetDirection())
    resample.SetOutputOrigin(imgobj.GetOrigin())
    resample.SetOutputSpacing(new_spacing)

    if new_size is None:
        orig_size = np.array(imgobj.GetSize(), dtype=np.int)
        orig_spacing = np.array(imgobj.GetSpacing())
        new_size = orig_size * (orig_spacing / new_spacing)
        new_size = np.ceil(new_size).astype(np.int)  # Image dimensions are in integers
        new_size = [int(s) for s in new_size]

    resample.SetSize(new_size)

    resampled_imgobj = resample.Execute(imgobj)
    return resampled_imgobj


def dim_2_categorical(label, num_class):
    dims = label.ndim
    if dims == 2:
        col, row = label.shape
        ex_label = np.zeros((num_class, col, row))
        for i in range(0, num_class):
            ex_label[i, ...] = np.asarray(label == i).astype(np.uint8)
    elif dims == 3:
        leng,col,row = label.shape
        ex_label = np.zeros((num_class, leng, col, row))
        for i in range(0, num_class):
            ex_label[i, ...] = np.asarray(label == i).astype(np.uint8)
    else:
        raise Exception
    return ex_label


def out_LabelHot_map_2D(img, seg_net, pre_paras, keras_paras):
    # reset the variables
    patch_dims = pre_paras.patch_dims
    label_dims = pre_paras.patch_label_dims
    strides = pre_paras.patch_strides
    n_class = pre_paras.n_class

    # build new variables for output
    length, col, row = img.shape
    categorical_map = np.zeros((n_class, length, col, row), dtype=np.uint8)
    likelihood_map = np.zeros((length, col, row), dtype=np.float32)
    counter_map = np.zeros((length,col,row), dtype=np.float32)
    length_step = int(patch_dims[0]/2)

    """-----predict the whole image from two directions, small to large and large to small----"""
    for i in range(0, length-patch_dims[0]+1, strides[0]):
        for j in range(0, col-patch_dims[1]+1, strides[1]):
            for k in range(0, row-patch_dims[2]+1, strides[2]):
                cur_patch=img[i:i+patch_dims[0],
                              j:j+patch_dims[1],
                              k:k+patch_dims[2]][:].reshape([1,
                                                             patch_dims[0],
                                                             patch_dims[1],
                                                             patch_dims[2]])
                if keras_paras.img_format == 'channels_last':
                    cur_patch = np.transpose(cur_patch, (0, 2, 3, 1))

                cur_patch_output = seg_net.predict(cur_patch, batch_size=1, verbose=0)

                # if there are multiple outputs
                if isinstance(cur_patch_output,list):
                    cur_patch_output = cur_patch_output[keras_paras.outID]
                cur_patch_output = np.squeeze(cur_patch_output)
                cur_patch_out_label = cur_patch_output.copy()
                cur_patch_out_label[cur_patch_out_label >= keras_paras.thd] = 1
                cur_patch_out_label[cur_patch_out_label < keras_paras.thd] = 0

                middle = i + length_step
                cur_patch_out_label = dim_2_categorical(cur_patch_out_label,n_class)

                categorical_map[:, middle, j:j+label_dims[1], k:k+label_dims[2]] \
                    = categorical_map[:, middle, j:j+label_dims[1], k:k+label_dims[2]] + cur_patch_out_label
                likelihood_map[middle, j:j+label_dims[1], k:k+label_dims[2]] \
                    = likelihood_map[middle, j:j+label_dims[1], k:k+label_dims[2]] + cur_patch_output
                counter_map[middle, j:j+label_dims[1], k:k+label_dims[2]] += 1

    for i in range(length, patch_dims[0]-1, -strides[0]):
        for j in range(col, patch_dims[1]-1, -strides[1]):
            for k in range(row, patch_dims[2]-1, -strides[2]):

                cur_patch=img[i-patch_dims[0]:i,
                              j-patch_dims[1]:j,
                              k-patch_dims[2]:k][:].reshape([1, patch_dims[0], patch_dims[1], patch_dims[2]])
                if keras_paras.img_format == 'channels_last':
                    cur_patch = np.transpose(cur_patch, (0, 2, 3, 1))

                cur_patch_output = seg_net.predict(cur_patch, batch_size=1, verbose=0)

                if isinstance(cur_patch_output,list):
                    cur_patch_output = cur_patch_output[keras_paras.outID]
                cur_patch_output = np.squeeze(cur_patch_output)

                cur_patch_out_label = cur_patch_output.copy()
                cur_patch_out_label[cur_patch_out_label >= keras_paras.thd] = 1
                cur_patch_out_label[cur_patch_out_label < keras_paras.thd] = 0

                middle = i - patch_dims[0] + length_step
                cur_patch_out_label = dim_2_categorical(cur_patch_out_label,n_class)
                categorical_map[:, middle, j-label_dims[1]:j, k-label_dims[2]:k] = \
                    categorical_map[:, middle, j-label_dims[1]:j, k-label_dims[2]:k] + cur_patch_out_label
                likelihood_map[middle, j-label_dims[1]:j, k-label_dims[2]:k] = \
                    likelihood_map[middle, j-label_dims[1]:j, k-label_dims[2]:k] + cur_patch_output
                counter_map[middle, j-label_dims[1]:j, k-label_dims[2]:k] += 1

    label_map = np.zeros([length,col,row],dtype=np.uint8)
    for idx in range(0,length):
        cur_slice_label = np.squeeze(categorical_map[:, idx,].argmax(axis=0))
        label_map[idx,] = cur_slice_label

    counter_map = np.maximum(counter_map, 10e-10)
    likelihood_map = np.divide(likelihood_map,counter_map)

    return label_map, likelihood_map, counter_map


def resample_img(imgobj, new_spacing, interpolator, new_size=None):
    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(interpolator)
    resample.SetOutputDirection(imgobj.GetDirection())
    resample.SetOutputOrigin(imgobj.GetOrigin())
    resample.SetOutputSpacing(new_spacing)

    if new_size is None:
        orig_size = np.array(imgobj.GetSize(), dtype=np.int)
        orig_spacing = np.array(imgobj.GetSpacing())
        new_size = orig_size * (orig_spacing / new_spacing)
        new_size = np.ceil(new_size).astype(np.int)  # Image dimensions are in integers
        new_size = [int(s) for s in new_size]

    resample.SetSize(new_size)

    resampled_imgobj = resample.Execute(imgobj)
    return resampled_imgobj


def min_max_normalization(img):
    new_img = img.copy()
    new_img = new_img.astype(np.float32)

    min_val = np.min(new_img)
    max_val = np.max(new_img)
    new_img =(np.asarray(new_img).astype(np.float32) - min_val)/(max_val-min_val)
    return new_img


def rescale_voxels(input_path):
    """
    Takes the input nii/nii.gz file and rewrites the metadata to correct for the previous 10x voxel upscale.
    :param input_path: input string of the file
    :return: SimpleITK image object
    """
    imgobj = sitk.ReadImage(input_path)
    keys = ['pixdim[1]', 'pixdim[2]', 'pixdim[3]']
    for key in keys:
        original_key = imgobj.GetMetaData(key)
        if original_key == '':
            raise Exception('Voxel parameter not set for file: ' + input_path)
        print('Old voxel dimension: ' + original_key)
        imgobj.SetMetaData(key, str(round(float(original_key)/10, 5)))
        print('New voxel dimension: ' + imgobj.GetMetaData(key))
    new_parameters = [param / 10 for param in list(imgobj.GetSpacing())]
    imgobj.SetSpacing(new_parameters)
    return imgobj


def preprocess(input):
    """
    Takes either the imgobj or the input string and resamples/normalizes it as described in the Hsu et al.
    paper.
    :param input: SimpleITK image object or string
    :return: Rescaled image array and the image object
    """
    if str(type(input)) == "<class 'SimpleITK.SimpleITK.Image'>":
        imgobj = input
    elif type(input) == str:
        imgobj = sitk.ReadImage(input)
    else:
        raise Exception('Input is not defined correctly!')
    # re-sample to 0.1x0.1x0.1
    resampled_imgobj = resample_img(imgobj, new_spacing=[0.1, 0.1, 1], interpolator=sitk.sitkLinear)
    print('Image resampled!')
    img_array = sitk.GetArrayFromImage(resampled_imgobj)
    img = min_max_normalization(img_array)
    return img, resampled_imgobj

In [None]:
files = tf.io.gfile.glob('/content/drive/MyDrive/mri-dataset/nii-files/*/*.nii.gz')

masks = [path for path in files if 'whole_brain' in path]
images = [item for item in files if 'whole_brain' not in item]
images = [item for item in images if 'bias2' not in item]

In [None]:
# Default Parameters Preparation
pre_paras = PreParas()
pre_paras.patch_dims = [1, 128, 128]
pre_paras.patch_label_dims = [1, 128, 128]
pre_paras.patch_strides = [1, 32, 32]
pre_paras.n_class = 2

# Parameters for Keras model
keras_paras = KerasParas()
keras_paras.outID = 0
keras_paras.thd = 0.5
keras_paras.loss = 'dice_coef_loss'
keras_paras.img_format = 'channels_last'
keras_paras.model_path = '/content/drive/MyDrive/mri-dataset/rat_brain-2d_unet.hdf5'

# load model
seg_net = load_model(keras_paras.model_path, custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})


for input_path, mask_input_path in zip(images,masks):
  print(input_path)
  imgobj = sitk.ReadImage(input_path)
  img_rescaled = rescale_voxels(input_path)
  normed_array, resampled_imgobj = preprocess(img_rescaled)

  out_label_map, out_likelihood_map, counter_map = out_LabelHot_map_2D(normed_array, seg_net, pre_paras, keras_paras)

  out_label_img = sitk.GetImageFromArray(out_label_map.astype(np.uint8))
  out_likelihood_img = sitk.GetImageFromArray(out_likelihood_map.astype(np.float))

  resampled_label_map = resample_img(out_label_img, new_spacing=(1.3671875, 1.3671875, 1.0), new_size=imgobj.GetSize(), interpolator=sitk.sitkNearestNeighbor)
  resampled_likelihood_img = resample_img(out_likelihood_img, new_spacing=(1.3671875, 1.3671875, 1.0), new_size=imgobj.GetSize(), interpolator=sitk.sitkNearestNeighbor)


  input_path_split = input_path.split('/')

  label_path = '/content/drive/MyDrive/mri-dataset/nii-files/' + input_path_split[-1] + '_label.nii'
  likelihood_path = '/content/drive/MyDrive/mri-dataset/nii-files/' + input_path_split[-1] + '_likelihood.nii'

  sitk.WriteImage(resampled_label_map, label_path)
  sitk.WriteImage(resampled_likelihood_img, likelihood_path)

  print(label_path + ' is saved!')