In [1]:
import cv2
import copy

class Utility_topo(object):
    
    @staticmethod
    def binarize_data(images, watershed):
        '''
        d should have value range between -1.0 and 1.0
        (output from tanh() layer).
        this function binarize generated images.
        '''
        d = copy.copy(images)
        d = np.squeeze(d)
        if len(d.shape) == 2:
            d = np.expand_dims(d, axis=0)
        d[d >  watershed] = 255.0
        d[d <= watershed] = 0.0
        d = d.astype(np.uint8)
        return d
    
    @staticmethod
    def convert_phc_pd_2_persim(phc_pd, dim):
        """ 
        convert from output of persistent homology [chao version]
        to persim style
        """
        assert(dim < len(phc_pd))
        num_ = len(phc_pd[dim])
        dgm_ = np.zeros([num_, 2])
        for i in range(num_):
            dgm_[i][0] = phc_pd[dim][i][0]
            dgm_[i][1] = phc_pd[dim][i][1]
        return dgm_
    
    @staticmethod
    def extract_dim_from_list(d, dim):
        """
        extract a single dimension from a list data
        d should have form N * [num_ * [dim1, dim2, ...]]
        where num_ depends on the particular item
        """
        assert(dim < len(d[0][0]))
        N   = len(d)
        res = [None] * N
        for i in range(N):
            l_ = [0.] * len(d[i])
            for j in range(len(d[i])):
                l_[j] = d[i][j][dim]
            res[i] = np.asarray(l_)
        return res
    
    @staticmethod
    def convert_phc_pd_2_persim_batch(phc_pd, dim_list):
        dgm_list = []
        for dim in dim_list:
            dgm_list.append(Utility_topo.convert_phc_pd_2_persim(phc_pd, dim))
        return dgm_list
    
    # sort bnd or red by persistence, the function is dimension specific
    @staticmethod
    def sort_persis_results_by_persistence(t, bd, increase=True):
        assert(len(t) == len(bd))
        set_num = len(t)
        
        persistence = [0] * set_num
        for i in range(set_num):
            persistence[i] = bd[i][1] - bd[i][0]
        index = sorted(range(set_num), reverse=not increase, key=lambda k: persistence[k])
        t_sorted = [None] * set_num
        for i in range(set_num):
            t_sorted[i] = t[index[i]]
        return t_sorted
    
    @staticmethod
    def topo_filter_retindex(pd, threshold):
        bat_size = len(pd)
        target_topo_index = [None] * bat_size
        
        for i in range(bat_size):
            filtered_index = []
            for j in range(len(pd[i])):
                # pass if structure birth at 0
                if (pd[i][j][0] == 0):
                    continue
                if (pd[i][j][1] - pd[i][j][0] >= threshold):
                    filtered_index.append(j)
            target_topo_index[i] = filtered_index
            
        return target_topo_index
    
    @staticmethod
    def topo_filter_retmat(pd, threshold):
        """
        pd should have form of batch_size * [(num_structure, 2)]
        where num_structure depends on the particular item
        """
        bat_size = len(pd)
        res_ = []
        for i in range(bat_size):
            shape_ = pd[i].shape
            pd_loc = []
            for j in range(shape_[0]):
                if pd[i][j][1] - pd[i][j][0] > threshold:
                    pd_loc.append(pd[i][j])
            res_.append(np.asarray(pd_loc))
        return res_
    
    @staticmethod
    def topo_filter_retmat_mul(B, D, PD, threshold):
        '''
        B, D, PD should be in form of batch_size * [(num_structure, 2)]
        which is the new_form outputs.
        This function is the same as topo_filter_retmat except that it
        filters B, D as well.
        '''
        bat_size = len(PD)
        B_  = []
        D_  = []
        PD_ = []
        for i in range(bat_size):
            shape_ = PD[i].shape
            b_loc  = []
            d_loc  = []
            pd_loc = []
            assert(B[i].shape[0] == shape_[0])
            assert(D[i].shape[0] == shape_[0])
            for j in range(shape_[0]):
                if PD[i][j, 1] - PD[i][j, 0] > threshold:
                    b_loc.append(B[i][j])
                    d_loc.append(D[i][j])
                    pd_loc.append(PD[i][j])
            B_.append(np.asarray(b_loc))
            D_.append(np.asarray(d_loc))
            PD_.append(np.asarray(pd_loc))
            
        if len(PD_) < bat_size:
            print("Warning: some file is killed by filter.")
        return B_, D_, PD_
    
    @staticmethod
    def topo_filter_retmat_bndorred_mul(data, PD, threshold):
        '''
        Note this function will NOT sort PD itself!
        bnd and red(data): batch_num * [num_structure * [point_num * [x, y]]]
        '''
        bat_size = len(PD)
        data_ = []
        for i in range(bat_size):
            shape_  = PD[i].shape
            dat_loc = []
            assert(len(data[i]) == shape_[0])
            for j in range(shape_[0]):
                if PD[i][j, 1] - PD[i][j, 0] > threshold:
                    dat_loc.append(data[i][j])
            data_.append(dat_loc)
            
        if len(data_) < bat_size:
            print("Warning: some file is killed by filter.")
        return data_
    
    @staticmethod
    def gather_sons_bnd_cv(h, father):
        '''
        gather all son indices for the father
        ===== inputs
        h: hierachy for single image, output from Utility_topo.compute_bnd_red_cv_batch
        father: integer
        '''
        sons = []
        if h[0, father, 2] != -1:
            p = h[0,:,3] 
            for i in range(len(p)):
                if p[i] == father:
                    sons.append(i)
        return sons

    @staticmethod
    def complete_contours_bnd_cv(h, contours):
        '''
        complete each structure with its direct sons.
        ===== inputs
        h: hierachy for single image, output from Utility_topo.compute_bnd_red_cv_batch
        contours: boundary or contour for single image, output from Utility_topo.compute_bnd_red_cv_batch
        '''
        contour_num = len(contours)
        contour_cpl = copy.copy(contours)
        for i in range(contour_num):
            sons = Utility_topo.gather_sons_bnd_cv(h, i)
            for j in range(len(sons)):
                contour_cpl[i] = np.vstack((contour_cpl[i], contours[sons[j]]))
        return contour_cpl
    
    @staticmethod
    def return_countour_with_p_inside(contours, hierarchy, p):
        '''
        return the contour index with p inside.
        contours CANNOT be output from complete_contours_bnd_cv.
        '''
        contour_num = len(contours)
        record = []
        pt_num = []
        for i in range(contour_num):
            if hierarchy[0, i, 3] == -1 and cv2.pointPolygonTest(contours[i], p, False) >= 0:
                record.append(i)
                pt_num.append(len(contours))
        if len(record) == 1:
            return record[0]
        elif len(record) == 0:
            return -1
        else:
            print("return_countour_with_p_inside: nested structures detected, return the closest boundary.")
            _, idx = min((min_val, min_idx) for (min_idx, min_val) in enumerate(pt_num))
            return record[idx]
            
    @staticmethod
    def dangling_edge_test(label, red, shape, kernel, iter_num, border_width):
        mask = np.zeros(shape)
        for i in range(shape[0]):
            for j in range(shape[1]):
                if red[1][i][j] == label:
                    mask[i][j] = 255
        mask_dp = cv2.dilate(mask, kernel, iterations=iter_num)
        mask_dp = cv2.erode(mask_dp, kernel, iterations=iter_num)
        diff_im = mask_dp - mask
        cv2.rectangle(diff_im, (0, 0), (shape[0]-1, shape[0]-1), (0), border_width)
        diff_im = diff_im / 255
        return np.sum(diff_im)
    
    @staticmethod
    def extract_crt_points(crt_, dim):
        # dim: designate dimension
        # 0-dim: [0], 1-dim: [1], 0 and 1-dim: [0, 1]
        data_dims = crt_.size()
        birth_x = []
        birth_y = []
        death_x = []
        death_y = []
        assert(max(dim) < data_dims)
        for item in dim:
            for j in range(len(crt_[item])):
                birth_x.append(crt_[item][j][0])
                birth_y.append(crt_[item][j][1])
                death_x.append(crt_[item][j][2])
                death_y.append(crt_[item][j][3])
        return birth_x, birth_y, death_x, death_y
    
    @staticmethod
    def extract_birth_death_from_crt(crt_, dim):
        '''
        ===== inputs
        crt_ should be in form of dims * [num_structure * [4]] list
        dim, integer, the dimension to extract
        ===== outputs
        b_ and d_ are in form of (num_structure x 2) numpy array
        '''
        assert(dim < len(crt_))
        num_structure = len(crt_[dim])
        b_ = np.ones((num_structure, 2))
        d_ = np.ones((num_structure, 2))
        for i in range(num_structure):
            b_[i, 0] = crt_[dim][i][0]
            b_[i, 1] = crt_[dim][i][1]
            d_[i, 0] = crt_[dim][i][2]
            d_[i, 1] = crt_[dim][i][3]
        return b_, d_
    
    @staticmethod
    def compute_bnd_red_cv(img, low_th, high_th, connectivity):
        ret, thresh = cv2.threshold(img,low_th,high_th,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
        image, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
        red = cv2.connectedComponents(thresh, connectivity, cv2.CV_32S)
        return contours, hierarchy, red
    
    @staticmethod
    def compute_bnd_red_cv_batch(data):
        bat_size = data.shape[0]
        bnd_res = [None] * bat_size
        hcy_res = [None] * bat_size
        red_res = [None] * bat_size
        for i in range(bat_size):
            bnd_res[i], hcy_res[i], red_res[i] = Utility_topo.compute_bnd_red_cv(data[i,:], 0, 255, 8)
        return bnd_res, hcy_res, red_res
    
    @staticmethod
    def dist_trfm_batch(data):
        '''
        data should be np array with shape (batch_size, height, width)
        returns a list: batch_size * [(height*width,)]
        '''
        assert(len(data.shape) == 3)
        assert(type(data[0][0,0] == 'np.uint8'))
        bat_size  = data.shape[0]
        trfm_list = [None] * bat_size
        for i in range(bat_size):
            trfm_list[i] = np.concatenate(Utility_general.nz2_nearest_z(data[i,:]))
        return trfm_list
    
    @staticmethod
    def compute_dist_homology(dshape, trfm, pc, dim, debug=False, old_form=True):
        '''
        compute persistence homology at dim, default double type for faster computation
        ===== old_form
        outputs birth_x_list_, birth_y_list_, death_x_list_, death_y_list_ in a form of 
        batch_num * [num_structure] and pd_list_ in batch_num * [num_structure * [dim0, dim1, ...]]
        ===== not old_form
        outputs b_, d_, and pd_ in form of batch_num * [(num_structure, 2)] list
        ===== in case of debug
        bnd_list_ and red_list_: batch_num * [num_structure * [point_num * [x, y]]]
        '''
        assert(len(dshape) == 3)
        bat_size = dshape[0]
        height   = dshape[1]
        width    = dshape[2]
        
        if old_form:
            birth_x_list_ = [None] * bat_size
            birth_y_list_ = [None] * bat_size
            death_x_list_ = [None] * bat_size
            death_y_list_ = [None] * bat_size
            pd_list_      = [None] * bat_size
        else:
            B_  = [None] * bat_size
            D_  = [None] * bat_size
            PD_ = [None] * bat_size
        
        crt_ = cpp_nested3_IntVector()
        pd_  = cpp_nested3_DoubleVector()
        flatten_cpp = cpp_nested1_DoubleVector(height * width)
        
        if debug:
            bnd_list_ = [None] * bat_size
            red_list_ = [None] * bat_size
            bnd_ = cpp_nested4_IntVector()
            red_ = cpp_nested4_IntVector()
        
        for i in range(bat_size):
            for j in range(height * width):
                flatten_cpp[j] = float(trfm[i][j])
            pc.source_from_mat_from_double("dummy.dat", flatten_cpp, height, width)
            
            pc.run()
            pc.return_pers_V(crt_)
            pc.return_pers_BD(pd_)
            if debug:
                pc.return_bnd(bnd_)
                pc.return_red(red_)
                bnd_[dim] = Utility_topo.sort_persis_results_by_persistence(bnd_[dim], pd_[dim], True)
                red_[dim] = Utility_topo.sort_persis_results_by_persistence(red_[dim], pd_[dim], True)
                bnd_list_[i] = bnd_[dim]
                red_list_[i] = red_[dim]
            pc.clear()

            crt_[dim] = Utility_topo.sort_persis_results_by_persistence(crt_[dim], pd_[dim], True)
            pd_[dim]  = Utility_topo.sort_persis_results_by_persistence(pd_[dim], pd_[dim], True)
            
            if old_form:
                birth_x, birth_y, death_x, death_y = Utility_topo.extract_crt_points(crt_, [dim])
                birth_x_list_[i] = birth_x
                birth_y_list_[i] = birth_y
                death_x_list_[i] = death_x
                death_y_list_[i] = death_y
                pd_list_[i]      = pd_[dim]
            else:      
                B_[i], D_[i] = Utility_topo.extract_birth_death_from_crt(crt_, dim)
                PD_[i] = np.asarray(pd_[dim])
                
        if old_form:
            if debug:
                return birth_x_list_, birth_y_list_, death_x_list_, death_y_list_, pd_list_, bnd_list_, red_list_
            else:
                return birth_x_list_, birth_y_list_, death_x_list_, death_y_list_, pd_list_
        else:
            if debug:
                return B_, D_, PD_, bnd_list_, red_list_
            else:
                return B_, D_, PD_
        
    def topo_force_(set1, set2, G, mapping):
        '''
        Assume target on the LEFT and reference on the RIGHT.
        The result has the same length as set1 but each item
        with dim+1 element. The result stores at each item the
        location the current point moves to along with the
        matching distance.
        '''
        M   = len(set1)
        dim = len(set1[0].shape)             # ONLY WORKS FOR 1 DIM OR 2 DIM
        assert(M == len(G))
        assert(M == len(mapping))
        assert(dim == len(set2[0].shape))

        res = [None] * M
        for i in range(M):
            nl  = G[i].shape[0]
            nr  = G[i].shape[1]
            m_  = np.ones((nl, dim+1)) * -1
            s2_ = set2[mapping[i]]
            if dim == 1:
                s2_ = np.expand_dims(s2_, axis=1)
            if nl > nr:
                g_ = np.argmax(G[i], axis=0).astype(np.int64)
                for j in range(len(g_)):
                    for k in range(dim):
                        m_[g_[j], k] = s2_[j, k]
                    m_[g_[j], -1] = G[i][g_[j], j]
            elif nl <= nr:
                g_ = np.argmax(G[i], axis=1).astype(np.int64)
                for j in range(len(g_)):
                    for k in range(dim):
                        m_[j, k] = s2_[g_[j], k]
                    m_[j, -1] = G[i][j, g_[j]]
            res[i] = m_
        return res
    
    def topo_force_blind_(set1):
        '''
        Force all persistence dots to the left regardless of the matching.
        This version is the same as the CVPR version. Note this version's
        death and distance are NOT to be used!!
        ===== inputs
        set1: batch_num * [(structure_num * dims)] in case of 2 dimensions or
              batch_num * [(structure_num,)] in case of 1 dimension
        '''
        M   = len(set1)
        dim = len(set1[0].shape)             # ONLY WORKS FOR 1 DIM OR 2 DIM
        
        res = [None] * M
        for i in range(M):
            nl = set1[i].shape[0]
            m_ = np.ones((nl, dim+1)) * -1
            for j in range(nl):
                m_[j, 0] = 0
            res[i] = m_
        return res
    
    def apply_force_(b1, force, B_, bnd, hcy, fixN, split_form=False,
                     opposite2=False, connect=False, shape=None, thickness=1):
        '''
        ===== input
        b1:    birth time, batch_size * [(num,)]
        force: output from Utility_topo.topo_force_, batch_size * [(num, dim + 1)]
        B_:    birth coordinates, batch_size * [(num, dim)]
        bnd:   output from Utility_topo.compute_bnd_red_cv_batch, batch_size * [struct_num * [point_num * [1 * [x ,y]]]
        hcy:   output from Utility_topo.compute_bnd_red_cv_batch, batch_size * [(1, struct_num, 4)]
        fixN:  number of points to fix for each critical point
        split_form: if True, outputs batch_size * [structure_num * [integers]], else outputs
                    batch_size * [integers]
        opposite2/connect/shape/thickness: refer to Utility_general.find_closest_N_points
        ===== output
        flt_list: index of the filtered structures, batch_size * [filtered_num]
        force_x_ / force_y_: the coordinates to be changed, batch_size * [filtered_num * [fix_num]]
        '''
        bat_size = len(b1)
        assert(bat_size == len(force))
        assert(bat_size == len(B_))

        flt_list = [None] * bat_size
        force_x_ = [None] * bat_size
        force_y_ = [None] * bat_size
        for i in range(bat_size):
            struct_num = b1[i].shape[0]
            file_l_ = []
            file_x_ = []
            file_y_ = []
            bnd_cp  = Utility_topo.complete_contours_bnd_cv(hcy[i], bnd[i])
            for j in range(struct_num):
                # ==== FILTER ONE =====
                if b1[i][j] == 0 or force[i][j,0] != 0:
                    continue
                # ==== FILTER TWO =====
                x_ = B_[i][j,0]
                y_ = B_[i][j,1]
                ctor_idx = Utility_topo.return_countour_with_p_inside(bnd[i], hcy[i], (x_, y_))
                if ctor_idx < 0:
                    continue
                pts_x, pts_y = Utility_general.find_closest_N_points((x_, y_), bnd_cp[ctor_idx], fixN,
                               opposite2=opposite2, connect=connect, shape=shape, thickness=thickness)
                file_l_.append(j)
                if split_form:
                    file_x_.append(pts_x)
                    file_y_.append(pts_y)
                else:
                    file_x_ = file_x_ + pts_x
                    file_y_ = file_y_ + pts_y
            flt_list[i] = file_l_
            force_x_[i] = file_x_
            force_y_[i] = file_y_
        return flt_list, force_x_, force_y_