## pipeline

In [8]:
import numpy as np
import os
import pickle
import copy
import edt
import matplotlib.pyplot as plt
import time
import cv2
import pandas as pd
from sklearn.metrics.cluster import adjusted_rand_score
from skimage.metrics import adapted_rand_error
import h5py

import torch
from torch import from_numpy as from_numpy
from torchsummary import summary

from func.run_pipeline_super_vox import segment_super_vox_2_channel, semantic_segment_crop_and_cat_2_channel_output, img_3d_erosion_or_expansion, \
generate_super_vox_by_watershed, get_outlayer_of_a_3d_shape, get_crop_by_pixel_val, Cluster_Super_Vox, assign_boudary_voxels_to_cells_with_watershed, \
delete_too_small_cluster, reassign
from func.run_pipeline import segment, assign_boudary_voxels_to_cells, dbscan_of_seg, semantic_segment_crop_and_cat
from func.cal_accuracy import IOU_and_Dice_Accuracy, VOI
from func.network import VoxResNet, CellSegNet_basic_lite
from func.unet_3d_basic import UNet3D_basic
from func.ultis import save_obj, load_obj

### init model

In [9]:
#model=UNet3D_basic(in_channels = 1, out_channels = 2)
#load_path=''
#model=VoxResNet(input_channel=1, n_classes=2, output_func = "softmax")
#load_path=''
model=CellSegNet_basic_lite(input_channel=1, n_classes=2, output_func = "softmax")
load_path='output/model_LRP_retrained.pkl'
checkpoint = torch.load(load_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

#summary(model, (1, 64, 64, 64))

CellSegNet_basic_lite(
  (conv1): Conv3d(1, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bnorm1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (resmodule1): ResModule(
    (batchnorm_module): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_module): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (conv4): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (resmodule2): ResModule(
    (batchnorm_module): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_module): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (conv5): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (resmodule3): R

### dataset info

In [10]:
data_dict = load_obj("dataset_info/LRP_dataset_info")
data_dict_test = data_dict["test"]

### seg one img

parameter setting

In [11]:
# we do not input the whole raw image to the model one time but input raw image crops
crop_cube_size=128
stride=64

# hyperparameter for TASCAN, min touching area of two super pixels if they belong to the same cell
min_touching_area=30

choose a test image and load it

In [12]:
print("there are test imgs: "+str(data_dict_test.keys()))
case = 'Movie2_T00010_crop_gt.h5'
print("for test case "+str(case)+" : "+str(data_dict_test[case]))

there are test imgs: dict_keys(['Movie2_T00010_crop_gt.h5', 'Movie1_t00006_crop_gt.h5', 'Movie1_t00045_crop_gt.h5', 'Movie2_T00020_crop_gt.h5'])
for test case Movie2_T00010_crop_gt.h5 : data/CellSeg_dataset/LateralRootPrimordia_processed_wide_boundary/test/Movie2_T00010_crop_gt.h5


In [13]:
hf = h5py.File(data_dict_test[case], 'r')
print(hf.keys())
raw_img = np.array(hf["raw"], dtype=np.float)
hand_seg = np.array(hf["ins"], dtype=np.float)
boundary_gt = np.array(hf["boundary"], dtype=np.float)
background_gt = np.array(hf["background"], dtype=np.float)
foreground_gt = np.array(hf["foreground"], dtype=np.float)

print("raw_img shape: "+str(raw_img.shape))
print("hand_seg shape: "+str(hand_seg.shape))
raw_img_shape = raw_img.shape
hand_seg_shape = hand_seg.shape

<KeysViewHDF5 ['background', 'boundary', 'edge_background', 'edge_boundary', 'edge_foreground', 'foreground', 'ins', 'raw']>


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  raw_img = np.array(hf["raw"], dtype=np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  hand_seg = np.array(hf["ins"], dtype=np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  boundary_gt = np.array(hf["boundary"], dtype=np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  background_gt = np.array(hf["background"], dtype=np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  foreground_gt = np.array(hf["foreground"], dtype=np.float)


raw_img shape: (201, 225, 795)
hand_seg shape: (201, 225, 795)


feed raw image crops to the model

In [None]:
start = time.time()

transposes = [[0,1,2],[2,0,1],[0,2,1]]#,[1,0,2]]
reverse_transposes = [[0,1,2],[1,2,0],[0,2,1]]#,[1,0,2]]

# feed the raw img to the model
print('Feed raw img to model. Use different transposes')
raw_img_size=raw_img.shape
seg_boundary_comp = np.zeros(raw_img_size)
seg_img_boundary_comp = np.zeros(raw_img_size)
for idx, transpose in enumerate(transposes):
    print(str(idx+1)+": Transpose the image to be: "+str(transpose))
    with torch.no_grad():
        seg_img=\
        semantic_segment_crop_and_cat_2_channel_output(raw_img.transpose(transpose), model, device,
                                                       crop_cube_size=crop_cube_size, stride=stride)
    seg_img_boundary=seg_img['boundary']
    seg_img_foreground=seg_img['foreground']
    torch.cuda.empty_cache()
    
    # argmax
    print('argmax', end='\r')
    # probability map to 0 1 segment
    seg_foreground=np.array(seg_img_foreground-seg_img_boundary>0, dtype=np.int)
    seg_boundary=1 - seg_foreground
        
    seg_foreground=seg_foreground.transpose(reverse_transposes[idx])
    seg_boundary=seg_boundary.transpose(reverse_transposes[idx])
    seg_img_foreground=seg_img_foreground.transpose(reverse_transposes[idx])
    seg_img_boundary=seg_img_boundary.transpose(reverse_transposes[idx])
    
    seg_boundary_comp+=seg_boundary
    seg_img_boundary_comp+=seg_img_boundary

print("Get model semantic seg by combination")
seg_boundary_comp = np.array(seg_boundary_comp>0, dtype=np.int)
seg_foreground_comp = 1 - seg_boundary_comp

end = time.time()

print("Time elapsed: ", end - start)

Feed raw img to model. Use different transposes
1: Transpose the image to be: [0, 1, 2]
argmaxs of segment_3d_img: 99%

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  seg_foreground=np.array(seg_img_foreground-seg_img_boundary>0, dtype=np.int)


2: Transpose the image to be: [2, 0, 1]
argmaxs of segment_3d_img: 99%3: Transpose the image to be: [0, 2, 1]
Progress of segment_3d_img: 81%

In [None]:
# show current result

N=200

plt.figure()
plt.title("raw_img")
plt.imshow(raw_img[:,:,N])
plt.figure()
plt.title("hand_seg")
plt.imshow(reassign(hand_seg[:,:,N]))
plt.figure()
plt.title("model_seg_boundary")
plt.imshow(seg_boundary_comp[:,:,N])

TASCAN

generate super vox by watershed

In [None]:
# Generate super vox by watershed
how_close_are_the_super_vox_to_boundary=2
min_touching_percentage=0.51

seg_foreground_erosion=1-img_3d_erosion_or_expansion(1-seg_foreground_comp, kernel_size=how_close_are_the_super_vox_to_boundary+1, device=device)
seg_foreground_super_voxel_by_ws = generate_super_vox_by_watershed(seg_foreground_erosion)
# from skimage.measure import label
# from skimage.segmentation import join_segmentations, watershed
# from skimage.feature import peak_local_max
# seg_foreground_edt=edt.edt(np.array(seg_foreground_erosion, dtype=np.uint32, order='F'),
#                            black_border=True, order='F',parallel=1)
# min_distance_between_cells = 5
# coords = peak_local_max(seg_foreground_edt, min_distance=min_distance_between_cells,
#                         labels=np.array(seg_foreground_erosion>0))
# mask = np.zeros(seg_foreground_edt.shape, dtype=bool)
# mask[tuple(coords.T)] = True
# markers = label(mask==True)
# seg_foreground_super_voxel_by_ws = watershed(-seg_foreground_edt, markers=markers, mask=np.array(seg_foreground_comp>0), connectivity=min_touching_area)

In [None]:
print("There are "+str(len(np.unique(seg_foreground_super_voxel_by_ws)))+" super voxels")

super voxel clustearing

In [None]:
#Super voxel clustering
cluster_super_vox=Cluster_Super_Vox(min_touching_area=min_touching_area, min_touching_percentage=min_touching_percentage)
cluster_super_vox.fit(copy.deepcopy(seg_foreground_super_voxel_by_ws))
seg_foreground_single_cell_with_boundary = cluster_super_vox.output_3d_img

delete too small cells

In [None]:
# Delete too small cells
min_cell_size_threshold=10
seg_foreground_single_cell_with_boundary = delete_too_small_cluster(seg_foreground_single_cell_with_boundary, threshold=min_cell_size_threshold)

assign boudary voxels to their nearest cells

In [None]:
# Assign boudary voxels to their nearest cells
seg_final=assign_boudary_voxels_to_cells_with_watershed(seg_foreground_single_cell_with_boundary, seg_boundary_comp, compactness=1)

see the final result

In [None]:
N=250
#print("There are "+str(len(np.unique(seg_foreground_single_cell_with_boundary)))+" cells")
plt.figure()
plt.title("raw_img")
plt.imshow(raw_img[:,:,N])
plt.figure()
plt.title("hand_seg")
plt.imshow(reassign(hand_seg[:,:,N]))
plt.figure()
plt.title("model_seg")
plt.imshow(reassign(seg_final[:,:,N]))

In [None]:
def colorful_seg(seg):
    unique_vals, val_counts = np.unique(seg, return_counts=True)
    
    background_val = unique_vals[np.argsort(val_counts)[::-1][0]]
    
    seg_RGB = []
    for i in range(seg.shape[0]):
        mask_gray = cv2.normalize(src=seg[i,:,:], dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
        seg_slice_RGB = cv2.cvtColor(mask_gray, cv2.COLOR_GRAY2RGB)
        seg_RGB.append(seg_slice_RGB)
    seg_RGB = np.array(seg_RGB)
    
    for idx, unique_val in enumerate(unique_vals):
        print(str(idx/len(unique_vals)), end="\r")
        if unique_val == background_val:
            COLOR = np.array([0,0,0], dtype=int)
        else:
            COLOR = np.array(np.random.choice(np.arange(256), size=3, replace=False), dtype=int)
        
        locs = np.where(seg==unique_val)
        
        for i in range(3):
            seg_RGB[locs[0], locs[1], locs[2], i] = COLOR[i]
        
    return seg_RGB

In [None]:
seg_final_RGB = colorful_seg(seg_final)
hand_seg_RGB = colorful_seg(hand_seg)

In [None]:
N=250

fig = plt.figure(figsize=(5,5))
plt.axis('off')
plt.imshow(seg_final_RGB[:,:,N],cmap="gray")
#plt.savefig('seg_final_RGB_'+str(N)+'.png',bbox_inches='tight',dpi=fig.dpi,pad_inches=0.0)