In [1]:
import itertools

import numpy as np
import scipy as sp
import pandas as pd
import networkx as nx

import meshplot as mp
import pyvista as pv
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns

from src import shapes

# Define The Figure and the Morse Function

In [2]:
def get_linear_morse(vector=None):
    if vector is None:
        vector = np.random.random(4)
    vector = np.array(vector)
    def f(points):
        return points @ vector
    return f

direction = np.random.random(3)
direction /= np.linalg.norm(direction)

In [3]:
f = lambda p: np.linalg.norm(p, axis=-1, ord=2)
f_linear = lambda p: p[:, 1]
#f = lambda p: (np.random.random(3)*p).sum(axis=-1)*(np.random.random(3)*p).sum(axis=-1)

In [4]:
def cylindrical_twist(vertices, k=1.0, mode="x", scale=1.0):
    """
    Nonlinear cylindrical twist diffeomorphism on R^3.

    vertices: (n,3) array
    mode:
      - "z": angle depends on z  (theta = k * tanh(z/scale))
      - "r": angle depends on radius r (theta = k * tanh(r/scale))
    k: twist strength (radians, roughly bounded by +/-k for tanh)
    scale: controls how quickly tanh saturates
    """
    if mode == 'x':
        v = vertices[:, [1, 2, 0]]
        v = cylindrical_twist(v, k=k, mode="z", scale=scale)
        v = v[:, [2, 0, 1]]
        return v
        
        
    v = vertices.copy()
    x, y, z = v[:, 0], v[:, 1], v[:, 2]

    r = np.sqrt(x*x + y*y)

    if mode == "z":
        theta = k * np.tanh(z / scale)
    elif mode == "r":
        theta = k * np.tanh(r / scale)
    else:
        raise ValueError('mode must be "z" or "r"')

    c, s = np.cos(theta), np.sin(theta)

    v[:, 0] = c * x - s * y
    v[:, 1] = s * x + c * y
    
    return v

In [5]:
n, m = 13, 12
vertices, faces = shapes.get_halftori_bouquet(leaves=2, n=n, m=m, l0=0.9, glue=False)

vertices = cylindrical_twist(vertices, k=-0.3, scale=1.5, mode='x')

vertices, faces = shapes.split_large_edges(vertices, faces, max_length=1.0)


print(f'faces.shape = {faces.shape}')

face_mean_values = f_linear(vertices[faces]).mean(axis=1)

p = mp.plot(vertices, faces, face_mean_values, shading={"wireframe": True})

faces.shape = (788, 3)


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

In [6]:
faces_pv = np.hstack([np.full((faces.shape[0], 1), 3, dtype=faces.dtype), faces]).ravel()

mesh = pv.PolyData(vertices, faces_pv)
mesh.point_data["values"] = f_linear(vertices)  # per-vertex scalars

p = pv.Plotter(window_size=(600, 600))
p.add_mesh(
    mesh,
    scalars="values",
    cmap="viridis",
    smooth_shading=False,   # helps show linear interpolation nicely
    show_edges=True,      # set True if you want to see triangle edges
)
p.add_scalar_bar(title="values")
p.show()

Widget(value='<iframe src="http://localhost:38611/index.html?ui=P_0x7725bd25f9e0_0&reconnect=auto" class="pyvi…

# Paths

In [7]:
from src.ms import MorseSmale

In [8]:
ms = MorseSmale(faces, f(vertices), vertices, forest_method='steepest')
#ms = MorseSmale(faces, f(vertices), vertices, forest_method='spaning')

paths = list(ms.iterate_paths())

In [9]:
from src import vis
from src.vis import plot_paths, plot_segmentation_forests, plot_ms_comparition


In [10]:
ms0 =  MorseSmale(faces, f(vertices), vertices, forest_method='steepest')
ms1 =  MorseSmale(faces, f_linear(vertices), vertices, forest_method='steepest')

ms0.define_critical_points()
ms1.define_critical_points()

In [11]:
faces_components = ms.define_decomposition_by_paths()
faces_components

array([0, 1, 1, 0, 1, 1, 0, 0, 2, 3, 2, 1, 0, 3, 2, 3, 2, 2, 3, 3, 1, 0,
       2, 3, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 1, 3, 3, 0, 3, 2, 2, 3, 3,
       1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 1, 0, 1, 0, 2, 3,
       2, 3, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 2, 1, 3, 3, 0, 3, 2, 2,
       2, 3, 3, 3, 1, 0, 2, 3, 1, 0, 2, 3, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2,
       2, 2, 3, 3, 3, 3, 1, 0, 2, 3, 1, 1, 1, 0, 0, 0, 2, 2, 2, 3, 3, 3,
       1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 1, 2, 3, 3, 3, 0, 2, 2, 3, 3, 1, 1,
       1, 0, 0, 0, 2, 2, 2, 3, 3, 3, 1, 0, 2, 3, 1, 0, 2, 3, 1, 0, 1, 0,
       1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 1, 3, 3, 0, 3, 2, 2, 3, 3, 3,
       2, 3, 3, 2, 1, 0, 2, 3, 1, 0, 0, 2, 2, 3, 1, 1, 1, 1, 0, 0, 0, 0,
       2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 0, 0, 0, 2, 2, 2, 3, 3, 3, 1, 1,
       1, 0, 0, 0, 2, 2, 2, 3, 3, 3, 1, 1, 1, 1, 0, 0, 0, 2, 1, 2, 3, 3,
       3, 0, 2, 2, 3, 3, 1, 1, 1, 0, 0, 0, 2, 2, 2, 3, 3, 3, 1, 0, 2, 3,
       1, 0, 1, 0, 3, 2, 2, 3, 1, 0, 2, 3, 1, 0, 2,

In [12]:
new_mesh = vis.get_pv_mesh(ms.vertices, ms.faces)
new_mesh.cell_data['component'] = ms.define_decomposition_by_paths()


pl = pv.Plotter(window_size=(600, 600))
pl.add_mesh(new_mesh, scalars="component", cmap="rainbow", smooth_shading=False, show_edges=True, categories=True)
for path in ms.get_paths():
    pl.add_mesh(pv.lines_from_points(vertices[path]), color='white', line_width=4)

pl.show()

Widget(value='<iframe src="http://localhost:38611/index.html?ui=P_0x7725a421edb0_1&reconnect=auto" class="pyvi…

In [13]:
face_graph = ms.get_face_graph()
face_pos = {i: ms.vertices[ms.faces[i]].mean(axis=0) for i in face_graph.nodes()}

face_graph_reduced = face_graph.copy()

# edges from graph
edges0, edges1, attrs = zip(*face_graph_reduced.edges(data=True))
edges = np.column_stack([edges0, edges1])

# edges from paths (already made undirected + unique)
paths_edges = np.concatenate([np.column_stack([path[:-1], path[1:]]) for path in ms.get_paths()])
paths_edges = np.unique(np.sort(paths_edges, axis=1), axis=0)

# ---- IMPORTANT: also sort graph edges for undirected matching
edges_u = np.sort(edges, axis=1)

# make contiguous for view trick
edges_u = np.ascontiguousarray(edges_u)
paths_u = np.ascontiguousarray(paths_edges)

# row-wise intersection via void view
dtype = np.dtype((np.void, edges_u.dtype.itemsize * edges_u.shape[1]))
edges_view = edges_u.view(dtype).ravel()
paths_view = paths_u.view(dtype).ravel()

to_remove_view = np.intersect1d(edges_view, paths_view)
edges_to_remove = to_remove_view.view(edges_u.dtype).reshape(-1, 2)

# convert to list of tuples for networkx
face_graph_reduced.remove_edges_from(map(tuple, edges_to_remove))
face_graph_reduced.remove_edges_from(map(lambda e: (e[1], e[0]), edges_to_remove))


In [14]:
face_graph = ms.get_face_graph()
face_pos = {i: ms.vertices[ms.faces[i]].mean(axis=0) for i in face_graph.nodes()}


self = ms

# represent face_graph edges as pairs of vertex ids triplets
edges_face_repr = self.faces[np.array(list(self.get_face_graph().edges))]

# define the edges of the complex coresponding the edges of the graph
edges_edge_repr = -1*np.ones([edges_face_repr.shape[0], 2], dtype=int)
for j0, j1 in itertools.product(itertools.combinations(range(3), 2), repeat=2):
    cond = (edges_face_repr[:, 0, list(j0)] == edges_face_repr[:, 1, list(j1)]).all(axis=1)
    edges_edge_repr[cond] = edges_face_repr[cond][:, 0, (list(j0))]
edges_edge_repr = np.sort(edges_edge_repr, axis=1)

# edges of the complex inclued into paths
paths_edges = np.concatenate([np.transpose([path[:-1], path[1:]]) for path in ms.paths])
paths_edges = np.unique(np.sort(paths_edges, axis=1), axis=0)

# remove edges from graph, which are included into paths
remove_conds = edges_edge_repr[:, None, :, None] == paths_edges[None, :, None, :]
remove_conds = (remove_conds[:, :, 0, 0] & remove_conds[:, :, 1, 1]).any(axis=-1)

edges_to_remove = [edge for edge, cond in zip(self.get_face_graph().edges, remove_conds) if cond]

face_graph_reduced = face_graph.copy()
face_graph_reduced.remove_edges_from(edges_to_remove)

In [15]:
for g in [face_graph, face_graph_reduced]:
    print('Nodes:', g.number_of_nodes(), 'Edges:', g.number_of_edges(), 'Components:', nx.number_connected_components(g))

Nodes: 788 Edges: 1182 Components: 1
Nodes: 788 Edges: 1078 Components: 8


In [16]:

pl = pv.Plotter(shape=(1, 2), window_size=(1000, 500))
pl.subplot(0, 0)
#vis.add_graph_to_plotter(pl, face_graph, face_pos)
vis.add_graph_to_plotter_by_components(pl, face_graph, face_pos)

for path in ms.get_paths():
    pl.add_mesh(pv.lines_from_points(vertices[path]), color='black', line_width=4)

pl.subplot(0, 1)
#vis.add_graph_to_plotter(pl, face_graph_reduced, face_pos)
vis.add_graph_to_plotter_by_components(pl, face_graph_reduced, face_pos)

for path in ms.get_paths():
    pl.add_mesh(pv.lines_from_points(vertices[path]), color='black', line_width=4)


pl.link_views()
pl.show()


Widget(value='<iframe src="http://localhost:38611/index.html?ui=P_0x772597a60350_2&reconnect=auto" class="pyvi…

In [17]:
assert False

AssertionError: 

In [None]:
def compact_mesh(V, F):
    """
    Remove vertices not referenced by any face and reindex faces.
    Returns V2, F2, old2new, new2old.
    """
    V = np.asarray(V)
    F = np.asarray(F)

    used = np.zeros(len(V), dtype=bool)
    used[F.reshape(-1)] = True

    new2old = np.nonzero(used)[0]
    old2new = -np.ones(len(V), dtype=int)
    old2new[new2old] = np.arange(len(new2old))

    V2 = V[new2old]
    F2 = old2new[F]

    return V2, F2, old2new, new2old

In [None]:
V2, F2, old2new, new2old = compact_mesh(vertices, surrounding)

start_idx = old2new[path[0]]
end_idx = old2new[path[-1]]

In [None]:
F2_pv = np.hstack([np.full((surrounding.shape[0], 1), 3, dtype=faces.dtype), F2]).ravel()

part_mesh = pv.PolyData(V2, F2_pv)

pl = pv.Plotter(window_size=(600, 600))
pl.add_mesh(part_mesh, color='white', smooth_shading=False, show_edges=True)
pl.show()

Widget(value='<iframe src="http://localhost:38733/index.html?ui=P_0x75cfb4a83bc0_4&reconnect=auto" class="pyvi…