In [None]:
import numpy as np
import matplotlib.pyplot as plt
from triUtils import *
import open3d as o3d
import pathlib
import scipy.stats as stats
from scipy.special import erf
from math import sqrt
import matplotlib

In [None]:
folder = '/Users/schimmenti/Desktop/DresdenProjects/wingsurface/vtk_meshes/'
meshes = []
for file in pathlib.Path(folder).glob("*"):
    if(file.is_dir()):
        continue
    if(file.name.endswith('.ply')):
        mesh = o3d.io.read_triangle_mesh(str(file.absolute()))
        verts, tris = np.asarray(mesh.vertices), np.asarray(mesh.triangles)
        meshes.append((verts, tris))

In [None]:
from matplotlib.widgets import Slider, TextBox, Button
import copy
from scipy.optimize import minimize
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from sklearn.metrics.pairwise import euclidean_distances

segmentation_params = {}

for idx in range(len(meshes)):
    verts, tris = meshes[idx]

    edgelist = compute_edgelist(tris)

    tri_normals, tri_areas = compute_triangle_normals(verts, tris, return_areas=True)
    vert_normals = compute_vertex_normals(verts, tris)
    vert_barycentric_areas = compute_vertex_barycentric_areas(verts, tris, tri_areas)

    M_taubin = compute_taubin_matrices(verts, tris, vert_normals, tri_areas)
    kappa_1, kappa_2 = compute_taubin_principal_curvatures(M_taubin, vert_normals)
    mean_curvature = -(kappa_1+kappa_2)/2

    cot_lapl_matrix = compute_cot_laplacian(verts, tris, normalize_by_areas=False, return_areas=False)

    hist, edges = np.histogram(mean_curvature[mean_curvature>0],bins=15,density=True)
    hist, edges = hist[hist>0], edges[:-1][hist>0]
    slopes = []
    for n in range(3, hist.shape[0]):
        slopes.append( stats.linregress(edges[:n],np.log(hist[:n])).slope)
    median_slope = np.median(slopes)

    min_h = -0.5/median_slope
    max_h = 10*min_h


    %matplotlib qt
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d') 
    interactive_axes = [plt.axes([0.1, 0.9, 0.1, 0.03]), plt.axes([0.1, 0.8, 0.1, 0.03]),
                        plt.axes([0.1, 0.7, 0.1, 0.03]),plt.axes([0.1, 0.6, 0.1, 0.03]),
                        plt.axes([0.1, 0.5, 0.1, 0.03]), plt.axes([0.1, 0.4, 0.1, 0.03]),
                        plt.axes([0.1, 0.3, 0.1, 0.03]), plt.axes([0.1, 0.2, 0.1, 0.03])]
    
    minH_widget = Slider(interactive_axes[0], 'MinH', 0.00, 3*min_h, valinit=min_h, valstep=0.005)
    maxH_widget = Slider(interactive_axes[1], 'MaxH', 0.0, 5.0, valinit=max_h, valstep=0.25)
    numpts_widget =  TextBox(interactive_axes[2], 'Min Num Points', initial=str(20))
    numcomp_widget = TextBox(interactive_axes[3], 'Num Components', initial=str(10))
    selection_widget = TextBox(interactive_axes[4], 'Folds',"")
    cswidths_widget = TextBox(interactive_axes[5], 'Cross section widths', "")
    submit_widget = Button(interactive_axes[6], 'Compute')
    skip_widget = Button(interactive_axes[7], 'Skip')

    params_dict= {}
    params_dict['num_comp'] = int(numcomp_widget.text)
    params_dict['curvature_min'] = float(minH_widget.val)
    params_dict['curvature_max'] = float(maxH_widget.val)
    params_dict['min_num_pts'] = int(numpts_widget.text)

    folds_plots = []
    bottom_plots = []


    skip_sample = False

    def skip_click(stuff):
        global skip_sample
        skip_sample = True
        plt.close()

    def update_folds(val):
        if(ax.get_legend() != None):
            ax.get_legend().remove()
        new_numComp = int(numcomp_widget.text)
        new_curvature_min = float(minH_widget.val)
        new_curvature_max = float(maxH_widget.val)
        new_min_num_pts = int(numpts_widget.text)
        recompute_clusters = new_numComp != params_dict['num_comp'] or new_curvature_min != params_dict['curvature_min'] or new_curvature_max != params_dict['curvature_max'] or new_min_num_pts != params_dict['min_num_pts']
        recompute_clusters = recompute_clusters or (ax.get_legend() == None)
        
        if(recompute_clusters):
            params_dict['num_comp'] = new_numComp
            params_dict['curvature_min'] = new_curvature_min
            params_dict['curvature_max'] = new_curvature_max
            params_dict['min_num_pts'] = new_min_num_pts
            print('Computing clusters..')
            clusters = compute_graph_clustering(edgelist,mean_curvature, params_dict['curvature_min'], params_dict['curvature_max'], 
                                                 min_num_points=params_dict['min_num_pts'], num_max_clusters= params_dict['num_comp'])
            if(len(folds_plots)>0):
                for el in folds_plots:
                    el.remove()
                folds_plots.clear() 

            for cluster_index,cluster in enumerate(clusters):
                folds_plots.append(ax.scatter(verts[cluster,1],verts[cluster,2],verts[cluster,0],s=10, label=cluster_index))

            
            if(len(bottom_plots)>0):
                for el in bottom_plots:
                    el.remove()
                bottom_plots.clear()  

            ax.legend()
            fig.canvas.draw_idle()

    foldcomps = {}
    def select_folds(val):
        foldcomps_list =  selection_widget.text.split(',')
        foldcomps.clear()
        for i, s in enumerate(foldcomps_list):
            if(len(s) == 0):
                continue
            if(i==0):
                foldcomps['hn'] = int(s)
            elif(i==1):
                foldcomps['hh'] = int(s)
            elif(i==2):
                foldcomps['hp'] = int(s)
            else:
                foldcomps['v%i' % (i-3)] = int(s)

    selection_widget.on_submit(select_folds)
    ax.scatter(verts[:,1], verts[:, 2], verts[:, 0], color='gray', s=1, alpha=0.75)
    skip_widget.on_clicked(skip_click)
    submit_widget.on_clicked(lambda _: update_folds(None))
    ax.set_aspect('equal')
    #fig.canvas.mpl_connect('button_release_event', on_release)
    plt.show(block=True) 
    params_dict['components'] = foldcomps
    segmentation_params[idx] = copy.copy(params_dict)
    print(idx,params_dict)