In [461]:
import numpy as np
from pathlib import Path
from fastdtw import fastdtw
import igl
import networkx as nx
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from numpy.typing import NDArray

In [489]:
class FoldAnnotation:
    @staticmethod
    def _align_paths_with_dtw(paths):
        from scipy.spatial.distance import euclidean
        # Choose a reference path (first path in this case)
        ref_path = paths[0]

        # List to store aligned paths
        aligned_paths = []

        for path in paths:
            # Compute the DTW distance and the alignment path
            distance, alignment_path = fastdtw(ref_path, path, dist=euclidean)
            # Align path to the reference path
            aligned_path = np.array([path[idx] for idx in list(zip(*alignment_path))[1]])
            aligned_paths.append(aligned_path)
        return aligned_paths
    @staticmethod
    def _resample_path(path, num_points):
        from scipy.interpolate import interp1d
        original_indices = np.linspace(0, 1, len(path))
        target_indices = np.linspace(0, 1, num_points)

        interpolator = interp1d(original_indices, path, axis=0, kind='linear', fill_value="extrapolate")
        resampled_path = interpolator(target_indices)
        return resampled_path
    @staticmethod
    def _average_aligned_paths(aligned_paths, num_points):
        resampled_paths = [FoldAnnotation._resample_path(path, num_points) for path in aligned_paths]
        stacked_paths = np.stack(resampled_paths)
        avg_path = np.mean(stacked_paths, axis=0)
        err_path = np.sqrt(np.var(stacked_paths, axis=0).sum(axis=-1))
        return avg_path, err_path
    
    def __init__(self,
                 vertices, 
                 triangles, 
                 vertex_mean_curvature, 
                 indices,
                 path_through_positive_curvature : bool,
                 alpha  : float,
                 vertex_adj_list = Optional[List[List[int]]],
                 path_quantile_level= 0.1,
                 restrict_path_to_boundary = False):
        self.vertices = vertices
        self.triangles = triangles
        self.vertex_mean_curvature = vertex_mean_curvature
        self.indices = indices    
        self.path_through_positive_curvature = path_through_positive_curvature
        self.path_quantile_level = path_quantile_level
        if vertex_adj_list is not None:
            self.vertex_adj_list = vertex_adj_list
        else:
            self.vertex_adj_list = igl.adjacency_list(triangles)
        self.vertex_adj_graph = nx.from_dict_of_lists({i: nbrs for i, nbrs in enumerate(self.vertex_adj_list)})
        from_nodes = np.repeat(np.arange(0, len(self.vertices)), [len(nbrs) for nbrs in self.vertex_adj_list])
        to_nodes = np.concatenate(self.vertex_adj_list)
        weights =  np.exp( ( -1.0 if self.path_through_positive_curvature else 1.0) * alpha * (self.vertex_mean_curvature[from_nodes] + self.vertex_mean_curvature[to_nodes])/2.0)
        weights[~np.isfinite(weights)] = np.inf
        self.vertex_adj_graph.add_weighted_edges_from(zip(from_nodes, to_nodes, weights))
        self.restrict_path_to_boundary = restrict_path_to_boundary
        if self.restrict_path_to_boundary:
            self.boundary_region_indices = np.unique(np.array(list(nx.edge_boundary(self.vertex_adj_graph, self.indices)))[:,0])
        else:
            self.boundary_region_indices = None

        self._start_indices = None
        self._end_indices = None
        self._paths = None
        self._path_weights = None
        self._triangle_centers = np.mean(self.vertices[self.triangles], axis=1)
        self._triangle_normals = igl.per_face_normals(self.vertices, self.triangles)
        self._paths_on_vertices = None
        self._paths_on_triangles = None
        self._paths_on_triangles_projections = None

        self._merged_path = None
        self._merged_path_error = None
    
    def compute_path_boundaries(self) -> Tuple[NDArray[np.int_], NDArray[np.int_]]:
        '''
            Estimates using PCA on vertices selected by 'indices' the extremal vertices, taken along the direction of largest variance. 
            We retain points in the lower and upper quantiles along this direction, as determined by 'path_quantile_level' i.e. 
            quantiles at levels self.path_quantile_level and 1.0 - self.path_quantile_level.
        '''
        if self._start_indices is not None and self._end_indices is not None:
            return self._start_indices, self._end_indices
        points = self.vertices[self.indices]
        points_cov = np.cov(points, rowvar=False, bias=True)
        vals, vecs = np.linalg.eigh(points_cov)
        largest_eigvec = vecs[:, -1]
        projected_points = (points - points.mean(axis=0, keepdims=True)) @ largest_eigvec
        quantiles = np.quantile(projected_points, [self.path_quantile_level, 1.0 - self.path_quantile_level])
        start_indices =  self.indices[(projected_points < quantiles[0])]
        end_indices = self.indices[(projected_points > quantiles[1])]
        if self.restrict_path_to_boundary:
            start_indices = np.intersect1d(start_indices, self.boundary_region_indices)
            end_indices = np.intersect1d(end_indices, self.boundary_region_indices)
            if len(start_indices) == 0 or len(end_indices) == 0:
                raise ValueError("No start or end indices found on boundary region. Consider disabling restrict_path_to_boundary.")
        self._start_indices = np.unique(start_indices)
        self._end_indices = np.unique(end_indices)
        return self._start_indices, self._end_indices   
    def compute_paths(self) -> Tuple[List[NDArray[np.int_]], NDArray[np.float64]]:
        '''
            We find using Djikstra's algorithm the shortest paths between all pairs of start and end indices, as determined by 'compute_path_boundaries'.
            We employ as weights between vertex 'i' and 'j' the value exp( alpha * (curvature(i) + curvature(j))/2.0) if 'path_through_positive_curvature' is True, 
            and exp( - alpha * (curvature(i) + curvature(j))/2.0) otherwise, where curvature(i) is the mean curvature at vertex i.
        '''
        if self._paths is not None:
            return self._paths, self._path_weights
        subgraph = self.vertex_adj_graph.subgraph(self.indices)
        start_indices, end_indices = self.compute_path_boundaries()
        paths = []
        path_weights = []
        for v1 in start_indices:
            distances_1, paths_1 = nx.algorithms.single_source_dijkstra(subgraph, v1, weight='weight')
            for v2 in end_indices:
                if(v2  in paths_1):
                    path = paths_1[v2]
                    dist = distances_1[v2]
                    if not np.isfinite(dist):
                        continue
                    paths.append(np.array(path))
                    path_weights.append(dist)
        self._paths = paths
        self._path_weights = np.array(path_weights)
        return self._paths, self._path_weights
    def visualize_paths(self, fig : go.Figure, color = 'lightgrey', fraction_to_display = 1.0):
        mesh_trace = go.Mesh3d(
            x=self.vertices[:, 0],
            y=self.vertices[:, 1], 
            z=self.vertices[:, 2],
            i=self.triangles[:, 0],
            j=self.triangles[:, 1],
            k=self.triangles[:, 2],
            color=color,
            opacity=0.5,
            name='mesh'
        )
        palette = px.colors.qualitative.Plotly
        fig.add_trace(mesh_trace)
        paths, _ = self.compute_paths()
        paths_to_display = np.arange(len(paths)) if fraction_to_display >= 1.0 else np.random.choice(len(paths), int(len(paths) * fraction_to_display), replace=False)
        for cnt, i in enumerate(paths_to_display):
            fig.add_trace(go.Scatter3d(
                x=self.vertices[paths[i], 0],
                y=self.vertices[paths[i], 1],
                z=self.vertices[paths[i], 2],
                mode='lines',
                opacity=0.8,
                line=dict(color=palette[cnt % len(palette)], width=5),
            ))        
    def compute_path_vertices(self) -> List[NDArray[np.float64]]:
        '''
            For each path, we return the 3D coordinates of the vertices along the path. 
        '''
        if self._paths_on_vertices is not None:
            return self._paths_on_vertices
        paths, _ = self.compute_paths()
        path_vertices = [self.vertices[path] for path in paths]
        self._paths_on_vertices = path_vertices
        return path_vertices
    def compute_path_triangles(self) -> Tuple[List[NDArray[np.int_]], List[NDArray[np.float64]]]:
        '''
            For each vertex in each path, we find the triangle it belongs to whose plane is closest to the vertex, and return the indices of these triangles, as well as the projection of the vertex onto the plane of the triangle. 
        '''
        if self._paths_on_triangles is not None and self._paths_on_triangles_projections is not None:
            return self._paths_on_triangles, self._paths_on_triangles_projections
        paths, _ = self.compute_paths()
        self._paths_on_triangles = []
        self._paths_on_triangles_projections = []
        for path in paths:
            candidate_triangles = np.argwhere(np.isin(self.triangles, path).any(axis=1)).flatten()
            
            projection = np.sum(self._triangle_normals[None,candidate_triangles,:] * (self.vertices[path][:,None,:] - self._triangle_centers[None, candidate_triangles, :]), axis=-1)
            correct_indices = np.argmin(np.abs(projection), axis=1)
            self._paths_on_triangles.append(candidate_triangles[correct_indices])
            self._paths_on_triangles_projections.append(projection[np.arange(len(path)), correct_indices])
        return self._paths_on_triangles, self._paths_on_triangles_projections 
    def merge_paths(self, num_points):
        '''
            We merge the paths found by 'compute_paths' into a single path, by first aligning them using Dynamic Time Warping (DTW) and then averaging the aligned paths. 
            The resulting path is resampled to contain 'num_points' points. We also return an error estimate for each point of the merged path, computed as the standard deviation of the aligned paths at that point.
        '''
        if self._merged_path is not None and self._merged_path_error is not None:
            return self._merged_path, self._merged_path_error
        
        aligned_paths = self._align_paths_with_dtw(self.compute_path_vertices())
        avg_path, err_path = self._average_aligned_paths(aligned_paths, num_points=num_points)
        self._merged_path = avg_path
        self._merged_path_error = err_path
        return avg_path, err_path

In [490]:
source_folder = 'segmentations/'
segmentation_files = list(Path(source_folder).glob('*_segmentation.npy'))

In [491]:
segmentations = np.load(segmentation_files[1], allow_pickle=True).item()

In [492]:
fold_annotation = FoldAnnotation(segmentations['vertices'],
                                 segmentations['triangles'],
                                 segmentations['vertex_mean_curvature'],
                                 segmentations['segmentations']['hinge'],
                                path_through_positive_curvature=False,
                                 alpha=10.0,
                                 vertex_adj_list = segmentations['vertex_adj_list'],
                                 restrict_path_to_boundary=True)