In [None]:
from sklearn.neighbors import KDTree
import os
import json
import torch
import trimesh
import numpy as np
import glob
import pickle
from tqdm import tqdm
import jax.numpy as jnp

In [None]:
### PASS IN LANGEVIN OUTPUT AND MESH###
mode_path = '/content/mesh_1_0.1_10000_modes.pickle'
with open(mode_path, 'rb') as f:
    modes = pickle.load(f)

mesh = trimesh.load('/content/mesh.obj') ## the mesh needs to be centered + normalized!

"""
mesh = trimesh.load('/content/cat4.off')
mesh = o3d.io.read_triangle_mesh('/content/cat4.off')
center_of_mass = mesh.get_center()
mesh.translate(-center_of_mass)
norms = np.linalg.norm(np.asarray(mesh.vertices), axis=1)
mesh.scale(1/np.max(norms), center=(0,0,0))
o3d.io.write_triangle_mesh('t.obj', mesh)
"""

In [None]:
def safe_norm(v, axis=None, eps=1e-6):
    l = jnp.linalg.norm(v, axis=axis)
    return jnp.where(l < eps, 0, l)
#for clustering
def dbscan(D, eps, MinPts):
    labels = [0]*len(D)
    C = 0
    for P in range(0, len(D)):
        if not (labels[P] == 0):
           continue
        NeighborPts = region_query(D, P, eps)
        if len(NeighborPts) < MinPts:
            labels[P] = -1
        else:
           C += 1
           grow_cluster(D, labels, P, NeighborPts, C, eps, MinPts)
    return labels

def grow_cluster(D, labels, P, NeighborPts, C, eps, MinPts):
    labels[P] = C
    i = 0
    while i < len(NeighborPts):
        Pn = NeighborPts[i]
        if labels[Pn] == -1:
           labels[Pn] = C
        elif labels[Pn] == 0:
            labels[Pn] = C
            PnNeighborPts = region_query(D, Pn, eps)
            if len(PnNeighborPts) >= MinPts:
                NeighborPts = NeighborPts + PnNeighborPts

        i += 1

def region_query(D, P, eps):
    neighbors = []
    for Pn in range(0, len(D)):
        #if geodesic_dist(D[P], D[Pn])<eps:
        if safe_norm(D[P] - D[Pn]) < eps:
           neighbors.append(Pn)

    return neighbors

def compute_centroids(data, labels):
    unique = np.unique(labels)
    unique_labels = unique[unique!=-1]
    centroids = []
    for label in unique_labels:
        mask = labels == label
        points = data[mask]
        centroid = jnp.mean(points, axis=0)
        centroids.append(centroid)

    return np.stack(centroids)

def createPlane(normal, point_on_plane):
    normal = normal / np.linalg.norm(normal)

    # Find a vector in the plane
    if np.allclose(normal, [1, 0, 0]):
        v1 = np.cross(normal, [0, 1, 0])
    else:
        v1 = np.cross(normal, [1, 0, 0])

    v1 = v1 / np.linalg.norm(v1)
    v2 = np.cross(normal, v1)
    v2 = v2 / np.linalg.norm(v2)

    half_width = 1
    half_height = 1

    # Calculate the corners
    corner1 = point_on_plane + half_width * v1 + half_height * v2
    corner2 = point_on_plane - half_width * v1 + half_height * v2
    corner3 = point_on_plane - half_width * v1 - half_height * v2
    corner4 = point_on_plane + half_width * v1 - half_height * v2

    vertices = np.array([corner1, corner2, corner3, corner4])

    # Define the faces of the rectangle
    faces = np.array([
        [0, 1, 2],
        [0, 2, 3],
        [2, 1, 0],
        [3, 2, 0]
    ])

    # Create a mesh for the rectangle
    vc1 = np.tile([247, 247, 121, 100], (4, 1))

    plane_mesh = trimesh.Trimesh(vertices=vertices, faces=faces,vertex_colors=vc1)

    return plane_mesh

def plot_plane(idx):
    points = trimesh.PointCloud(mesh.vertices, colors = [0,191,255, 100])
    plane = createPlane(n[idx], point[idx])
    scene = trimesh.Scene([points, plane])
    return scene


In [None]:
my_labels = dbscan(modes, eps=0.03, MinPts=3) #TODO:this has to be tuned!
centroid = compute_centroids(modes, my_labels)
centroids = np.array(centroid)
print("total number of modes found: " + str(len(centroids)))

norm = torch.norm(torch.tensor(np.array(centroids)), dim=-1)
n = centroids / norm[:,None]
d = norm - 1
point = n * d[...,None]

total number of modes found: 3


In [None]:
###visualize planes###
allplanes = []
for i in range(len(n)):
  pl = createPlane(n[i], point[i])
  allplanes.append(pl)

###only visualizing the first two planes for sanity check...###
points = trimesh.PointCloud(mesh.vertices, colors = [0,191,255])
plane1 = createPlane(n[0], point[0])
plane2 = createPlane(n[1], point[1])
mesh.visual.vertex_colors[np.arange(len(mesh.vertices))] = [220,220,220, 255]

#scene = trimesh.Scene([points])
scene = trimesh.Scene([mesh, plane1, plane2])
scene.show()