## pipeline

In [1]:
import numpy as np
import os
import pickle
import copy
import edt
import matplotlib.pyplot as plt
import time
import cv2
import pandas as pd
from sklearn.metrics.cluster import adjusted_rand_score
from skimage.metrics import adapted_rand_error

import torch
from torch import from_numpy as from_numpy
from torchsummary import summary

from func.run_pipeline_super_vox import segment_super_vox_3_channel, semantic_segment_crop_and_cat_3_channel_output, img_3d_erosion_or_expansion, \
generate_super_vox_by_watershed, get_outlayer_of_a_3d_shape, get_crop_by_pixel_val, Cluster_Super_Vox, assign_boudary_voxels_to_cells_with_watershed, \
delete_too_small_cluster, reassign
from func.run_pipeline import segment, assign_boudary_voxels_to_cells, dbscan_of_seg, semantic_segment_crop_and_cat
from func.cal_accuracy import IOU_and_Dice_Accuracy, VOI
from func.network import VoxResNet, CellSegNet_basic_lite
from func.unet_3d_basic import UNet3D_basic
from func.ultis import save_obj, load_obj

In [None]:
HMS_data_dict = load_obj("dataset_info/HMS_dataset_info")
HMS_data_dict_test = HMS_data_dict["test"]
print("Test cases: "+str(HMS_data_dict_test.keys()))
case = "135"
print("for test case "+str(case)+" : "+str(HMS_data_dict_test[case]))

# you may load the image using another path
raw_img=np.load(HMS_data_dict_test[case]["raw"]).astype(float)
hand_seg=np.load(HMS_data_dict_test[case]["ins"]).astype(float)

# np.save('seg_foreground_super_voxel_by_ws_graph.npy', seg_foreground_super_voxel_by_ws)

In [3]:
seg_foreground_super_voxel_by_ws = np.load('seg_foreground_super_voxel_by_ws_graph.npy')

In [5]:
len(np.unique(seg_foreground_super_voxel_by_ws))

1951

In [6]:
# get the size of each super voxel

unique, counts = np.unique(seg_foreground_super_voxel_by_ws, return_counts=True)
dict(zip(unique, counts))

{0: 9125466,
 1: 69,
 2: 35,
 3: 3,
 4: 852,
 5: 53,
 6: 6,
 7: 89,
 8: 36,
 9: 222,
 10: 394,
 11: 7,
 12: 34,
 13: 114,
 14: 9,
 15: 132,
 16: 52,
 17: 108,
 18: 21,
 19: 110,
 20: 262,
 21: 325,
 22: 14,
 23: 181,
 24: 60,
 25: 6,
 26: 161,
 27: 104,
 28: 5,
 29: 15,
 30: 9,
 31: 58,
 32: 5,
 33: 156,
 34: 1,
 35: 9,
 36: 8,
 37: 307,
 38: 2126,
 39: 338,
 40: 3,
 41: 240,
 42: 275,
 43: 66,
 44: 31,
 45: 260,
 46: 277,
 47: 576,
 48: 306,
 49: 273,
 50: 147,
 51: 105,
 52: 73,
 53: 62,
 54: 16,
 55: 289,
 56: 85,
 57: 860,
 58: 495,
 59: 127,
 60: 71,
 61: 831,
 62: 218,
 63: 9,
 64: 6,
 65: 4,
 66: 15,
 67: 191,
 68: 41,
 69: 87,
 70: 8,
 71: 7,
 72: 515,
 73: 543,
 74: 474,
 75: 3,
 76: 172,
 77: 44,
 78: 410,
 79: 53,
 80: 214,
 81: 603,
 82: 36,
 83: 4,
 84: 5,
 85: 8,
 86: 28,
 87: 14,
 88: 169,
 89: 54,
 90: 865,
 91: 11,
 92: 12,
 93: 155,
 94: 2,
 95: 29,
 96: 22,
 97: 81,
 98: 4,
 99: 302,
 100: 3,
 101: 4,
 102: 433,
 103: 195,
 104: 6,
 105: 617,
 106: 162,
 107: 22,
 10

In [20]:
class Super_Vox_To_Graph_depr():
    def __init__(self, boundary_extend=2):
        super(Super_Vox_To_Graph, self).__init__
        self.boundary_extend = boundary_extend

        self.UN_PROCESSED = 0
        self.LONELY_POINT = -1
        self.A_LARGE_NUM = 100000000

    def fit(self, input_3d_img, restrict_area_3d=None):
        self.input_3d_img = input_3d_img

        if restrict_area_3d is None:
            self.restrict_area_3d = np.array(input_3d_img==0, dtype=np.int8)
        else:
            self.restrict_area_3d = restrict_area_3d

        unique_vals, unique_val_counts = np.unique(self.input_3d_img, return_counts=True)
        unique_val_counts = unique_val_counts[unique_vals>0]
        unique_vals = unique_vals[unique_vals>0]
        sort_locs = np.argsort(unique_val_counts)[::-1]
        self.unique_vals = unique_vals[sort_locs]

        self.val_labels = dict()
        for unique_val in self.unique_vals:
            self.val_labels[unique_val] = self.UN_PROCESSED

        self.val_outlayer_area = dict()
        for idx, unique_val in enumerate(self.unique_vals):
            # print("get val_outlayer area of all vals: "+str(idx/len(self.unique_vals)))
            self.val_outlayer_area[unique_val] = self.A_LARGE_NUM

        """
        neighborhoods_dict:
        {
            voxel: {
                        neighbor_1: touching area,
                        neighbor_2: touching area
                    }
            ...
        }
        """
        neighborhoods_dict = {}
        for idx, current_val in enumerate(self.unique_vals):
            # print('processing: '+str(idx/len(self.unique_vals))+' pixel val: '+str(current_val))
            if self.val_labels[current_val]!=self.UN_PROCESSED:
                continue
            valid_neighbor_vals = self.regionQuery(current_val)
            neighborhoods_dict[current_val] = valid_neighbor_vals
            # if len(valid_neighbor_vals)>0:
            #     # print('Assign label '+str(current_val)+' to current val\'s neighbors: '+str(valid_neighbor_vals))
            #    self.val_labels[current_val] = current_val
            #    self.growCluster(valid_neighbor_vals, current_val)
            # else:
            #    self.val_labels[current_val] = self.LONELY_POINT

        # self.output_3d_img = self.input_3d_img
        return neighborhoods_dict

    def get_outlayer_area(self, current_val):
        current_crop_img, current_restrict_area = get_crop_by_pixel_val(self.input_3d_img, current_val,
                                                                        boundary_extend=self.boundary_extend,
                                                                        crop_another_3d_img_by_the_way=self.restrict_area_3d)
        current_crop_img_onehot = np.array(current_crop_img==current_val, dtype=np.int8)
        current_crop_img_onehot_outlayer = get_outlayer_of_a_3d_shape(current_crop_img_onehot)

        assert current_crop_img_onehot_outlayer.shape == current_restrict_area.shape

        current_crop_img_onehot_outlayer[current_restrict_area>0]=0
        current_crop_outlayer_area = np.sum(current_crop_img_onehot_outlayer)

        return current_crop_outlayer_area

    def regionQuery(self, current_val):
        current_crop_img, current_restrict_area = get_crop_by_pixel_val(self.input_3d_img, current_val,
                                                                        boundary_extend=self.boundary_extend,
                                                                        crop_another_3d_img_by_the_way=self.restrict_area_3d)

        current_crop_img_onehot = np.array(current_crop_img==current_val, dtype=np.int8)
        current_crop_img_onehot_outlayer = get_outlayer_of_a_3d_shape(current_crop_img_onehot)

        assert current_crop_img_onehot_outlayer.shape == current_restrict_area.shape

        current_crop_img_onehot_outlayer[current_restrict_area>0]=0
        current_crop_outlayer_area = np.sum(current_crop_img_onehot_outlayer)

        neighbor_vals, neighbor_val_counts = np.unique(current_crop_img[current_crop_img_onehot_outlayer>0], return_counts=True)
        neighbor_val_counts = neighbor_val_counts[neighbor_vals>0]
        neighbor_vals = neighbor_vals[neighbor_vals>0]

        print("current_crop_outlayer_area: "+str(current_crop_outlayer_area))

        valid_neighbor_vals = self.neighborCheck(neighbor_vals, neighbor_val_counts, current_crop_outlayer_area)


        print("valid_neighbor_vals: "+str(valid_neighbor_vals))

        return valid_neighbor_vals

    def neighborCheck(self, neighbor_vals, neighbor_val_counts, current_crop_outlayer_area):
        neighbor_val_counts = neighbor_val_counts[neighbor_vals>0]
        neighbor_vals = neighbor_vals[neighbor_vals>0]

        valid_neighbor_vals_dict = {}
        for idx, neighbor_val in enumerate(neighbor_vals):
            print("touching_area: "+str(neighbor_val_counts[idx]), end="\r")
            valid_neighbor_vals_dict[neighbor_val] = neighbor_val_counts[idx]

        # double_checked_valid_neighbor_vals = []
        # for valid_neighbor_val in valid_neighbor_vals_dict.keys():
        #    if self.val_labels[valid_neighbor_val]==self.UN_PROCESSED or \
        #     self.val_labels[valid_neighbor_val]==self.LONELY_POINT:
        #        double_checked_valid_neighbor_vals.append(valid_neighbor_val)

        return valid_neighbor_vals_dict

In [104]:
class Super_Vox_To_Graph():
    def __init__(self, boundary_extend=2):
        super(Super_Vox_To_Graph, self).__init__
        self.boundary_extend = boundary_extend

        self.UN_PROCESSED = 0
        self.LONELY_POINT = -1
        self.A_LARGE_NUM = 100000000

    def get_neighbors_and_touching_area(self, input_3d_img, restrict_area_3d=None):
        """
        Parameters
        ----------
        input_3d_img
        restrict_area_3d

        Returns numpy array with each column representing two super voxels touching
                -> shape: supervoxel_1, neighbor_1, touching_area(between supervoxel_1 and neighbor_1)
                          supervoxel_1, neighbor_2, touching_area(between supervoxel_1 and neighbor_2)
                          ...
        -------

        """
        self.input_3d_img = input_3d_img

        if restrict_area_3d is None:
            self.restrict_area_3d = np.array(input_3d_img==0, dtype=np.int8)
        else:
            self.restrict_area_3d = restrict_area_3d

        unique_vals, unique_val_counts = np.unique(self.input_3d_img, return_counts=True)
        unique_val_counts = unique_val_counts[unique_vals>0]
        unique_vals = unique_vals[unique_vals>0]
        sort_locs = np.argsort(unique_val_counts)[::-1]
        self.unique_vals = unique_vals[sort_locs]

        self.val_labels = dict()
        for unique_val in self.unique_vals:
            self.val_labels[unique_val] = self.UN_PROCESSED

        self.val_outlayer_area = dict()
        for idx, unique_val in enumerate(self.unique_vals):
            # print("get val_outlayer area of all vals: "+str(idx/len(self.unique_vals)))
            self.val_outlayer_area[unique_val] = self.A_LARGE_NUM

        """
        neighborhoods:
        np array:
        supervoxel1, neighbor_1, touching_area
        supervoxel1, neighbor_2, touching_area
        ...
        """
        neighborhoods = []
        for idx, current_val in enumerate(self.unique_vals):
            # print('processing: '+str(idx/len(self.unique_vals))+' pixel val: '+str(current_val))
            if self.val_labels[current_val]!=self.UN_PROCESSED:
                continue
            valid_neighbor_vals = self.regionQuery(current_val)
            if len(valid_neighbor_vals) != 0:
                neighborhoods.append(valid_neighbor_vals)
            # if len(valid_neighbor_vals)>0:
            #     # print('Assign label '+str(current_val)+' to current val\'s neighbors: '+str(valid_neighbor_vals))
            #    self.val_labels[current_val] = current_val
            #    self.growCluster(valid_neighbor_vals, current_val)
            # else:
            #    self.val_labels[current_val] = self.LONELY_POINT

        # self.output_3d_img = self.input_3d_img
        neighborhoods = np.vstack(neighborhoods)

        return neighborhoods

    def get_outlayer_area(self, current_val):
        current_crop_img, current_restrict_area = get_crop_by_pixel_val(self.input_3d_img, current_val,
                                                                        boundary_extend=self.boundary_extend,
                                                                        crop_another_3d_img_by_the_way=self.restrict_area_3d)
        current_crop_img_onehot = np.array(current_crop_img==current_val, dtype=np.int8)
        current_crop_img_onehot_outlayer = get_outlayer_of_a_3d_shape(current_crop_img_onehot)

        assert current_crop_img_onehot_outlayer.shape == current_restrict_area.shape

        current_crop_img_onehot_outlayer[current_restrict_area>0]=0
        current_crop_outlayer_area = np.sum(current_crop_img_onehot_outlayer)

        return current_crop_outlayer_area

    def regionQuery(self, current_val):
        current_crop_img, current_restrict_area = get_crop_by_pixel_val(self.input_3d_img, current_val,
                                                                        boundary_extend=self.boundary_extend,
                                                                        crop_another_3d_img_by_the_way=self.restrict_area_3d)

        current_crop_img_onehot = np.array(current_crop_img==current_val, dtype=np.int8)
        current_crop_img_onehot_outlayer = get_outlayer_of_a_3d_shape(current_crop_img_onehot)

        assert current_crop_img_onehot_outlayer.shape == current_restrict_area.shape

        current_crop_img_onehot_outlayer[current_restrict_area>0]=0
        current_crop_outlayer_area = np.sum(current_crop_img_onehot_outlayer)

        neighbor_vals, neighbor_val_counts = np.unique(current_crop_img[current_crop_img_onehot_outlayer>0], return_counts=True)
        neighbor_val_counts = neighbor_val_counts[neighbor_vals>0]
        neighbor_vals = neighbor_vals[neighbor_vals>0]

        print("current_crop_outlayer_area: "+str(current_crop_outlayer_area))

        valid_neighbor_vals = self.neighborCheck(current_val, neighbor_vals, neighbor_val_counts, current_crop_outlayer_area)


        print("valid_neighbor_vals: "+str(valid_neighbor_vals))

        return valid_neighbor_vals

    def neighborCheck(self, current_val, neighbor_vals, neighbor_val_counts, current_crop_outlayer_area):
        neighbor_val_counts = neighbor_val_counts[neighbor_vals>0]
        neighbor_vals = neighbor_vals[neighbor_vals>0]

        valid_neighbor_vals = np.empty((len(neighbor_vals), 3))
        valid_neighbor_vals[:,0] = current_val
        for idx, neighbor_val in enumerate(neighbor_vals):
            print("touching_area: "+str(neighbor_val_counts[idx]), end="\r")
            # valid_neighbor_vals_dict[neighbor_val] = neighbor_val_counts[idx]
            valid_neighbor_vals[idx, 1] = neighbor_val
            valid_neighbor_vals[idx, 2] = neighbor_val_counts[idx]

        # double_checked_valid_neighbor_vals = []
        # for valid_neighbor_val in valid_neighbor_vals_dict.keys():
        #    if self.val_labels[valid_neighbor_val]==self.UN_PROCESSED or \
        #     self.val_labels[valid_neighbor_val]==self.LONELY_POINT:
        #        double_checked_valid_neighbor_vals.append(valid_neighbor_val)

        return valid_neighbor_vals

    def add_ground_truth_node_labels(self, input_3d_img, groundtruth_img, neighbors_and_touching_area):
        # add ground truth column to neighbors_and_touching_area matrix
        neighbors_and_touching_area = np.c_[(neighbors_and_touching_area,
                                                np.zeros(len(neighbors_and_touching_area)))]
        unique_values = np.unique(neighbors_and_touching_area[:,0])

        # get the ground truth cell label for each super voxel
        groundtruth_labels = {}
        for idx, value in enumerate(unique_values):
            # get values of groundtruth that overlap with each supervoxel
            overlapping_voxels = groundtruth_img[np.where(input_3d_img == value)]
            # get the most occuring groundtruth voxel label for each supervoxel
            gt_label = np.bincount(overlapping_voxels.astype(int)).argmax()

            groundtruth_labels[value] = gt_label

        # set ground 4th column of neighbors_and_touching_area to 1
        # if both supervoxels have the same groundtruth cell label, 0 otherwise
        for idx, col in enumerate(neighbors_and_touching_area):
            if groundtruth_labels[col[0]] == groundtruth_labels[col[1]]:
                neighbors_and_touching_area[idx, 3] = 1

        return neighbors_and_touching_area

In [105]:
super_vox_to_graph = Super_Vox_To_Graph()

In [55]:
neighbors = super_vox_to_graph.get_neighbors_and_touching_area(seg_foreground_super_voxel_by_ws)

current_crop_outlayer_area: 4
valid_neighbor_vals: [[1.239e+03 1.326e+03 1.000e+00]
 [1.239e+03 1.374e+03 3.000e+00]]
current_crop_outlayer_area: 0
valid_neighbor_vals: []
current_crop_outlayer_area: 0
valid_neighbor_vals: []
current_crop_outlayer_area: 25
valid_neighbor_vals: [[507. 234.  25.]]
current_crop_outlayer_area: 10
valid_neighbor_vals: [[1297. 1245.   10.]]
current_crop_outlayer_area: 24
valid_neighbor_vals: [[346. 379.   1.]
 [346. 400.   9.]
 [346. 631.  14.]]
current_crop_outlayer_area: 2
valid_neighbor_vals: [[1648. 1256.    2.]]
current_crop_outlayer_area: 0
valid_neighbor_vals: []
current_crop_outlayer_area: 8
valid_neighbor_vals: [[167. 276.   3.]
 [167. 311.   5.]]
current_crop_outlayer_area: 60
valid_neighbor_vals: [[1858. 1636.   60.]]
current_crop_outlayer_area: 0
valid_neighbor_vals: []
current_crop_outlayer_area: 73
valid_neighbor_vals: [[1709. 1622.   57.]
 [1709. 1711.    6.]
 [1709. 1751.   10.]]
current_crop_outlayer_area: 11
valid_neighbor_vals: [[761. 870.

In [56]:
seg_foreground_super_voxel_by_ws
hand_seg

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [106]:
neighbors_with_gt = super_vox_to_graph.add_ground_truth_node_labels(seg_foreground_super_voxel_by_ws,
                                                                    hand_seg,
                                                                    neighbors)

In [107]:
neighbors_with_gt


array([[1.239e+03, 1.326e+03, 1.000e+00, 0.000e+00],
       [1.239e+03, 1.374e+03, 3.000e+00, 1.000e+00],
       [5.070e+02, 2.340e+02, 2.500e+01, 1.000e+00],
       ...,
       [1.409e+03, 1.429e+03, 1.000e+00, 1.000e+00],
       [1.655e+03, 1.637e+03, 1.000e+00, 1.000e+00],
       [1.291e+03, 1.292e+03, 1.000e+00, 1.000e+00]])