In [None]:
from pathlib import Path
import json
import inspect
import tqdm
from typing import Any, Optional, Dict, List, Tuple, Callable, Union

import numpy as np
import networkx as nx
import scipy.sparse as sparse
import open3d as o3d
import matplotlib.pyplot as plt
import igl
from scipy.spatial import KDTree

import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go
import plotly.express as px

In [None]:
def _clean_mesh_open3d(mesh: o3d.geometry.TriangleMesh) -> o3d.geometry.TriangleMesh:
    ret_mesh = mesh.remove_duplicated_triangles()
    ret_mesh = ret_mesh.remove_duplicated_vertices()
    ret_mesh = ret_mesh.remove_degenerate_triangles()
    ret_mesh = ret_mesh.remove_non_manifold_edges()
    ret_mesh = ret_mesh.remove_unreferenced_vertices()

    # Keep only the largest connected component
    clusters, lengths, _ = ret_mesh.cluster_connected_triangles()
    clusters = np.asarray(clusters)
    lengths = np.asarray(lengths)
    largest_cluster = np.argmax(lengths)
    ret_mesh.remove_triangles_by_index(
        np.where(clusters != largest_cluster)[0]
    )
    ret_mesh = ret_mesh.remove_unreferenced_vertices()

    # Remove non-manifold vertices
    nm_verts = ret_mesh.get_non_manifold_vertices()
    if len(nm_verts) > 0:
        ret_mesh.remove_vertices_by_index(nm_verts)

    # Final clean-up
    ret_mesh = ret_mesh.remove_non_manifold_edges()
    ret_mesh = ret_mesh.remove_unreferenced_vertices()
    return ret_mesh
def _orient_mesh_by_centroid(vertices : np.ndarray,
                             triangles : np.ndarray,
                             vertex_normals : np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    flip = np.mean((vertices - vertices.mean(axis=0, keepdims=True))*vertex_normals > 0) < 0.5
    if flip:
        triangles = triangles[:, [2, 1, 0]]
        if vertex_normals is not None:
            vertex_normals = -vertex_normals
    return triangles, vertex_normals
def _process_mesh(
    mesh: Optional[o3d.geometry.TriangleMesh] = None,
    vertices: Optional[np.ndarray] = None,
    triangles: Optional[np.ndarray] = None,
    scale: Union[float, Tuple[float, float, float]] = 1.0,
    invert_axis: Tuple[bool, bool, bool] = (False, False, False),
    mesh_clean_pipeline: Callable = _clean_mesh_open3d,
    mesh_clean_pipeline_params: Optional[Dict] = None,
    orient_by_centroid: bool = False,
    return_as_numpy: bool = False,
):
    # --- get V,F as numpy ---
    if mesh is None:
        if vertices is None or triangles is None:
            raise ValueError("Either mesh or both vertices and triangles must be provided.")
        V = np.asarray(vertices, dtype=np.float64).copy()
        F = np.asarray(triangles, dtype=np.int32).copy()
        mesh = o3d.geometry.TriangleMesh()
    else:
        V = np.asarray(mesh.vertices).copy()
        F = np.asarray(mesh.triangles).copy()

    # --- preprocess in numpy ---
    if scale != 1.0:
        if isinstance(scale, (int, float)):
            V *= scale
        else:
            V *= np.array(scale)

    for axis, inv in enumerate(invert_axis):
        if inv:
            mn, mx = V[:, axis].min(), V[:, axis].max()
            V[:, axis] = (mx + mn) - V[:, axis]

    # odd number of reflections => flip winding
    if sum(bool(x) for x in invert_axis) % 2 == 1:
        F = F[:, [0, 2, 1]]

    # --- write back before normals/orientation ---
    mesh.vertices = o3d.utility.Vector3dVector(V)
    mesh.triangles = o3d.utility.Vector3iVector(F)

    # now normals correspond to the current geometry
    mesh.compute_vertex_normals()
    N = np.asarray(mesh.vertex_normals)

    if orient_by_centroid:
        F2, N2 = _orient_mesh_by_centroid(V, F, N)
        F = F2
        # write updated triangles and recompute normals (safest)
        mesh.triangles = o3d.utility.Vector3iVector(F)
        mesh.compute_vertex_normals()
        N = np.asarray(mesh.vertex_normals)


    if mesh_clean_pipeline is not None:
        params = dict(mesh_clean_pipeline_params or {})
        sig = inspect.signature(mesh_clean_pipeline)

        has_var_kwargs = any(
            p.kind == inspect.Parameter.VAR_KEYWORD
            for p in sig.parameters.values()
        )

        if not has_var_kwargs:
            # filter params
            params = {k: v for k, v in params.items() if k in sig.parameters}

        # inject core objects if accepted and not already provided
        if 'mesh' in sig.parameters and 'mesh' not in params:
            params['mesh'] = mesh
        if 'vertices' in sig.parameters and 'vertices' not in params:
            params['vertices'] = V
        if 'triangles' in sig.parameters and 'triangles' not in params:
            params['triangles'] = F

        mesh = mesh_clean_pipeline(**params)


    if return_as_numpy:
        return np.asarray(mesh.vertices), np.asarray(mesh.triangles)
    return mesh
def _load_and_process_mesh(file_path : str | Path, scale : float = 1.0, invert_axis : Tuple[bool, bool, bool] = (False, False, False), orient_by_centroid: bool = False, return_as_numpy: bool = False) -> Union[o3d.geometry.TriangleMesh, Tuple[np.ndarray, np.ndarray]]:
    mesh = o3d.io.read_triangle_mesh(str(file_path))
    return _process_mesh(mesh=mesh,
                            scale=scale,
                            invert_axis=invert_axis,
                            mesh_clean_pipeline=_clean_mesh_open3d,
                            orient_by_centroid=orient_by_centroid, return_as_numpy=return_as_numpy)
def load_and_process_meshes(mesh_info_dict : Union[Dict[Any, dict]], verbose : bool = False) -> Dict[Any, Tuple[np.ndarray, np.ndarray]]:
    processed_meshes = {}   
    pbar = tqdm.tqdm(mesh_info_dict.items(), disable=not verbose)
    for name, fields in pbar:
        pbar.set_description(f"Processing mesh: {name}")
        vertices, triangles = _load_and_process_mesh(fields['path'], scale=fields.get('scale', 1.0),
                                                      invert_axis=fields.get('invert_axis', (False, False, False)), 
                                                      orient_by_centroid=fields.get('orient_by_centroid', False), 
                                                      return_as_numpy=True)
        processed_meshes[name] = (vertices, triangles)
    return processed_meshes

In [70]:
class FoldSegmentation:
    def __init__(self, 
                 initial_params : Dict,
                 vertices : np.ndarray, 
                 triangles : np.ndarray,
                 exclude_boundary_loop : bool = True):
        self.segmentation_params_types = {
                    'min_H': float,
                    'max_H': float,
                    'use_pc2': bool,
                    'pc2_quantile': float,
                    'max_num_clusters': int,
                    'expand_distance': float,
                    'expand_graph_distance': int,
                    'join_method': str,  # 'and' or 'or'
                    }
        self.vertices = vertices
        self.triangles = triangles
        self.vertex_normals = igl.per_vertex_normals(vertices, triangles)
        principal_curvatures = igl.principal_curvature(vertices, triangles)
        self.vertex_pc1_values, self.vertex_pc2_values = principal_curvatures[2], principal_curvatures[3]
        #self.vertex_mean_curvature = (self.vertex_pc1_values + self.vertex_pc2_values) / 2.0
        cotmatrix = igl.cotmatrix(vertices, triangles)
        massmatrix = igl.massmatrix(vertices, triangles, igl.MASSMATRIX_TYPE_VORONOI)
        laplacian = sparse.linalg.inv(massmatrix) @ cotmatrix
        self.vertex_mean_curvature = np.sum((laplacian @ vertices)*self.vertex_normals, axis=1)
        self.boundary_loop = igl.boundary_loop(triangles)
        if exclude_boundary_loop:
            self.vertex_mean_curvature[self.boundary_loop] = np.nan
            self.vertex_pc2_values[self.boundary_loop] = np.nan
        self.vertex_adj_list = igl.adjacency_list(triangles)
        self.adj_graph = nx.from_dict_of_lists({i: nbrs for i, nbrs in enumerate(self.vertex_adj_list)})
        self.params = {}
        for param_name, param_type in self.segmentation_params_types.items():
            if param_name in initial_params:
                if not isinstance(initial_params[param_name], param_type):
                    raise ValueError(f"Parameter {param_name} must be of type {param_type}.")
                self.params[param_name] = initial_params[param_name]
            else:
                raise ValueError(f"Missing required parameter: {param_name}")
        self.tree = KDTree(self.vertices)
        self._mean_curvature_mask = None
        self._pc2_mask = None
        self._clusters = None
        self._expanded_clusters = None
        #self._annotations = None

    def update_parameter(self, param_name: str, param_value, invalidate_caches: bool = True) -> bool:
        if param_name not in self.segmentation_params_types:
            raise ValueError(f"Unknown parameter: {param_name}")
        if not isinstance(param_value, self.segmentation_params_types[param_name]):
            raise ValueError(f"Parameter {param_name} must be of type {self.segmentation_params_types[param_name]}.")
        old_value = self.params[param_name]
        self.params[param_name] = param_value
        parameter_changed = old_value != param_value
        if invalidate_caches and parameter_changed:
            if param_name in ['min_H', 'max_H']:
                self._mean_curvature_mask = None
                self._clusters = None
            if param_name in ['use_pc2', 'pc2_quantile']:
                self._pc2_mask = None
                self._clusters = None
            if param_name in ['max_num_clusters']:
                self._clusters = None
            if param_name in ['expand_distance', 'expand_graph_distance', 'join_method']:
                self._expanded_clusters = None
        return parameter_changed

    def _get_mean_curvature_mask(self):
        if self._mean_curvature_mask is None:
            self._mean_curvature_mask = (self.vertex_mean_curvature >= self.params['min_H']) & (self.vertex_mean_curvature <= self.params['max_H'])
        return self._mean_curvature_mask
    def _get_pc2_mask(self):
        if self._pc2_mask is None:
            if self.params['use_pc2']:
                pc2_threshold = np.nanquantile(self.vertex_pc2_values, self.params['pc2_quantile'])
                self._pc2_mask = self.vertex_pc2_values >= pc2_threshold
            else:
                self._pc2_mask = None
        return self._pc2_mask
    
    def _compute_clusters(self):
        if self._clusters is None:
            mean_curvature_mask = self._get_mean_curvature_mask()
            pc2_mask = self._get_pc2_mask()
            if pc2_mask is not None:
                combined_mask = mean_curvature_mask & pc2_mask
            else:
                combined_mask = mean_curvature_mask
            subgraph = self.adj_graph.subgraph(np.argwhere(combined_mask).flatten())
            sorted_components = sorted(list(nx.connected_components(subgraph)), key=lambda x: len(x), reverse=True)
            if self.params['max_num_clusters'] is None or self.params['max_num_clusters'] == 0:
                raise ValueError("Parameter 'max_num_clusters' must be a positive integer.")
            sorted_components = sorted_components[:self.params['max_num_clusters']]
            self._clusters = [ np.array(list(comp)) for comp in  sorted_components ]
        return self._clusters
    
    @staticmethod
    def _expand_nodes(graph : nx.Graph, nodes, dist : int):
        if dist <= 0:
            return set(nodes)
        inflated = set(nodes)
        frontier = set(nodes)
        for _ in range(dist):
            next_frontier = set()
            for node in frontier:
                next_frontier.update(graph.neighbors(node))
            next_frontier -= inflated
            inflated.update(next_frontier)
            frontier = next_frontier
            if not frontier:
                break
        return inflated

    def _expand_clusters(self):
        clusters = self._compute_clusters()
        if self._expanded_clusters is None:
            self._expanded_clusters = []
            for cluster in clusters:
                
                grown_cluster_by_distance = None
                grown_cluster_by_graph_distance = None
                if self.params['expand_distance'] > 0:
                    grown_cluster_by_distance = set(cluster)
                    indices = self.tree.query_ball_point(self.vertices[cluster], r=self.params['expand_distance'])
                    for nearby_indices in indices:
                        grown_cluster_by_distance.update(nearby_indices)
                if self.params['expand_graph_distance'] > 0:
                    grown_cluster_by_graph_distance = self._expand_nodes(self.adj_graph, cluster, self.params['expand_graph_distance'])

                if self.params['join_method'] == 'and':
                    if grown_cluster_by_distance is not None and grown_cluster_by_graph_distance is not None:
                        final_cluster = grown_cluster_by_distance.intersection(grown_cluster_by_graph_distance)
                    elif grown_cluster_by_distance is not None:
                        final_cluster = grown_cluster_by_distance
                    elif grown_cluster_by_graph_distance is not None:
                        final_cluster = grown_cluster_by_graph_distance
                    else:
                        final_cluster = set(cluster)
                elif self.params['join_method'] == 'or':
                    final_cluster = set(cluster)
                    if grown_cluster_by_distance is not None:
                        final_cluster.update(grown_cluster_by_distance)
                    if grown_cluster_by_graph_distance is not None:
                        final_cluster.update(grown_cluster_by_graph_distance)
                else:
                    raise ValueError(f"Unknown join_method: {self.params['join_method']}")
                self._expanded_clusters.append(np.array(list(final_cluster)))
            self._expanded_clusters = self._expanded_clusters
        return self._expanded_clusters
    
    def run(self):
        return self._expand_clusters()
def save_segmentation(segmentation : FoldSegmentation, clusters_annotations : Optional[List[str]] = None, include_geometry :  bool = True, include_curvatures : True = True) -> Dict:
    if segmentation._clusters is None or segmentation._expanded_clusters is None:
        raise ValueError("No clusters to export. Please run the segmentation first.")
    segmentation_dict = {}
    segmentation_dict['params'] = segmentation.params
    if include_geometry:
        segmentation_dict['vertices'] = segmentation.vertices
        segmentation_dict['triangles'] = segmentation.triangles
    if include_curvatures:
        segmentation_dict['vertex_mean_curvature'] = segmentation.vertex_mean_curvature
        segmentation_dict['vertex_pc1_values'] = segmentation.vertex_pc1_values
        segmentation_dict['vertex_pc2_values'] = segmentation.vertex_pc2_values

    if clusters_annotations is not None:
        annotations_dict = {}
        annotation_counter = 0
        for annotation in clusters_annotations:
            name = None
            clusters = None
            if ':' in annotation:
                sub_annotations = annotation.split(':')
                if len(sub_annotations) == 2:
                    name, clusters = sub_annotations
            else:
                clusters = annotation
            if name is None:
                name = "annotation" + str(annotation_counter)
            if '+' in clusters:
                clusters = clusters.split('+')
                clusters = [int(c) for c in clusters]
            else:
                clusters = [int(clusters)]
            annotations_dict[name] = np.unique(np.concatenate([segmentation._expanded_clusters[c] for c in clusters if c < len(segmentation._expanded_clusters)]))
            annotation_counter += 1
        segmentation_dict['segmentations'] = annotations_dict
        segmentation_dict['annotations'] = clusters_annotations
    return segmentation_dict

## Load Meshes
Meshes are loaded from a dict. The dict must contain an entry per mesh with an arbitrary name. Each entry must be a dict with at least a 'path' field, containing the path of the mesh.
We also accept other fields in each mesh dictionary.
The 'scale' key contains either a single float or a 3-tuple of floats to scale the vertices with.
The 'invert_axis' is a boolean 3-tuple which if set inverts each axis of the mesh as:
$$ x' = x_{\rm max} + x_{\rm min} - x$$
Below we also visualize the meshes using Plotly. At the moment, one can modify manually the meshes by specifying which axes to invert (as above). This results in a new 'mesh_info_dict'. Note: to avoid confusion, we do NOT use the original invert_axis field. In other words, this is useful when no prior 'invert_axis' field is specified.

In [71]:
base_dir = Path("/Users/schimmenti/Desktop/DresdenProjects/wingsurface/final_meshes/wildtype/")
#base_dir = Path("/data/biophys/schimmenti/Repositories/wingsurface/final_meshes/wildtype/")
mesh_info_dict = {p.stem: {'path': str(p.absolute())} for p in base_dir.glob("*.ply")}
mesh_info_dict = { key :   mesh_info_dict[key] for key in list(mesh_info_dict.keys())[:2]}

In [72]:
mesh_dataset = load_and_process_meshes(mesh_info_dict, verbose=True)

Processing mesh: 20210125_ecadGFPnbG4_upcrawling_disc2_scale0.5_fused_surface_blender_split: 100%|██████████| 2/2 [00:00<00:00,  5.85it/s]


### View processed meshes

In [None]:
invert_axes_results = { name : mesh_info_dict[name].get('invert_axis', (False, False, False)) for name in mesh_dataset.keys() }
def view_meshes():
    fig = go.FigureWidget()
    fig.update_layout(
        title="",
        width=800, height=400,
        scene=dict(
            xaxis_title='x', yaxis_title='y', zaxis_title='z',
            aspectmode='data',
            uirevision="keep"  # preserve camera/zoom
        ),
        margin=dict(l=50, r=50, t=60, b=0),
        legend=dict(itemsizing='constant')
    )
    palette = px.colors.qualitative.T10
    sample_id_widget = widgets.Dropdown(options=mesh_dataset.keys(), description='Sample')
    close_button_widget = widgets.Button(description='Close')
    tick_widgets = [widgets.Checkbox(description='Invert X'), 
                    widgets.Checkbox(description='Invert Y'), 
                    widgets.Checkbox(description='Invert Z')]
    
    ui = widgets.HBox([
        sample_id_widget, *tick_widgets, close_button_widget])
    def _replot():
        fig.data = []
        sample_id = sample_id_widget.value
        vertices, triangles = mesh_dataset[sample_id]
        mesh_trace = go.Mesh3d(
            x=vertices[:, 0].max() + vertices[:, 0].min() - vertices[:, 0] if tick_widgets[0].value else vertices[:, 0],
            y=vertices[:, 1].max() + vertices[:, 1].min() - vertices[:, 1] if tick_widgets[1].value else vertices[:, 1], 
            z=vertices[:, 2].max() + vertices[:, 2].min() - vertices[:, 2] if tick_widgets[2].value else vertices[:, 2],
            i=triangles[:, 0],
            j=triangles[:, 1],
            k=triangles[:, 2],
            color=palette[0],
            opacity=0.5,
            name='Mesh'
        )
        fig.add_trace(mesh_trace)

    def on_tick_change(change):
        invert_axes_results[sample_id_widget.value] = tuple(tick_widgets[i].value for i in range(3))
        _replot()
    
    def _update_viewer(change):
        if change is not None:
            # we recover the previous settings
            for i in range(3):
                tick_widgets[i].unobserve_all()
                tick_widgets[i].value = invert_axes_results[sample_id_widget.value][i]
        for i in range(3):
            tick_widgets[i].observe(on_tick_change, names='value')
        
        _replot()
    
    _update_viewer(None)
    def on_sample_change(change):
        _update_viewer(change)
    sample_id_widget.observe(on_sample_change, names='value')
    def on_close_button_clicked(b):
        for _ in range(3):
            invert_axes_results[sample_id_widget.value] = tuple(tick_widgets[i].value for i in range(3))
        fig.close_all()
    close_button_widget.on_click(on_close_button_clicked)
    display(ui, fig)
view_meshes()

In [None]:
new_mesh_info_dict = {}
for name, fields in mesh_info_dict.items():
    new_fields = fields.copy()
    new_fields['invert_axis'] = invert_axes_results[name]
    new_mesh_info_dict[name] = new_fields

## Segment folds

In [73]:
saved_segmentations_folder = "segmentations/"
saved_exported_segmentation_files = list(Path(saved_segmentations_folder).glob('*_segmentation.npy'))

In [79]:
use_existing_segmentations = True
default_params_dict = {
    'min_H': -1.0,
    'max_H': -0.1,
    'use_pc2': False,
    'pc2_quantile': 0.0,
    'max_num_clusters': 10,
    'expand_distance': 5.0,
    'expand_graph_distance': 0,
    'join_method': 'or',
}
segmentation_results = {}
annotation_results = {}
pbar = tqdm.tqdm(mesh_info_dict.items())
for name, fields in pbar:
    old_file =[f for f in saved_exported_segmentation_files if f.stem == (name + "_segmentation")]
    if len(old_file) == 1:
        old_file = old_file[0]
    else:
        old_file = None
    if old_file is not None and use_existing_segmentations is True:
        pbar.set_description(f"Loading {name} (found existing segmentation) ")
        saved_segmentation = np.load(old_file, allow_pickle=True).item()
        params = saved_segmentation.get('params', default_params_dict)
        if 'annotations' in saved_segmentation:
            annotation_results[name] = saved_segmentation['annotations']
        else:
            annotation_results[name] = []
    else:
        pbar.set_description(f"Loading {name} ")
        params = default_params_dict
        annotation_results[name] = []
    segmentation_results[name] = FoldSegmentation(params,
                                                    vertices=mesh_dataset[name][0],
                                                    triangles=mesh_dataset[name][1])

Loading 20220517_ecadGFPnbG4_96hAEL_disc6_scale0.5_fused_surface_blender_split (found existing segmentation) :   0%|          | 0/2 [00:00<?, ?it/s]

Loading 20210125_ecadGFPnbG4_upcrawling_disc2_scale0.5_fused_surface_blender_split (found existing segmentation) : 100%|██████████| 2/2 [00:19<00:00,  9.64s/it]


In [85]:
for name, seg in segmentation_results.items():
    print(f"Sample: {name}")
    print(f"Parameters: {seg.params}")
    print(f"Number of clusters: {len(seg._clusters) if seg._clusters is not None else 'Not computed'}")
    print(f"Number of expanded clusters: {len(seg._expanded_clusters) if seg._expanded_clusters is not None else 'Not computed'}")

Sample: 20220517_ecadGFPnbG4_96hAEL_disc6_scale0.5_fused_surface_blender_split
Parameters: {'min_H': -1.0, 'max_H': -0.1, 'use_pc2': False, 'pc2_quantile': 0.0, 'max_num_clusters': 10, 'expand_distance': 0.0, 'expand_graph_distance': 5, 'join_method': 'or'}
Number of clusters: 10
Number of expanded clusters: 10
Sample: 20210125_ecadGFPnbG4_upcrawling_disc2_scale0.5_fused_surface_blender_split
Parameters: {'min_H': -1.0, 'max_H': -0.1, 'use_pc2': False, 'pc2_quantile': 0.0, 'max_num_clusters': 10, 'expand_distance': 0.0, 'expand_graph_distance': 5, 'join_method': 'or'}
Number of clusters: 10
Number of expanded clusters: 10


In [None]:
def segment_folds(
        curvature_step : float = 0.01,
        quantile_step : float = 0.01,
        distance_step : float = 0.1,
):
    fig = go.FigureWidget()
    fig.update_layout(
        title="",
        width=800, height=680,
        scene=dict(
            xaxis_title='x', yaxis_title='y', zaxis_title='z',
            aspectmode='data',
            uirevision="keep"  # preserve camera/zoom
        ),
        margin=dict(l=0, r=0, t=0, b=0),
        legend=dict(itemsizing='constant')
    )
    fig.layout.legend.y = 0.5
    palette = px.colors.qualitative.T10
    sample_id_widget = widgets.Dropdown(options=[], description='Sample')
    show_hide_clusters_widget = widgets.Checkbox(value=True, description='Show/Hide Clusters')
    show_hide_curvature_widget = widgets.Checkbox(value=False, description='Show/Hide Curvature')
    min_H_widget = widgets.FloatText(value=default_params_dict['min_H'], description='Min H', step=curvature_step)
    max_H_widget = widgets.FloatText(value=default_params_dict['max_H'], description='Max H', step=curvature_step)
    use_pc2_widget = widgets.Checkbox(value=default_params_dict['use_pc2'], description='Use PC2')
    pc2_quantile_widget = widgets.BoundedFloatText(value=default_params_dict['pc2_quantile'], description='PC2 Quantile', step=quantile_step, min=0.0, max=1.0)
    max_num_clusters_widget = widgets.BoundedIntText(value=default_params_dict['max_num_clusters'], description='Num Clusters', min=1)
    expand_distance_widget = widgets.BoundedFloatText(value=default_params_dict['expand_distance'],  description='Expand Distance', step=distance_step, min=0.0)
    expand_graph_distance_widget = widgets.BoundedIntText(value=default_params_dict['expand_graph_distance'], description='Expand Graph Distance', min=0)
    join_method_widget = widgets.Dropdown(options=['and', 'or'], value=default_params_dict['join_method'], description='Join Method')
    clusters_selection_widget = widgets.TagsInput(value=[],allow_duplicates=False)
    close_button_widget = widgets.Button(description='Close')
    row1_widget = widgets.HBox(
        [sample_id_widget, widgets.HBox([ widgets.Label("Cluster IDs", tooltip="Save cluster indices as comma-separated values, e.g. 0,2,5. You can merge clusters using '+', e.g. [0+1,2,5]."),
                                           clusters_selection_widget]), 
                                           close_button_widget],
        layout=widgets.Layout(width="100%", justify_content="space-between", align_items="center")
    )
    row2_left_widget = widgets.VBox([
        show_hide_clusters_widget,
        show_hide_curvature_widget,
        min_H_widget, max_H_widget,
        use_pc2_widget, pc2_quantile_widget,
        max_num_clusters_widget,
        expand_distance_widget,
        expand_graph_distance_widget,
        join_method_widget
    ])
    # Box is better than VBox for “just a figure”
    row2_right_widget = widgets.Box(
        [fig],
        layout=widgets.Layout(width="100%", height="100%", overflow="visible")
    )
    # -----------------------------
    grid = widgets.GridspecLayout(2, 2, width="100%", grid_gap="12px")
    grid[0, :] = row1_widget
    grid[1, 0] = row2_left_widget
    grid[1, 1] = row2_right_widget
    grid.layout.grid_template_columns = "320px 1fr"
    grid.layout.grid_template_rows = "auto 1fr"
    grid.layout.align_items = "flex-start"
    row2_left_widget.layout = widgets.Layout(
        width="100%",
        height="680px",
        overflow_y="auto",
        overflow_x="hidden",
        align_self="flex-start"
    )
    # -----------------------------
    row2_right_widget.layout = widgets.Layout(
        width="100%",
        height="680px",
        align_self="stretch"
    )
    def _plot_mesh(delete_all):
        if delete_all:
            fig.data = []
        else:
            fig.data = tuple(t for t in fig.data if t.name != 'mesh')
        sample_id = sample_id_widget.value
        vertices, triangles = mesh_dataset[sample_id]
        mesh_trace = go.Mesh3d(
            x=vertices[:, 0],
            y=vertices[:, 1], 
            z=vertices[:, 2],
            i=triangles[:, 0],
            j=triangles[:, 1],
            k=triangles[:, 2],
            color=palette[0],
            opacity=0.5,
            name='mesh'
        )
        if show_hide_curvature_widget.value:
            mesh_trace.intensity = segmentation_results[sample_id].vertex_mean_curvature
            mesh_trace.colorscale = 'RdBu'
            mesh_trace.cmid = 0.0
            mesh_trace.cmin, mesh_trace.cmax = np.nanquantile(segmentation_results[sample_id].vertex_mean_curvature, [0.05, 0.95])
            mesh_trace.colorbar.x = 0.95
            mesh_trace.colorbar.y = 0.5
        else:
            mesh_trace.intensity = None
        fig.add_trace(mesh_trace)
    def _show_hide_clusters(b):
        for trace in fig.data:
            if 'cluster' in trace.name:
                if show_hide_clusters_widget.value:
                    trace.visible = True
                else:
                    trace.visible = False    
    def _plot_clusters():
        name = sample_id_widget.value
        segmentation = segmentation_results[name]
        expanded_clusters = segmentation.run()
        if len(fig.data) > 0:
            fig.data = tuple(t for t in fig.data if 'cluster' not in t.name)
        for i, exanded_cluster in enumerate(expanded_clusters):
            cluster_trace = go.Scatter3d(
                x=segmentation.vertices[exanded_cluster, 0],
                y=segmentation.vertices[exanded_cluster, 1],
                z=segmentation.vertices[exanded_cluster, 2],
                mode='markers',
                marker=dict(size=2, color=palette[(i % (len(palette)-1)) + 1]),
                name=f'cluster_{i}'
            )
            cluster_trace.visible = True if show_hide_clusters_widget.value else False
            fig.add_trace(cluster_trace)   
    def _on_parameter_change(change):
        name = sample_id_widget.value
        segmentation = segmentation_results[name]
        segmentation.update_parameter('min_H', min_H_widget.value)
        segmentation.update_parameter('max_H', max_H_widget.value)
        segmentation.update_parameter('use_pc2', use_pc2_widget.value)
        segmentation.update_parameter('pc2_quantile', pc2_quantile_widget.value)
        segmentation.update_parameter('max_num_clusters', max_num_clusters_widget.value)
        segmentation.update_parameter('expand_distance', expand_distance_widget.value)
        segmentation.update_parameter('expand_graph_distance', expand_graph_distance_widget.value)
        segmentation.update_parameter('join_method', join_method_widget.value)
        
        _plot_clusters()
    def _on_annotation_change(change):
        name = sample_id_widget.value
        annotation_results[name] =  clusters_selection_widget.value
    


    def _on_sample_change(change):
        name = sample_id_widget.value
        segmentation = segmentation_results[name]
        #show_hide_clusters_widget.value = True
        #show_hide_curvature_widget.value = False
        min_H_widget.value = segmentation.params['min_H']
        max_H_widget.value = segmentation.params['max_H']
        use_pc2_widget.value = segmentation.params['use_pc2']
        pc2_quantile_widget.value = segmentation.params['pc2_quantile']
        max_num_clusters_widget.value = segmentation.params['max_num_clusters']
        expand_distance_widget.value = segmentation.params['expand_distance']
        expand_graph_distance_widget.value = segmentation.params['expand_graph_distance']
        join_method_widget.value = segmentation.params['join_method']
        if name in annotation_results and annotation_results[name] is not None:
            clusters_selection_widget.value = annotation_results[name]
        else:                
            clusters_selection_widget.value = []
        _plot_mesh(True)
        _plot_clusters()
    sample_id_widget.observe(_on_sample_change, names='value')
    sample_id_widget.options = list(mesh_dataset.keys())
    sample_id_widget.value = list(mesh_dataset.keys())[0]
    for widget in [min_H_widget, max_H_widget, use_pc2_widget, pc2_quantile_widget, max_num_clusters_widget,
                   expand_distance_widget, expand_graph_distance_widget, join_method_widget]:
        widget.observe(_on_parameter_change, names='value')

    show_hide_clusters_widget.observe(_show_hide_clusters, names='value')
    show_hide_curvature_widget.observe(lambda change: _plot_mesh(False), names='value')
    clusters_selection_widget.observe(_on_annotation_change, names='value')
    display(grid)
segment_folds()

GridspecLayout(children=(HBox(children=(Dropdown(description='Sample', options=('20220517_ecadGFPnbG4_96hAEL_d…

### Save results

In [60]:
for name, seg in segmentation_results.items():
    clusters_annotations = annotation_results[name] if name in annotation_results else None
    segmentation_dict = save_segmentation(seg, clusters_annotations=clusters_annotations, include_geometry=True, include_curvatures=True)
    with open(f"segmentations/{name}_segmentation.npy", 'wb') as f:
        np.save(f, segmentation_dict)