In [15]:
import numpy as np
import matplotlib.pyplot as plt
import PIL
import os
from PIL import Image
import tifffile
from tifffile import imsave, imread
import nibabel as nib
import random
from skimage.transform import rotate
from skimage import data
from skimage.transform import rescale
import sklearn
import sklearn.feature_extraction
from keras.models import Model, load_model
import scipy

In [16]:
# LOAD YOUR MODEL 
model_axial = load_model('/home/sisyphe/sol-test/AAA_Jarod/MODEL_TRAINED/patch_128_unet_solal_and_i_osbackground_final_axial_inference_from_trained.h5')
model_sagittal=load_model('/home/sisyphe/sol-test/UNet/PP3/Sagittal_PP1+100PS/model_trained/patch_128_augment_unet_osbackground_sagittal_pp3_from_trained.h5')
model_coronal=load_model('/home/sisyphe/sol-test/UNet/PP1/Coronal_PP5+50PJ/model_trained/patch_128_augment_unet_osbackground_coronal_pp1_from_trained.h5')

In [17]:
def norm(array):
    return ((array - array.min())/(array.max() - array.min()))  

In [18]:
def my_patch_generator_non_overlap_for_reconstruction(image,patch_size):
          img = norm(image)

          height = img.shape[0]
          width = img.shape[1]
          patch_size = 128

          if(height%patch_size != 0):
            A = height%patch_size
            padding_height = patch_size - A
            img = np.pad(img,pad_width=((0,padding_height),(0,0)))
          if(width%patch_size != 0):
            A = width%patch_size
            padding_width = patch_size - A
            img = np.pad(img,pad_width=((0,0),(0,padding_width)))
          
          patch_2d_list =[]
          for i in range(0,img.shape[0],patch_size):
            for j in range(0,img.shape[1],patch_size):
              patch = img[i:i+patch_size,j:j+patch_size]
              patch_2d_list.append((patch))
          
          return np.asarray(patch_2d_list)

In [19]:
def my_patch_reconstructor(patch_list, original_size, patch_size):
          size1 = original_size[0]
          size2 = original_size[1]

          if(size1%patch_size != 0):
            A = size1%patch_size
            padding_height = patch_size - A
            size1 = size1 + padding_height

          if(size2%patch_size != 0):
            A = size2%patch_size
            padding_width = patch_size - A
            size2 = size2 + padding_width
            
          reconstructed_image = np.zeros((size1,size2))
          count = 0
          for i in range(0,size1,patch_size):
            for j in range(0,size2,patch_size):
              reconstructed_image[i:i+patch_size,j:j+patch_size] = patch_list[count]
              count = count+1
          reconstructed_image_final = reconstructed_image[0:original_size[0],0:original_size[1]]
  
          return reconstructed_image_final

In [20]:
# Reconstruct a mask segmentation for a 3D image - take the image path, normalize the image and then put the dimension in the right order if you want
# a coronal or sagittal inference (view argument), generate patches, infer on it, reconstruct those patches and do it on each 2D slices, then fill the holes in the 3D image
# and re-order the dimension to have the standard (z-stack, height, width) dimensions

def reconstruction_3d_mask_segmentation_fortif(image3d_path, model, type='AXIAL'):
  
  # Select & Sort all 2D slices from a dicom folder

  img = imread(image3d_path)
  img = np.array(img)
  img = norm(img)

  if(type=='CORONAL'):
    img = np.transpose(img, (1,0,2))

  if(type=='SAGITTAL'):
    img = np.transpose(img, (2,0,1))

  inter_mask = np.zeros((img.shape))

  for i in range(0,img.shape[0]):
    patch = my_patch_generator_non_overlap_for_reconstruction(img[i],128)
    predictions = model.predict(patch)
    mask = np.argmax(predictions, axis=-1)
    reconstructed_mask = my_patch_reconstructor(mask,(img.shape[1], img.shape[2]),128)
    
    # Filling holes
    reconstructed_mask = scipy.ndimage.morphology.binary_fill_holes(reconstructed_mask, structure=None, output=None, origin=0)
    reconstructed_mask = reconstructed_mask*1 # to transform True, False in 0,1

    # A adapter en fonction de comment est construit le dossier DICOM

    inter_mask[i,:,:] = reconstructed_mask
    
  inter_mask = inter_mask.astype(int) # Retrouver les labels lisibles par Napari

  if(type=='CORONAL'):
    inter_mask = np.transpose(inter_mask, (1,0,2))

  if(type=='SAGITTAL'):
    inter_mask = np.transpose(inter_mask, (2,0,1))
    inter_mask = np.transpose(inter_mask, (2,0,1))

  #imsave(saved_path, final_mask)

  return inter_mask

In [21]:
# Same but with an image already loaded as an array in the notebook

def reconstruction_3d_mask_segmentation_fortif_imagearraydirectly(image3d, model, view='AXIAL'):
  
  # Select & Sort all 2D slices from a dicom folder

  img = image3d

  if(view=='CORONAL'):
    img = np.transpose(img, (1,0,2))

  if(view=='SAGITTAL'):
    img = np.transpose(final_mask, (2,0,1))

  final_mask = np.zeros((img.shape))

  for i in range(0,img.shape[0]):
    patch = my_patch_generator_non_overlap_for_reconstruction(img[i],128)
    predictions = model.predict(patch)
    mask = np.argmax(predictions, axis=-1)
    reconstructed_mask = my_patch_reconstructor(mask,(img.shape[1], img.shape[2]),128)
    
    # Filling holes
    reconstructed_mask = scipy.ndimage.morphology.binary_fill_holes(reconstructed_mask, structure=None, output=None, origin=0)
    reconstructed_mask = reconstructed_mask*1

    # A adapter en fonction de comment est construit le dossier DICOM

    final_mask[i,:,:] = reconstructed_mask

    
  final_mask = final_mask.astype(int) # Retrouver les labels lisibles par Napari

  if(view=='CORONAL'):
    final_mask = np.transpose(final_mask, (1,0,2))

  if(view=='SAGITTAL'):
    final_mask = np.transpose(final_mask, (2,0,1))
    final_mask = np.transpose(final_mask, (2,0,1))


  #imsave(saved_path, final_mask)

  return final_mask

In [22]:
def three_networks_vote(inference1, inference2, inference3):
  final_mask = inference1+inference2+inference3
  final_mask = np.where(final_mask<1, 0, final_mask)
  final_mask = np.where(final_mask>=1, 1, final_mask) 
  return final_mask

In [23]:
# Mask segmentation for a 3D image using 4 different model and a vote, load the image path
# The view argument is a 4 entry list to choose which view you'll use, e.g. ['AXIAL', 'SAGITTAL', 'CORONAL', 'AXIAL']

def four_vote_reconstruction(model1,model2,model3,image_path,view):
  inference1 = reconstruction_3d_mask_segmentation_fortif(image_path, model1, type=view[0])
  inference2 = reconstruction_3d_mask_segmentation_fortif(image_path, model2, type=view[1])
  inference3 = reconstruction_3d_mask_segmentation_fortif(image_path, model3, type=view[2])
  final_mask = inference3+inference1+inference2
  return final_mask

In [24]:
image_path='/home/sisyphe/sol-test/MOG/SH_04108-N.tif'
test=four_vote_reconstruction(model_axial, model_sagittal, model_coronal, image_path, view=['AXIAL', 'SAGITTAL', 'CORONAL'])



  reconstructed_mask = scipy.ndimage.morphology.binary_fill_holes(reconstructed_mask, structure=None, output=None, origin=0)




In [25]:
test.shape

(222, 512, 512)

In [26]:
imsave('/home/sisyphe/sol-test/code/test4108_AJ_CPP1_SPP3.tif', test)

  imsave('/home/sisyphe/sol-test/code/test4108_AJ_CPP1_SPP3.tif', test)
