In [1]:
import cv2

class Utility_topo(object):
    
    # sort bnd or red by persistence, the function is dimension specific
    @staticmethod
    def sort_persis_results_by_persistence(t, bd, increase=True):
        assert(len(t) == len(bd))
        set_num = len(t)
        
        persistence = [0] * set_num
        for i in range(set_num):
            persistence[i] = bd[i][1] - bd[i][0]
        index = sorted(range(set_num), reverse=not increase, key=lambda k: persistence[k])
        t_sorted = [None] * set_num
        for i in range(set_num):
            t_sorted[i] = t[index[i]]
        return t_sorted
    
    @staticmethod
    def topo_filter(pd, threshold):
        bat_size = len(pd)
        target_topo_index = [None] * bat_size
        
        for i in range(bat_size):
            filtered_index = []
            for j in range(len(pd[i])):
                # pass if structure birth at 0
                if (pd[i][j][0] == 0):
                    continue
                if (pd[i][j][1] - pd[i][j][0] >= threshold):
                    filtered_index.append(j)
            target_topo_index[i] = filtered_index
            
        return target_topo_index
    
    @staticmethod
    def return_countour_with_p_inside(countours, p):
        countour_num = len(countours)
        record = []
        for i in range(countour_num):
            if cv2.pointPolygonTest(countours[i], p, False) >= 0:
                record.append(i)
        if len(record) == 1:
            return record[0]
        elif len(record) == 0:
            return -1
        else:
            print("return_countour_with_p_inside: unhandled case...")
            
    @staticmethod
    def dangling_edge_test(label, red, shape, kernel, iter_num, border_width):
        mask = np.zeros(shape)
        for i in range(shape[0]):
            for j in range(shape[1]):
                if red[1][i][j] == label:
                    mask[i][j] = 255
        mask_dp = cv2.dilate(mask, kernel, iterations=iter_num)
        mask_dp = cv2.erode(mask_dp, kernel, iterations=iter_num)
        diff_im = mask_dp - mask
        cv2.rectangle(diff_im, (0, 0), (shape[0]-1, shape[0]-1), (0), border_width)
        diff_im = diff_im / 255
        return np.sum(diff_im)
    
    @staticmethod
    def extract_crt_points(crt_, dim):
        # dim: designate dimension
        # 0-dim: [0], 1-dim: [1], 0 and 1-dim: [0, 1]
        data_dims = crt_.size()
        birth_x = []
        birth_y = []
        death_x = []
        death_y = []
        assert(max(dim) < data_dims)
        for item in dim:
            for j in range(len(crt_[item])):
                birth_x.append(crt_[item][j][0])
                birth_y.append(crt_[item][j][1])
                death_x.append(crt_[item][j][2])
                death_y.append(crt_[item][j][3])
        return birth_x, birth_y, death_x, death_y
    
    @staticmethod
    def compute_bnd_red_cv(img, low_th, high_th, connectivity):
        ret, thresh = cv2.threshold(img,low_th,high_th,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
        image, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        red = cv2.connectedComponents(thresh, connectivity, cv2.CV_32S)
        return contours, red
    
    @staticmethod
    def compute_bnd_red_cv_batch(data):
        bat_size = data.shape[0]
        bnd_res = [None] * bat_size
        red_res = [None] * bat_size
        for i in range(bat_size):
            bnd_res[i], red_res[i] = Utility_topo.compute_bnd_red_cv(data[i,:], 0, 255, 8)
        return bnd_res, red_res
    
    @staticmethod
    def compute_dist_homology(data, pc, debug=False):
        # computes distance transform first and persistence homology
        # defaults double type, and dim 1
        bat_size = data.shape[0]
        height   = data.shape[1]
        width    = data.shape[2]
        
        birth_x_list_ = [None] * bat_size
        birth_y_list_ = [None] * bat_size
        death_x_list_ = [None] * bat_size
        death_y_list_ = [None] * bat_size
        pd_list_      = [None] * bat_size
        
        crt_ = cpp_nested3_IntVector()
        pd_  = cpp_nested3_DoubleVector()
        flatten_cpp = cpp_nested1_DoubleVector(height * width)
        
        if debug:
            bnd_list_ = [None] * bat_size
            red_list_ = [None] * bat_size
            bnd_ = cpp_nested4_IntVector()
            red_ = cpp_nested4_IntVector()
        
        for i in range(bat_size):
            flatten_ = np.concatenate(Utility_general.nz2_nearest_z(data[i,:]))
            for j in range(height * width):
                flatten_cpp[j] = float(flatten_[j])
            pc.source_from_mat_from_double("dummy.dat", flatten_cpp, height, width)
            
            pc.run()
            pc.return_pers_V(crt_)
            pc.return_pers_BD(pd_)
            if debug:
                pc.return_bnd(bnd_)
                pc.return_red(red_)
                bnd_[1] = Utility_topo.sort_persis_results_by_persistence(bnd_[1], pd_[1], True)
                red_[1] = Utility_topo.sort_persis_results_by_persistence(red_[1], pd_[1], True)
                bnd_list_[i] = bnd_[1]
                red_list_[i] = red_[1]
            pc.clear()

            crt_[1] = Utility_topo.sort_persis_results_by_persistence(crt_[1], pd_[1], True)
            pd_[1]  = Utility_topo.sort_persis_results_by_persistence(pd_[1], pd_[1], True)
            birth_x, birth_y, death_x, death_y = Utility_topo.extract_crt_points(crt_, [1])
            birth_x_list_[i] = birth_x
            birth_y_list_[i] = birth_y
            death_x_list_[i] = death_x
            death_y_list_[i] = death_y
            pd_list_[i]      = pd_[1]
        
        if debug:
            return birth_x_list_, birth_y_list_, death_x_list_, death_y_list_, pd_list_, bnd_list_, red_list_
        else:
            return birth_x_list_, birth_y_list_, death_x_list_, death_y_list_, pd_list_