In [2]:
#| default_exp core
from diffusion_curvature.kernels import *
from diffusion_curvature.datasets import *
from fastcore.all import *
import numpy as np
from nbdev import show_doc

# Implementation (PyGSP + JAX)
> Curvature computations on any graphtools graph

This notebook implements diffusion curvature atop the popular PyGSP library. To compute the curvature of any PyGSP graph, simply instantiate a `DiffusionCurvature` object with your choice of parameters, and pass the graphtools graph through as input.

What follows is a literate implementation, showing the steps of the algorithm applied to our old friend, the torus.

The implementation of Diffusion Curvature involves several big pieces, each of which can be performed with different strategies:

1. Simulating heat diffusion on the manifold, either via powering the diffusion matrix, or by Chebyshev approximation of the heat equation using the graph laplacian.
2. Computing the "spreads" of diffusion. This can be done either via the entropy or the Wasserstein distance.
3. Constructing a comparison space of approximately the same sampling as the input graph.
4. (Experimental) Verification that the above is working by differentiating the spreads of diffusion over time.

We implement everything generically in JAX (a high performance numpy replacement, which can compile to the GPU), treating each of the above as modules that can be parametrically tuned. Functional programming is our game: each function takes a graph object as input and returns an updated graph object with the required quantities computed. 

In [3]:
# Our sample dataset for testing the rest of the notebook
from diffusion_curvature.datasets import torus
X_torus, ks_torus = torus(5000,use_guide_points=True)

# Graph Construction
Our 'Graphs' notebook has code to create PyGSP graphs from pointcloud data, in several varieties. 
We also provide heuristics to sanity check the graphs, as well as choose the optimal parameters.
TODO: MAKE THESE HEURISTICS

In [4]:
from diffusion_curvature.graphs import get_alpha_decay_graph, get_knn_graph, get_scanpy_graph
G_torus = get_alpha_decay_graph(X_torus,knn=15,decay=20,anisotropy=1.0)

# Simulated Heat Diffusion

In [4]:
#|export
from typing import Literal, get_args, get_origin
from inspect import getfullargspec
def enforce_literals(function):
    """Decorator that raises AssertionError on Literal check failure."""

    def decorator(*args, **kwargs):
        specs = getfullargspec(function)
        args = {key: args[i] for i, key in enumerate(specs.args) if i < len(args)}
        #  key_values = args | kwargs  # use this if python >= 3.9
        key_values = {**args_new, **kwargs}  # this is for python 3.8

        for name, type_ in getfullargspec(function).annotations.items():
            value = key_values[name]
            options = get_args(type_)
            if (
                get_origin(type_) is Literal
                and name in specs.args
                and value not in options
            ):
                raise AssertionError(f"'{value}' is not in {options} for '{name}'")
        return function(*args, **kwargs)

    return decorator


_TYPES = Literal["solar", "view", "both"]
_NUMS = Literal[1, 2, 3, 4, 5]

In [6]:
#| export
import pygsp
import jax
import jax.numpy as jnp
from fastcore.all import *
import skdim

from diffusion_curvature.graphs import diff_aff, diff_op
from diffusion_curvature.heat_diffusion import heat_diffusion_on_signal, kronecker_delta, jax_power_matrix
from diffusion_curvature.diffusion_laziness import wasserstein_spread_of_diffusion, entropy_of_diffusion
from diffusion_curvature.distances import phate_distances
from diffusion_curvature.comparison_space import EuclideanComparisonSpace, fit_comparison_space_model
import diffusion_curvature

_DIFFUSION_TYPES = Literal['diffusion matrix','heat kernel']
_LAZINESS_METHOD = Literal['Wasserstein','Entropic']
_COMPARISON_METHOD = Literal['Ollivier', 'Subtraction']

@enforce_literals
class DiffusionCurvature():
    def __init__(
            self,
            diffusion_type:_DIFFUSION_TYPES = 'diffusion matrix', # Either ['diffusion matrix','heat kernel']
            laziness_method: _LAZINESS_METHOD = 'Wasserstein', # Either ['Wasserstein','Entropic']
            comparison_method: _COMPARISON_METHOD = 'Ollivier',
            distance_method:function = phate_distances,
            dimest = None, # Dimension estimator to use. If none, defaults to kNN.
            different_comparison_space_for_each_point = True, # If true, constructs a comparison space for every point in the manifold. If false, only constructs unique comparison spaces for each unique dimension.
    ):
        store_attr()
        self.D = None
        if self.dimest is None:
            self.dimest = skdim.id.KNN()
    def unsigned_curvature(
            self,
            G:pygsp.graphs.Graph, # PyGSP input Graph
            t:int, # Scale at which to compute curvature; number of steps of diffusion.
            idx=None, # the index at which to compute curvature. If None, computes for all points. TODO: Implement
            # The below are used internally
            _also_return_first_scale = False, # if True, calculates the laziness measure at both specified t and t=1. The distances (if used) are calcualted with the larger t.
            D = None, # Supply manifold distances yourself to override their computation. Only used with the Wasserstein laziness method.
    ):
        n = G.L.shape[0]
        # Compute diffusion matrix
        match self.diffusion_type:
            case 'diffusion matrix':
                P = diff_op(G)
                P = jnp.array(P)
                Pt = jax_power_matrix(P,t) 
                if idx: Pt = Pt[idx] # TODO: Could be more efficient here
            case 'heat kernel':
                signal = jnp.eye(n) if idx else kronecker_delta(n,idx=idx)
                Ps = heat_diffusion_on_signal(G, signal, [1,t])
                P = Ps[0]
                Pt = Ps[1]
            case _:
                raise ValueError(f"Diffusion Type {self.diffusion_type} not in {_DIFFUSION_TYPES}")
        match self.laziness_method:
            case "Wasserstein":
                if not D: D = self.distance_method(G) if not idx else self.distance_method(G)[idx] #TODO: Could be more efficient here
                laziness = wasserstein_spread_of_diffusion(D,Pt)
                if _also_return_first_scale: laziness_nought = wasserstein_spread_of_diffusion(D,P)
            case "Entropic":
                laziness = entropy_of_diffusion(Pt)
                if _also_return_first_scale: laziness_nought = entropy_of_diffusion(P)
            case _:
                raise ValueError(f"Laziness Method {self.laziness_method} not in {_LAZINESS_METHOD}")
        if _also_return_first_scale: 
            return laziness, laziness_nought
        else:
            return laziness
    def curvature(
            self,
            G:pygsp.graphs.Graph, # Input Graph
            t:int, # Scale
            idx=None, # the index at which to compute curvature. If None, computes for all points.
            dim = None, # the INTRINSIC dimension of your manifold, as an int for global dimension or list of pointwise dimensions; if none, tries to estimate pointwise.
    ):
        def fit_comparison_space(dimension, jump_of_diffusion, num_points_in_comparison):
            model = EuclideanComparisonSpace(dimension=dimension, num_points=num_points_in_comparison, jump_of_diffusion=jump_of_diffusion)
            params = fit_comparison_space_model(model, max_epochs=1000)
            euclidean_stuffs = model.apply(params) # dictionary containing A, P, D
            G_euclidean = pygsp.graphs.Graph(
                W = euclidean_stuffs['A'],
                lap_type = G.lap_type, # type of laplacian; we'll use the same as inputted.
                )
            return G_euclidean, euclidean_stuffs['D']


        # Start by estimating the manifold's unsigned curvature, i.e. spreads of diffusion
        manifold_spreads, manifold_spreads_nought = self.unsigned_curvature(G,t,idx, _also_return_first_scale=True)

        n = G.L.shape[0]
        if dim is None: # The dimension wasn't supplied; we'll estimate it pointwise
            print("estimating local dimension of each point... may take a while")
            ldims = self.dimest.fit_pw(
                                G.data, #TODO: Currently this requires underlying points!
                                n_neighbors = 100,
                                n_jobs = 1)
            dims_per_point = np.round(ldims.dimension_pw_).astype(int)
        else: # the dimension *was* supplied, but it may be either a single global dimension or a local dimension for each point
            if isinstance(dim, int):
                dims_per_point = np.ones(G.P.shape[0], dtype=int)*dim
            else:
                dims_per_point = dim
        
        flat_spreads = jnp.zeros(n)
        num_points_in_comparison = n / 5 # TODO: Can surely find a better heuristic here

        if self.different_comparison_space_for_each_point:
            # iterate through each point, extract single step spread of diffuusion, erect comparison space tailored to that point
            for i in range(n):
                G_euclidean, euclidean_D = fit_comparison_space(
                    dimension = dims_per_point[i],
                    jump_of_diffusion = manifold_spreads[i],
                    num_points_in_comparison = num_points_in_comparison,
                    )
                flat_spreads = flat_spreads.at[i].set(
                    self.unsigned_curvature(G_euclidean,t,idx=0,D=euclidean_D)
                )
        else:
            raise NotImplementedError # TODO: Implement this! Need to average spreads of all points of a dimension. Create a mapping
            # between points and dimensions to get an easy mask of points per dimension. Use this to replace the set stuff below.
            unique_dims = set(dims_per_point)
            unique_flat_lazinesses = {}
            for d in unique_dims:
                G_euclidean, euclidean_D = fit_comparison_space(
                    dimension = d,
                    jump_of_diffusion = manifold_spreads[i], # Need average spreads in dimension
                    num_points_in_comparison = num_points_in_comparison,
                    )
                flat_spreads = flat_spreads.at[i].set(
                    self.unsigned_curvature(G_euclidean,t,idx=0,D=euclidean_D)
                )

                G_flat = euclidean_comparison_space(G, dimension=d)
                G_flat = self.power_diffusion_matrix(G_flat,t)
                unique_flat_lazinesses = self.unsigned_curvature(G_flat, t, idx=0)

        match self.comparison_method:
            case "Ollivier":
                ks = 1 - manifold_spreads/flat_spreads
            case "Subtraction":
                ks = flat_spreads - manifold_spreads
            case _:
                raise ValueError(f'Comparison method must be in {_COMPARISON_METHOD}')        

    


        

NameError: name 'function' is not defined

In [5]:
G_torus.lap_type

'combinatorial'

In [None]:
G_torus.get_params()

{'n_pca': None,
 'random_state': None,
 'kernel_symm': '+',
 'theta': None,
 'anisotropy': 0,
 'knn': 5,
 'decay': 40,
 'bandwidth': None,
 'bandwidth_scale': 1.0,
 'distance': 'euclidean',
 'precomputed': 'affinity'}

In [None]:
type(G_torus)

graphtools.graphs.TraditionalPyGSPGraph