In [1]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import load_model
import nibabel as nib
from skimage.measure import label

In [2]:
def neuroToRadio(vol, flip_flag):
    """ Change from neurological to radiological orientation. """
    vol = np.transpose(vol, axes=(1, 0, 2))
    if flip_flag:
        vol = np.fliplr(vol)
    vol = np.flipud(vol)

    return vol

def adjust_HU_value(volume, x=240, y=-160):
    # The upper grey level x and the lower grey level y
    volume[volume > x] = x  # above x will be white 
    volume[volume < y] = y  # below y will be black
    return volume
    
def normalize(volume):
    # Normalize into range [0, 1]
    vol_max = volume.max()
    vol_min = volume.min()
    volume = (volume - vol_min) / (vol_max - vol_min)
    return volume

def nifti_to_array(file_path):
    volume = nib.load(file_path)
    volume = volume.get_fdata()
    (w, h, d) = volume.shape
    array = np.zeros((d, w, h, 3))
    volume = adjust_HU_value(volume)
    volume = neuroToRadio(volume, flip_flag=False)
    volume = normalize(volume)
    volume = volume * 255
    volume = volume.astype(np.uint8)
    for i in range(d):
        array[i] = cv2.cvtColor(volume[:,:,i], cv2.COLOR_GRAY2BGR)
        array[i] = normalize(array[i])
    return array

def morphology(seg_vol, kernel_size = (15, 15)):
  (x, y, z) = seg_vol.shape
  kernel1 =  cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size)
  kernel2 =  cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size)
  for i in range(x):
    axial = cv2.morphologyEx(seg_vol[i,:,:], cv2.MORPH_CLOSE, kernel1)
    seg_vol[i,:,:] = cv2.morphologyEx(axial, cv2.MORPH_OPEN, kernel2)
  for i in range(y):
    coronal = cv2.morphologyEx(seg_vol[:,i,:], cv2.MORPH_CLOSE, kernel1)
    seg_vol[:,i,:] = cv2.morphologyEx(coronal, cv2.MORPH_OPEN, kernel2)
  for i in range(z):
    sagittal = cv2.morphologyEx(seg_vol[:,:,i], cv2.MORPH_CLOSE, kernel1)
    seg_vol[:,:,i] = cv2.morphologyEx(sagittal, cv2.MORPH_OPEN, kernel2)
  return seg_vol
def centroid(img, lcc=False):
  if lcc:
    img = img.astype(np.uint8)
    nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=4)
    sizes = stats[:, -1]
    if len(sizes) > 2:
      max_label = 1
      max_size = sizes[1]

      for i in range(2, nb_components):
          if sizes[i] > max_size:
              max_label = i
              max_size = sizes[i]

      img2 = np.zeros(output.shape)
      img2[output == max_label] = 255
      img = img2

  if len(img.shape) > 2:
    M = cv2.moments(img[:,:,1])
  else:
    M = cv2.moments(img)

  if M["m00"] == 0:
    return (img.shape[0] // 2, img.shape[1] // 2)
  
  cX = int(M["m10"] / M["m00"])
  cY = int(M["m01"] / M["m00"])
  return (cX, cY)

def to_polar(input_img, center):
  #input_img = input_img.astype(np.float32)
  value = np.sqrt(((input_img.shape[0]/2.0)**2.0)+((input_img.shape[1]/2.0)**2.0))
  polar_image = cv2.linearPolar(input_img, center, value, cv2.WARP_FILL_OUTLIERS)
  polar_image = cv2.rotate(polar_image, cv2.ROTATE_90_COUNTERCLOCKWISE)
  return polar_image

def to_cart(input_img, center):
  #input_img = input_img.astype(np.float32)
  input_img = cv2.rotate(input_img, cv2.ROTATE_90_CLOCKWISE)
  value = np.sqrt(((input_img.shape[1]/2.0)**2.0)+((input_img.shape[0]/2.0)**2.0))
  polar_image = cv2.linearPolar(input_img, center, value, cv2.WARP_FILL_OUTLIERS + cv2.WARP_INVERSE_MAP)
  #polar_image = polar_image.astype(np.uint8)
  return polar_image

def get_CC_largerThanTh(arr, thresh=8000,dbg=False):
    if dbg:
        dbg_CC(arr, prec=0.02)

    print('Applying Connected Component and take components with num pixels > max_pixels')
    labels = label(arr)
    print('Found ', labels.max(), 'labels')
    max_label = 0
    # Find largestCC
    large_labels = []
    for c_label in range(1, labels.max()+1):
        curr_num_bins = np.sum(np.where(labels == c_label, 1, 0))
        print(c_label, ':', curr_num_bins)
        if curr_num_bins > thresh:
            large_labels.append(c_label)
    print('Max CC label is: ', max_label)

    print('Num liver before CC: ', np.sum(arr))
    is_first = True
    for c_label in large_labels:
        if is_first:
            arr = np.where(labels == c_label, 1, 0)
            is_first = False
        else:
            arr[labels == c_label] = 1
    print('Num liver After CC: ',np.sum(arr) )

    if dbg:
        dbg_CC(arr,prec=0.02)
    return arr

def dbg_CC(arr, prec=0.01):
    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    pos = np.where(arr == 1)
    num_points = int(np.round(prec * len(pos[0])))
    indices = np.random.permutation(len(pos[0]))[0:num_points]

    ax.scatter(pos[0][indices], pos[1][indices], pos[2][indices])
    ax.view_init(elev=230., azim=360)
    plt.show()
    plt.ioff()
    # plt.waitforbuttonpress()
    plt.close()

def get_crop_coordinates(img, pad_size=2):
    """ input: binaty @D image
        output: crop coordinates of minimal "1" area with gap padding"""
    im_h, im_w = img.shape
    liver_m, liver_n = np.where(img >= 1)
    h_min = min(liver_m) - pad_size
    h_max = max(liver_m) + pad_size
    h = h_max - h_min + 1
    w_min = min(liver_n) - pad_size
    w_max = max(liver_n) + pad_size
    w = w_max - w_min + 1
    gap = abs(h - w)
    pad_l = int(np.ceil(gap / 2.))
    pad_r = int(np.floor(gap / 2.))
    if h > w:
        w_min -= pad_l
        w_max += pad_r
        if w_min < 0:
            w_min = 0
            w_max += (0 - w_min)
        if w_max > im_w:
            w_min -= w_max - im_w
            w_max = im_w
    if h < w:
        h_min -= pad_l
        h_max += pad_r
        if h_min < 0:
            h_min = 0
            h_max += (0 - h_min)
        if h_max > im_h:
            h_min -= h_max - im_h
            h_max = im_h

    return h_min, h_max, w_min, w_max

def get_crop_coordinates_3D(img_arr, pad_size=1,dbg=False):
    """ input: binaty 3D image
        output: global crop coordinates of minimal "1" area with gap padding"""
    im_d, im_h, im_w = img_arr.shape
    liver_z, liver_h, liver_w = np.where(img_arr >= 1)
    h_min = min(liver_h) - pad_size
    h_max = max(liver_h) + pad_size
    h = h_max - h_min + 1
    w_min = min(liver_w) - pad_size
    w_max = max(liver_w) + pad_size
    w = w_max - w_min + 1
    gap = abs(h - w)
    pad_l = int(np.ceil(gap / 2.))
    pad_r = int(np.floor(gap / 2.))
    if h > w:
        w_min -= pad_l
        w_max += pad_r
        if w_min < 0:
            w_min = 0
            w_max += (0 - w_min)
        if w_max > im_w:
            w_min -= w_max - im_w
            w_max = im_w
    if h < w:
        h_min -= pad_l
        h_max += pad_r
        if h_min < 0:
            h_min = 0
            h_max += (0 - h_min)
        if h_max > im_h:
            h_min -= h_max - im_h
            h_max = im_h
    if dbg:
        from mpl_toolkits.mplot3d import Axes3D
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        num_points = int(np.round(0.02 * len(liver_z)))
        indices = np.random.permutation(len(liver_z))[0:num_points]
        # ax.scatter(liver_h, liver_w, liver_z)
        ax.scatter(liver_h[indices], liver_w[indices],liver_z[indices],s=0.8)
        ax.plot([h_min,h_max,h_max,h_min,h_min],[w_min, w_min,w_max, w_max,w_min], zs=int(im_d / 2), zdir='z',color='black')
        ax.view_init(elev=180., azim=360)
        plt.show()
        plt.ioff()
        plt.waitforbuttonpress()
        plt.close()

    return h_min, h_max, w_min, w_max

In [3]:
liver_model = load_model('D:/Study/Thesis/LiTS/models/liver_weights_best.h5')
tumor_model = load_model('D:/Study/Thesis/LiTS/models/tumor-liver-crops-512x512_weights_best.h5')
tumor_polar_model = load_model('D:/Study/Thesis/LiTS/models/tumor-liver-crops-polar-512x512_weights_best.h5')



In [10]:
data_path = 'D:/Study/Thesis/LiTS'
filenames = ['test-volume-61.nii']
output_dir = "D:/Study/Thesis/LiTS/submit"

In [11]:
for filename in filenames[0:1]:
  image_arr = nifti_to_array(file_path=os.path.join(data_path, filename))
  
  pred = liver_model.predict(image_arr, verbose = 1, batch_size=4)
  pred[pred >= 0.5] = 1
  pred[pred < 0.5] = 0
  
  liver_seg = pred[:,:,:,0]
  (d, w, h) = liver_seg.shape
  liver_seg = morphology(liver_seg)
  
  seg_liver_CC = get_CC_largerThanTh(np.where(liver_seg > 0, 1, 0), dbg=False)
  # 3-D Crop coordinates
  (h1, h2, w1, w2) = get_crop_coordinates_3D(seg_liver_CC, dbg=False)
  seg_lesion_arr = np.zeros((d, w, h)).astype('uint8')
  
  # Crop in 3D
  crop_img_arr = image_arr[:, h1:h2, w1:w2, :]
  crop_mask_arr = liver_seg[:, h1:h2, w1:w2]
  liver_crop_w, liver_crop_h = (512, 512)
  _, curr_liver_crop_w, curr_liver_crop_h = crop_mask_arr.shape
  for i in range(d):
    if crop_mask_arr[i].sum() == 0:
      pred_tumor = np.zeros((liver_crop_w, liver_crop_h)).astype('uint8')
    else:
      crop_img = cv2.resize(crop_img_arr[i], (liver_crop_w, liver_crop_h), interpolation=cv2.INTER_CUBIC)
      crop_img = np.expand_dims(crop_img, axis=0)
      pred_tumor = tumor_model.predict(crop_img, verbose=1)[0][:,:,0].astype('uint8')
      if pred_tumor.sum() != 0:
        center = centroid(pred_tumor)
        crop_img = to_polar(crop_img[0], center)
        crop_img = np.expand_dims(crop_img, axis=0)
        pred_tumor = tumor_polar_model.predict(crop_img, verbose = 1)[0][:,:,0]
        pred_tumor[pred_tumor > 0.2] = 1
        pred_tumor[pred_tumor < 0.2] = 0
        pred_tumor = to_cart(pred_tumor, center)

    pred_tumor = cv2.resize(pred_tumor, (curr_liver_crop_w, curr_liver_crop_h), interpolation=cv2.INTER_NEAREST)
    pred_tumor[crop_mask_arr[i] == 0] = 0
    seg_lesion_arr[i, h1:h2, w1:w2] = pred_tumor

  seg_lesion_arr[seg_lesion_arr == 1] = 2
  origin_volume = nib.load(os.path.join(data_path, filename))
  seg_tumor = np.zeros(origin_volume.shape).astype('uint8')
  
  for i in range(d):
    seg_tumor[:,:,i] = np.fliplr(np.transpose(seg_lesion_arr[i]))
  seg_tumor = nib.Nifti1Image(seg_tumor, origin_volume.affine, origin_volume.header)
  nib.save(seg_tumor, os.path.join(output_dir, filename.replace('volume', 'segmentation')))    
  print(f'Save predicted {filename}')

Applying Connected Component and take components with num pixels > max_pixels
Found  1 labels
1 : 4514450
Max CC label is:  0
Num liver before CC:  4514450
Num liver After CC:  4514450
Save predicted test-volume-61.nii
