In [None]:
#|default_exp ricci.rays
## Standard libraries
import os
import math
import numpy as np
import time
from fastcore.all import *
from nbdev.showdoc import *
# Configure environment
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' # Tells Jax not to hog all of the memory to this process.

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

## Progress bar
from tqdm.auto import tqdm

## project specifics
import diffusion_curvature
import pygsp
import jax
import jax.numpy as jnp
jax.devices()

from diffusion_curvature.graphs import *
from diffusion_curvature.datasets import *
from diffusion_curvature.core import *
from diffusion_curvature.utils import *
from diffusion_curvature.comparison_space import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# 3a Diffusion Rays
> Extending edges to scales where geometry is meaningful

To obtain our Ricci curvature on nodes, it's desirable to have edges. Of course, nodes already come with edges, but these edges are usually too small for meaningful geometric analysis. Modifying them doesn't change much. They lie in the noise domain, in which everything is either locally Euclidean, or where it deviates is due to noise. 

This is a method for extending the directions suggested by local edges across larger parts of the data by, in effect, shooting geodesic rays from each point. To do this, we build a diffusion map, sending the points to high-dimensional representations of themselves, in which the Euclidean distance roughly corresponds to the manifold distance, with some asterisks. We can then take the directions suggested by edges as vectors, and extend them further into the diffusion space. Locally, at least, this gives a convincing approximation of a geodesic ray shot from the point.

# Implementation

In [None]:
#|export
from sklearn.preprocessing import normalize
import warnings

class DiffusionRays:
    """
    """

    def __init__(
        self,
        G, # Graph with W for affinity matrix
        t=25, # Scaling parameter for diffusion maps
        knn=5, # Number of rays to construct around each point, based in k nearest neighbors
        num_steps=20, # Num points within each ray #TODO: would be best to discern this dynamically
        percent_manifold_to_cover=0.3, # Size of each ray
        # n_evecs=10, #Dimension of diffusion map space
        # radius=0.3,
    ):
        self.t = t
        self.knn = knn
        self.num_steps = num_steps
        # self.percent_of_manifold_to_cover = percent_of_manifold_to_cover
        self.name = "Diffusion Ray Curvature"  # to be programmatically accessed by printing functions, e.g.
        self.radius = percent_manifold_to_cover
        # dmap = diffusion_map.DiffusionMap.from_sklearn(
        # epsilon=0.15, alpha=0.5, n_evecs=n_evecs,
        # )
        # self.diffusion_coordinates = 
        # dmap.fit_transform(X)
        self.graph = G
        self.A = G.W.toarray()
        self.num_points = len(self.A)
        D = np.diag(1 / np.sum(self.A, axis=1) ** 0.5)
        # # Compute symmetric diffusion operator
        # M = normalize(G.W, norm="l1", axis=1)
        self.Ms = D @ self.A @ D # TODO: can we keep things sparse for longer
        # # eigendecompose # TODO: DEMD already does eigendecomposition with fast algorithms. Can we reuse that?
        # # Create diffusion map and diffusion coordinates (basis of diffusion distance)
        print("Eigendecomposing diffusion matrix")
        self.E, self.V = np.linalg.eigh(self.Ms)
        # # correct eigenvecs of Ms to M
        self.V = D @ self.V
        # print("Building diffusion coordinates")
        self.diffusion_coordinates = self.V * (self.E ** self.t)

    def diffusion_distances_to(self, i):
        return np.linalg.norm(
            self.diffusion_coordinates
            - (
                np.ones_like(self.diffusion_coordinates)
                @ np.diag(self.diffusion_coordinates[i])
            ),
            axis=1,
        )

    def rays(self, i):
        """
        Returns indxs of diffusion rays around point i
        """
        # Find max diffusion distance from i
        distances_to_i = self.diffusion_distances_to(i)
        max_dist_to_i = np.max(distances_to_i)
        print(max_dist_to_i, 'max dist')
        # Convert this distance to a radius of inclusion
        radius = self.radius * max_dist_to_i
        print("radius", radius)
        # find k nearest neighbors for i.
        nn = np.argsort(distances_to_i)  # sorts the adjacency matrix
        knn = nn[1 : self.knn + 1]  # takes the k values with highest affinity
        # Simple proof of concept: can be heavily optimized
        # Loop through nearest neighbors and assemble rays for each
        rays = np.zeros((self.knn, self.num_steps, self.diffusion_coordinates.shape[1]))
        ray_coords = np.zeros((self.knn, self.num_steps), dtype=np.int)  # for debugging
        x = self.diffusion_coordinates[i]
        for m, k in enumerate(knn):
            # Loop through all points in the dataset and compute the distance from $p$ to the ray that passes through x and y
            point_dists = []
            y = self.diffusion_coordinates[k]
            normalized_ray_direction = (y - x) / np.linalg.norm(y - x)
            # Only consider coordinates within a specified radius from the central point
            points_within_radius = (distances_to_i <= radius).nonzero()[
                0
            ]  # gives the indices of points within the radius from i
            # TODO: This step could be heavily optimized
            print(points_within_radius.shape)
            for n, j in enumerate(points_within_radius):
                p = self.diffusion_coordinates[j]
                # length of the hypotenuse from x to p
                c2 = np.linalg.norm(p - x) ** 2
                # length of ray to p's projection onto y-x
                a2 = (np.dot(normalized_ray_direction, p - x)) ** 2
                # length of shortest side from p to closest point on ray y-x
                b2 = c2 - a2
                point_dists.append(b2)
            # take the closest num_step points
            # print("Sorted point dists",np.sort(point_dists))
            ray_coords[m] = points_within_radius[
                np.argsort(point_dists)[: self.num_steps]
            ]
            # print("have ray coords",ray_coords[m])
        return ray_coords

    def pointwise_curvature(self):
        # returns an [n_points] sized array of pointwise curvatures.
        # TODO: How can we speed up redundant distance calculations?
        # TODO: We probably don't need to compute the curvature of every point. Can we sample points, and then average the curvatures around them?
        curvatures = np.empty(self.num_points)
        for i in trange(self.num_points):
            curvatures[i] = self.curvature(i)[0]
        return curvatures

In [None]:
"""
Utilities for estimating geodesics on a graph
"""
import numpy as np
import numpy as np
import graphtools
import time
from tqdm.notebook import trange
from sklearn.metrics.pairwise import euclidean_distances


class LineMAGICv1:
    def __init__(self, A, X):
        """
        Needs an affinity matrix for future calculations.
        """
        self.A = A.toarray()
        self.X = X

    def average_neighbors(self, line_idxs, t=1):
        """
        Takes a list of idxs for a line (assumed to be ordered)
        Returns a new line, comprised of the closest points to the averages of each point's neighbors.
        """
        # 1. Construct a local diffusion operator containing just the transition probabilities along the line.
        # TODO: Could improve by making a different type of kernel, as in SUGAR, to incorporate information from surrounding neighbors.
        # line_idxs = line.nonzero()[0]
        local_A = self.A[np.ix_(line_idxs, line_idxs)]
        # add artificially high degree between the endpoints and their neighbors
        for i in range(len(line_idxs)):
            local_A[i, 0] = np.sum(local_A[i, i + 1 : -1])
            local_A[i, -1] = np.sum(local_A[i, 1 : i - 1])
        #
        # local_A[0,1] = 1
        # local_A[1,0] = 1
        # local_A[-1,-2] = 1
        # local_A[-2,-1] = 1 #(This may be overwhelming. We'll see.)

        D = np.sum(local_A, axis=1)
        print("local a has shape", local_A.shape, "D has shape", D.shape)
        P = local_A / D  # [:, None]
        P = np.linalg.matrix_power(P, t)
        # 2. Apply local diffusion operator to the coordinates of the points on the line
        local_x = self.X[line_idxs]
        averaged_coords = P @ local_x
        # 3. Find the actual points which are closest to each of the averaged points
        # TODO: Here we assume a (local) euclidean metric on the underlying points. Can this assumption be relaxed?
        new_line = [line_idxs[0]]
        for avg in averaged_coords:
            distances = np.linalg.norm(self.X - avg, axis=1)
            sorted = np.argsort(distances)
            new_line.append(sorted[0])
            # new_line.append(sorted[1])
        new_line.append(line_idxs[-1])
        # new_line_signal = np.zeros_like(line)
        # new_line_signal[new_line] = 1
        return new_line

    def iterative_averaging(self, line, num_iterations, t=1):
        for i in range(num_iterations):
            line = self.average_neighbors(line, t=t)
        return line


class MIDGeodesic:
    def __init__(
        self,
        X,
        P,
        D,
        neighborhood_size,
        min_intersection_size=1,
        max_intersection_size=20,
        threshold_of_intersection=0,
        num_iterations=4,
    ):
        """
        Prepares for MIDRay geodesic computation by
        1. Taking powers of the diffusion operator
        As inputs,
        X - the raw points
        P - n x n ndarray - the diffusion operator of your graph. Can be sparse.
        D - n x n ndarray - the euclidean distances between points on your graph. (If the graph doesn't come from a pointcloud in R^n, you can just supply some constant distance between all connected neighbors.)
        To compute the curvature of a point, call MIDRayCurvature.curvature_of_point(i)
        To compute all of the curvatures at once, call MIDRayCurvature.curvatures()
        """
        self.X = X
        self.Pt = [P]  # The stored powers of the diffusion operator
        start = time.time()
        print(f"Taking {neighborhood_size} powers of the diffusion operator.")
        for i in range(
            neighborhood_size
        ):  # TODO: there may be a more efficient way to do this...
            self.Pt.append(self.Pt[-1] @ P)
        end = time.time()
        print(f"Finished in {end-start}s")
        self.D = D
        self.neighborhood_size = neighborhood_size
        self.min_intersection_size = min_intersection_size
        self.max_intersection_size = max_intersection_size
        self.threshold_of_intersection = threshold_of_intersection
        self.intersections_cached = []  # for debugging
        self.num_iterations = num_iterations

    def midpoint(self, x, y, previous_indices_of_intersection=None, recursion_number=0):
        """
        Finds first scale at which diffusions from i and j have nonempty intersection.
        If the intersection size is greater than self.min_intersection_size but less than self.max_intersection_size
        returns the point closest to the euclidean center of the intersection.
        If the intersection size meets the minimum threshold but exceeds the max_intersection_size, calculates the diffusions
        from each endpoint and considers the intersection of this intersection and the prior intersection. Repeats recursively
        until the size of the intersection is lower than max_intersection_size.
        As inputs:
        x, y - the indices of the points to find a midpoint between.
        previous_intersection - for internal use when recursively finding intersecting intersections.
        """
        if x == y:
            return x  # A hacky way to make the number of points match up across MID geodesics
        indices_of_intersection = []
        t = 0
        while len(indices_of_intersection) < self.min_intersection_size:
            # take one step of diffusion
            diffused_x = self.Pt[t][x].toarray()[0]
            diffused_y = self.Pt[t][y].toarray()[0]
            # remove points from each diffusion that fall below the threshold
            diffused_x = diffused_x > self.threshold_of_intersection  # * diffused_x
            diffused_y = diffused_y > self.threshold_of_intersection  # * diffused_y
            # find the intersection as the product of the diffusions
            intersection = diffused_x * diffused_y
            print(f" for {t} found sum of intersection", np.sum(intersection))
            indices_of_intersection = intersection.nonzero()[0]
            # print(indices_of_intersection)
            # if this function is taking the midpoint of endpoints of a previous intersection, we limit this intersection to the previous intersection
            if previous_indices_of_intersection is not None:
                indices_of_intersection = np.intersect1d(
                    indices_of_intersection, previous_indices_of_intersection
                )
            t += 1
            if t >= self.neighborhood_size:
                raise ValueError(
                    f"Cannot take the midpoint between {x} and {y}: points are outside of the specified neighborhood radius. Did not compute sufficient diffusion powers. Try reinitializing MIDRayCurvature with a higher neighborhood_size"
                )
        # optional step: chop off outliers in the intersections
        # TODO
        # if there are too many nodes in the intersection, we can narrow further by taking the midpoints of the endpoints of the intersection
        print(
            f"{recursion_number} - Found {len(indices_of_intersection)} nodes in the intersection (v {self.max_intersection_size})"
        )
        if (
            len(indices_of_intersection) > self.max_intersection_size
            and recursion_number < 4
        ):
            print(f"{recursion_number} - Recursively taking another intersection...")
            # find distances to some random point from each point in the intersection
            distances_to_rando = self.D[indices_of_intersection[0]][
                indices_of_intersection
            ]
            # find the index of the point with the maximum such distance. Call this the first endpoint.
            new_x = indices_of_intersection[np.argmax(distances_to_rando)]
            # find the point in the intersection furthest from this new x; call it new y
            distances_to_new_x = self.D[x][indices_of_intersection]
            new_y = indices_of_intersection[np.argmax(distances_to_new_x)]
            # take the midpoint between these
            midpoint = self.midpoint(
                new_x,
                new_y,
                previous_indices_of_intersection=indices_of_intersection,
                recursion_number=recursion_number + 1,
            )
        else:

            small_neighborhood_of_midpoints = self.X[indices_of_intersection]
            # if not, take the point closest to the euclidean center
            euclidean_center = np.mean(small_neighborhood_of_midpoints)
            # get distances to euclidean center
            distances_to_euclidean_center = np.linalg.norm(
                small_neighborhood_of_midpoints - euclidean_center
            )
            # find closest candidate
            midpoint = indices_of_intersection[np.argmin(distances_to_euclidean_center)]
        return midpoint

    def insert_midpoints(self, sorted_list):
        """
        Internal function.
        Given a list of sorted indices, inserts the index of the midpoint between each consecutive pair of points, and returns a new list.
        """
        new_list = []
        for idx1, idx2 in zip(sorted_list[:-1], sorted_list[1:]):
            m = self.midpoint(idx1, idx2)
            if m == idx1:
                new_list.append(idx1)
            else:
                new_list.extend([idx1, m])
        new_list.append(idx2)
        return new_list

    def recursive_midpoints(self, sorted_list, num_iterations=4):
        # print(num_iterations)
        for i in range(num_iterations):
            # print(f"iteration {i}")
            sorted_list = self.insert_midpoints(sorted_list)
        return sorted_list

    def geodesic(self, start, end):
        se = [start, end]
        return self.recursive_midpoints(se, num_iterations=self.num_iterations)


class MIDMAGICv1(object):
    """
    Choose points on manifold which lie along the approximate geodesic between two points.
    """

    def __init__(
        self,
        X,
        A,
        P,
        num_midpoint_iterations=4,
        num_averaging_iterations=6,
        neighborhood_size=5,
    ):
        D = euclidean_distances(X)
        self.neighborhood_size = neighborhood_size
        self.linemagic = LineMAGICv1(A, X)
        self.midrays = MIDGeodesic(
            X, P, D, neighborhood_size, num_iterations=num_midpoint_iterations
        )
        self.num_averaging_iterations = num_averaging_iterations
        self.Pts = self.midrays.Pt

    def geodesic(self, start, end):
        rough_geodesic = self.midrays.geodesic(start, end)
        magicked_line = self.linemagic.iterative_averaging(
            rough_geodesic, self.num_averaging_iterations
        )
        return magicked_line


# Tests

In [None]:
X, ks = torus(2000, use_guide_points=True)
A = get_adaptive_graph(X)



In [None]:
DRAY = DiffusionRays(
    A, 
    percent_manifold_to_cover = 0.3,
    t=12
)

Eigendecomposing diffusion matrix


In [None]:
ray_idxs = DRAY.rays(1)

0.013351628015262173 max dist
radius 0.004005488404578652
(55,)
(55,)
(55,)
(55,)
(55,)


In [None]:
signal = np.zeros(len(X))
signal[ray_idxs[2]] = 1
signal[1]=0.5
plot_3d(X, signal, use_plotly=True)

## Export

In [None]:
# sync changes to the library
from IPython.display import display, Javascript
import time
display(Javascript('IPython.notebook.save_checkpoint();'))
time.sleep(2)
!pixi run nbsync