## pipeline

In [12]:
import numpy as np
import os
import pickle
import copy
import edt
import matplotlib.pyplot as plt
import time
import cv2
import pandas as pd
from sklearn.metrics.cluster import adjusted_rand_score
from skimage.metrics import adapted_rand_error

import torch
from torch import from_numpy as from_numpy
from torchsummary import summary

from func.run_pipeline_super_vox import segment_super_vox_3_channel, semantic_segment_crop_and_cat_3_channel_output, \
img_3d_erosion_or_expansion, segment_super_vox_2_channel, semantic_segment_crop_and_cat_2_channel_output, \
generate_super_vox_by_watershed, get_outlayer_of_a_3d_shape, get_crop_by_pixel_val, Cluster_Super_Vox, assign_boudary_voxels_to_cells_with_watershed, \
delete_too_small_cluster, reassign
from func.run_pipeline import segment, assign_boudary_voxels_to_cells, dbscan_of_seg, semantic_segment_crop_and_cat
from func.cal_accuracy import IOU_and_Dice_Accuracy, VOI
from func.network import VoxResNet, CellSegNet_basic_lite
from func.unet_3d_basic import UNet3D_basic
from func.ultis import save_obj, load_obj

### init model

In [13]:
# model=UNet3D_basic(in_channels = 1, out_channels = 3)
# load_path=''
# model=VoxResNet(input_channel=1, n_classes=3, output_func = "softmax")
# load_path=""
model=CellSegNet_basic_lite(input_channel=1, n_classes=3, output_func = "softmax")
load_path='output/model_HMS_delete_fake_cells.pkl'
checkpoint = torch.load(load_path)
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

#summary(model, (1, 64, 64, 64))

CellSegNet_basic_lite(
  (conv1): Conv3d(1, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (bnorm1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (resmodule1): ResModule(
    (batchnorm_module): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_module): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (conv4): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (resmodule2): ResModule(
    (batchnorm_module): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_module): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (conv5): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (resmodule3): R

### dataset info

In [14]:
HMS_data_dict = load_obj("dataset_info/HMS_dataset_info")
HMS_data_dict_test = HMS_data_dict["test"]

### seg one img

parameter setting

In [15]:
# we do not input the whole raw image to the model one time but input raw image crops
crop_cube_size=64
stride=32

# hyperparameter for TASCAN, min touching area of two super pixels if they belong to the same cell
min_touching_area=30

choose a test image and load it

In [16]:
print("Test cases: "+str(HMS_data_dict_test.keys()))
case = "135"
print("for test case "+str(case)+" : "+str(HMS_data_dict_test[case]))

# you may load the image using another path
raw_img=np.load(HMS_data_dict_test[case]["raw"]).astype(float)
hand_seg=np.load(HMS_data_dict_test[case]["ins"]).astype(float)

Test cases: dict_keys(['135', '120', '65', '90'])
for test case 135 : {'raw': 'data/CellSeg_dataset/HMS_processed/raw/135.npy', 'background': 'data/CellSeg_dataset/HMS_processed/segmentation/135/135_background_3d_mask.npy', 'boundary': 'data/CellSeg_dataset/HMS_processed/segmentation/135/135_boundary_3d_mask.npy', 'foreground': 'data/CellSeg_dataset/HMS_processed/segmentation/135/135_foreground_3d_mask.npy', 'ins': 'data/CellSeg_dataset/HMS_processed/segmentation/135/135_ins.npy'}


feed raw image crops to the model

In [17]:
start = time.time()

# feed the raw img to the model
print('Feed raw img to model')
raw_img_size=raw_img.shape

seg_background_comp = np.zeros(raw_img_size)
seg_boundary_comp = np.zeros(raw_img_size)

transposes = [[0,1,2]]#,[2,0,1],[0,2,1]]
reverse_transposes = [[0,1,2]]#,[1,2,0],[0,2,1]]

for idx, transpose in enumerate(transposes):
    print(str(idx+1)+": Transpose the image to be: "+str(transpose))
    with torch.no_grad():
        seg_img=\
        semantic_segment_crop_and_cat_3_channel_output(raw_img.transpose(transpose), model, device, crop_cube_size=crop_cube_size, stride=stride)
    seg_img_background=seg_img['background']
    seg_img_boundary=seg_img['boundary']
    seg_img_foreground=seg_img['foreground']
    torch.cuda.empty_cache()

    # argmax
    print('argmax', end='\r')
    seg=[]
    seg.append(seg_img_background)
    seg.append(seg_img_boundary)
    seg.append(seg_img_foreground)
    seg=np.array(seg)
    seg_argmax=np.argmax(seg, axis=0)
    # probability map to 0 1 segment
    seg_background=np.zeros(seg_img_background.shape)
    seg_background[np.where(seg_argmax==0)]=1
    seg_foreground=np.zeros(seg_img_foreground.shape)
    seg_foreground[np.where(seg_argmax==2)]=1
    seg_boundary=np.zeros(seg_img_boundary.shape)
    seg_boundary[np.where(seg_argmax==1)]=1

    seg_background=seg_background.transpose(reverse_transposes[idx])
    seg_foreground=seg_foreground.transpose(reverse_transposes[idx])
    seg_boundary=seg_boundary.transpose(reverse_transposes[idx])

    seg_background_comp+=seg_background
    seg_boundary_comp+=seg_boundary
#print("Get model semantic seg by combination")
seg_background_comp = np.array(seg_background_comp>0, dtype=float)
seg_boundary_comp = np.array(seg_boundary_comp>0, dtype=float)
seg_foreground_comp = np.array(1 - seg_background_comp - seg_boundary_comp>0, dtype=float)

end = time.time()

print("Time elapsed: ", end - start)

Feed raw img to model
1: Transpose the image to be: [0, 1, 2]
argmaxs of segment_3d_img: 99%Time elapsed:  230.41491889953613


TASCAN

generate super vox by watershed

In [18]:
# Generate super vox by watershed
how_close_are_the_super_vox_to_boundary=2
min_touching_percentage=0.51
#min_touching_percentage=0.02

seg_foreground_erosion=1-img_3d_erosion_or_expansion(1-seg_foreground_comp, kernel_size=how_close_are_the_super_vox_to_boundary+1, device=device)
seg_foreground_super_voxel_by_ws = generate_super_vox_by_watershed(seg_foreground_erosion)


super voxel clustering

In [19]:
import multiprocess

In [28]:
result_df_final = pd.DataFrame(columns=["min_touching_area", "split", "merge"])
result_list = []

In [23]:
def img_3d_interpolate(img_3d, output_size, device = torch.device('cpu'), mode='nearest'):
    img_3d = img_3d.reshape(1,1,img_3d.shape[0],img_3d.shape[1],img_3d.shape[2])
    img_3d=torch.from_numpy(img_3d).float().to(device)
    img_3d=torch.nn.functional.interpolate(img_3d, size=output_size, mode='nearest')
    img_3d=img_3d.detach().cpu().numpy()
    img_3d=img_3d.reshape(img_3d.shape[2],img_3d.shape[3],img_3d.shape[4])

    return img_3d

In [29]:
for value in range(10):
    seg_foreground_erosion=1-img_3d_erosion_or_expansion(1-seg_foreground_comp, kernel_size=how_close_are_the_super_vox_to_boundary+1, device=device)
    seg_foreground_super_voxel_by_ws = generate_super_vox_by_watershed(seg_foreground_erosion)

    min_touching_area = (value+1)*5
    # Super voxel clustering
    cluster_super_vox=Cluster_Super_Vox(min_touching_area=min_touching_area, min_touching_percentage=min_touching_percentage)
    cluster_super_vox.fit(seg_foreground_super_voxel_by_ws)
    seg_foreground_single_cell_with_boundary = cluster_super_vox.output_3d_img

    # Delete too small cells
    min_cell_size_threshold=10

    seg_foreground_single_cell_with_boundary_delete_too_small = delete_too_small_cluster(seg_foreground_single_cell_with_boundary, threshold=min_cell_size_threshold)

    # Assign boudary voxels to their nearest cells
    seg_final=assign_boudary_voxels_to_cells_with_watershed(seg_foreground_single_cell_with_boundary_delete_too_small, seg_boundary_comp, seg_background_comp, compactness=1)

    # Reassign unique numbers
    seg_final=reassign(seg_final)

    # Evaluate with Metrics

    import func.cal_accuracy

    #accuracy_result = iou_and_dice_accuracy.cal_accuracy()
    voi = VOI(seg_final.astype(int), hand_seg.astype(int))

    print(f"VOI for {min_touching_area}:")
    print(voi)
    accuracy=IOU_and_Dice_Accuracy(img_3d_interpolate(hand_seg, output_size = seg_final.shape), img_3d_interpolate(seg_final, output_size = seg_final.shape))
    accuracy_record=accuracy.cal_accuracy_II()
    iou=np.array(accuracy_record[:,1]>0.7, dtype=float)
    print(str(sum(iou)/len(iou)))
    dice=np.array(accuracy_record[:,2]>0.7, dtype=float)
    print(str(sum(dice)/len(dice)))


    #result_df = result_df_final.append({"min_touching_area": min_touching_area,
    #                              "split": voi[0],
    #                              "merge": voi[1]},
    #                             ignore_index=True)
    result_list.append({"min_touching_area": min_touching_area,
                        "split": voi[0],
                        "merge": voi[1],
                        "iou": str(sum(iou)/len(iou)),
                        "dice": str(sum(dice)/len(dice))})



#pool_obj = multiprocess.Pool(8)
#pool_obj.map(run_voi, range(18))

reassign unique numbers progress: 0.99723756906077348]]37]724]ayer_area: 0current_crop_outlayer_area: 1current_crop_outlayer_area: 83current_crop_outlayer_area: 60current_crop_outlayer_area: 8current_crop_outlayer_area: 0current_crop_outlayer_area: 23current_crop_outlayer_area: 15current_crop_outlayer_area: 0current_crop_outlayer_area: 6current_crop_outlayer_area: 6current_crop_outlayer_area: 4current_crop_outlayer_area: 10current_crop_outlayer_area: 57current_crop_outlayer_area: 0current_crop_outlayer_area: 6current_crop_outlayer_area: 0current_crop_outlayer_area: 59current_crop_outlayer_area: 0current_crop_outlayer_area: 18current_crop_outlayer_area: 0current_crop_outlayer_area: 62current_crop_outlayer_area: 51current_crop_outlayer_area: 4current_crop_outlayer_area: 93current_crop_outlayer_area: 50current_crop_outlayer_area: 15current_crop_outlayer_area: 61current_crop_outlayer_area: 64current_crop_outlayer_area: 54current_crop_outlayer_area: 38current_crop_outlayer_area: 138cur

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  iou=np.array(accuracy_record[:,1]>0.7, dtype=np.float)
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dice=np.array(accuracy_record[:,2]>0.7, dtype=np.float)


0.44341801385681295
0.5935334872979214
VOI for 10:ique numbers progress: 0.994962216624685100]610]1724]ayer_area: 10current_crop_outlayer_area: 3current_crop_outlayer_area: 8current_crop_outlayer_area: 105current_crop_outlayer_area: 0current_crop_outlayer_area: 0current_crop_outlayer_area: 32current_crop_outlayer_area: 0current_crop_outlayer_area: 16current_crop_outlayer_area: 3current_crop_outlayer_area: 2current_crop_outlayer_area: 67current_crop_outlayer_area: 3current_crop_outlayer_area: 0current_crop_outlayer_area: 57current_crop_outlayer_area: 0current_crop_outlayer_area: 6current_crop_outlayer_area: 0current_crop_outlayer_area: 0current_crop_outlayer_area: 125current_crop_outlayer_area: 21current_crop_outlayer_area: 2current_crop_outlayer_area: 0current_crop_outlayer_area: 22current_crop_outlayer_area: 0current_crop_outlayer_area: 99current_crop_outlayer_area: 2current_crop_outlayer_area: 1current_crop_outlayer_area: 8current_crop_outlayer_area: 19current_crop_outlayer_area: 

In [30]:
result_list

[{'min_touching_area': 5,
  'split': 1.0100333438059486,
  'merge': 1.097907684622354,
  'iou': '0.44341801385681295',
  'dice': '0.5935334872979214'},
 {'min_touching_area': 10,
  'split': 1.0472691240646186,
  'merge': 1.0770707086971962,
  'iou': '0.43648960739030024',
  'dice': '0.5981524249422633'},
 {'min_touching_area': 15,
  'split': 1.0839207736338081,
  'merge': 1.0653870775371244,
  'iou': '0.43187066974595845',
  'dice': '0.5912240184757506'},
 {'min_touching_area': 20,
  'split': 1.147748133128398,
  'merge': 1.0327874437050772,
  'iou': '0.4295612009237875',
  'dice': '0.5912240184757506'},
 {'min_touching_area': 25,
  'split': 1.1558894390221717,
  'merge': 1.0276680498378403,
  'iou': '0.4295612009237875',
  'dice': '0.5889145496535797'},
 {'min_touching_area': 30,
  'split': 1.1643798217879262,
  'merge': 1.0203083411909382,
  'iou': '0.43187066974595845',
  'dice': '0.5866050808314087'},
 {'min_touching_area': 35,
  'split': 1.1764786558050049,
  'merge': 1.0200916785

In [31]:
with open("parameter_sensitivity_hms.pickle", "wb") as f:
    pickle.dump(result_list, f)