In [3]:
%run Utility_general.ipynb
%run Utility_topo.ipynb

import sys
from numpy import linalg as LA
from mpl_toolkits.mplot3d import Axes3D

# # installed lib for persistence image
# import PersistenceImages.persistence_images as pimg

# python library compile from cpp
sys.path.insert(0, './persis_lib_cpp')
from persis_homo_optimal import *

class Edges_(object):
    
    def __init__(self, params, debug=False):
        self.pc    = Persistence_Computer()
        self.debug = debug
        
        # Parse parameters
        self.watershed       = params["image_watershed"]
        self.topo_threshold  = params["target_topo_threshold"]
        self.pts_to_fix      = params["number_of_pts_to_fix"]
        self.hole_kernel_r   = params["hole_test_kernel_radius"]
        self.hole_iterations = params["hole_test_iteration"]
        self.hole_bordwidth  = params["hole_test_border_width"]
        self.hole_threshold  = params["hole_test_threshold"]
    
    def treat_edges(self, d_mat):
        # binarize generated images
        d_mat = np.squeeze(d_mat)
        data  = d_mat.copy()
        data[data > self.watershed]  = 255.0
        data[data <= self.watershed] = 0.0
        data = data.astype(np.uint8)

        # compute persistent homology for each image
        bx, by, dx, dy, pd    = Utility_topo.compute_dist_homology(data, self.pc, self.debug)       
        target_ind            = Utility_topo.topo_filter(pd, self.topo_threshold)
        bnd_res, red_res      = Utility_topo.compute_bnd_red_cv_batch(data)
        fix_x, fix_y, crtB_dt = self.detection_fix_points(bx, by, pd, target_ind, bnd_res, red_res, [data.shape[1],data.shape[2]])    
        loss                  = self.compute_topo_loss(d_mat, fix_x, fix_y, crtB_dt)
        return np.mean(loss)
    
    def detection_fix_points(self, bx, by, pd, ind, bnd_res, red_res, shape):
        bat_size = len(ind)
        fix_x    = [None] * bat_size
        fix_y    = [None] * bat_size
        crt_birth_distrfm_val = [None] * bat_size
        kernel = np.ones((self.hole_kernel_r, self.hole_kernel_r), np.uint8)
        
        for i in range(bat_size):
            set_x = []
            set_y = []
            crt_b_dt_val = []
            good_section_record = []
            for idx in ind[i]:
                countour_idx = Utility_topo.return_countour_with_p_inside(bnd_res[i], (bx[i][idx], by[i][idx]))
                if countour_idx >= 0:
                    section_label = red_res[i][1][by[i][idx]][bx[i][idx]]
                    if section_label not in good_section_record:
                        hole_test = Utility_topo.dangling_edge_test(section_label, red_res[i],
                                    shape, kernel, self.hole_iterations, self.hole_bordwidth)
                        if hole_test > self.hole_threshold:
                            good_section_record.append(section_label)
                    if section_label in good_section_record:
                        pts_x, pts_y = Utility_general.find_closest_N_points((bx[i][idx], by[i][idx]),
                                       bnd_res[i][countour_idx], self.pts_to_fix)
                        set_x = set_x + pts_x
                        set_y = set_y + pts_y
                        crt_b_dt_val.append(pd[i][idx][0])
            fix_x[i] = set_x
            fix_y[i] = set_y
            crt_birth_distrfm_val[i] = crt_b_dt_val
        return fix_x, fix_y, crt_birth_distrfm_val
    
    def compute_topo_loss(self, data_origin, fix_x, fix_y, crt_birth_distrfm_val):
        bat_size = data_origin.shape[0]
        loss = [0.0] * bat_size
        
        for i in range(bat_size):
            l_ = 0.0
            for j in range(len(fix_x[i])):
                cur_val = crt_birth_distrfm_val[i][int(j/self.pts_to_fix)]
                l_ = l_ + (data_origin[i][fix_y[i][j]][fix_x[i][j]] - (-1.0)) * cur_val
            loss[i] = l_
        return loss