In [1]:
#|default_exp distances
from diffusion_curvature.utils import *

# Manifold Distances

Here we provide reimplementations of algorithms for estimating the manifold distances of a graph.

## PHATE Distances

Wasserstein Diffusion Curvature -- despite the name -- requires only manifold distances. This saves quite a bit of computation, but doesn't alleviate the need for a good approximation of the manifold's geodesic distance. Here, we implement one straightforward and accurate manifold distance: that proposed by Moon et al. in PHATE (2019). The PHATE distance is an extension of the diffusion distance, except instead of calculating the L2 distances between diffusion coordinates (which corresponds roughly to the rows of the diffusion matrix), it takes the L2 distances between the log-transformed diffusions. This flips the weighting from local to global, as a diffusion that assigns a small mass where another assigns a miniscule mass becomes much further than those that differ only at their centers. This log transform has the additional advantage of, through the WAWA formulation of the heat equation, recovering the distance term.

It is defined as:
$$d_p(x,y) = \| \log(p_y^t)-\log(p_x^t) \|_2 $$

In [None]:
#|export
from sklearn.metrics import pairwise_distances
import numpy as np
import scipy
import graphtools

def phate_distances_graphtools(G:graphtools.api.Graph):
    assert G.Pt is not None
    if type(G.Pt) == np.ndarray:
        log_Pts = -np.log(G.Pt + 1e-6)
        D = pairwise_distances(log_Pts)
    elif type(G.Pt) == scipy.sparse.csr_matrix:
        # TODO: There's likely a more efficient way of doing this. 
        # But I mustn't tempt the devil of premature optimization
        Pt_np = G.Pt.toarray()
        log_Pts = -np.log(Pt_np + 1e-6)
        D = pairwise_distances(log_Pts)
    G.D = D
    return G

In [2]:
#|export
import jax.numpy as jnp
def pairwise_euclidean(x, y):
  # Pairwise euclidean distances in Jax, courtesy of [jakevdp](https://github.com/google/jax/discussions/11841)
  assert x.ndim == y.ndim == 2
  return jnp.sqrt(((x[:, None, :] - y[None, :, :]) ** 2).sum(-1)) # I would want to use something like PyKeops for this, if being done differentiably.
  
def phate_distances(Pt):
    log_Pts = -jnp.log(Pt + 1e-6)
    D = pairwise_distances(log_Pts, log_Pts)
    return D

def phate_distances_differentiable(Pt):
    return phate_distances(Pt)

In [3]:
A = random_jnparray(8,3)
D = pairwise_euclidean(A,A)

In [None]:
D.shape

(8, 8)

In [None]:
from diffusion_curvature.datasets import torus
import graphtools

In [None]:
X_torus, torus_gaussian_curvature = torus(n=3000)
G_torus = graphtools.Graph(X_torus)
G_torus.Pt = G_torus.P ** 4

In [None]:
G_torus = phate_distances(G_torus)

In [None]:
G_torus.D

array([[ 0.        , 93.7875039 , 94.59402414, ..., 92.54080119,
        90.6213935 , 92.93120075],
       [93.7875039 ,  0.        , 89.16673659, ..., 86.98550107,
        84.94065264, 87.40071844],
       [94.59402414, 89.16673659,  0.        , ..., 87.64808383,
        85.61906139, 82.96883047],
       ...,
       [92.54080119, 86.98550107, 87.64808383, ...,  0.        ,
        83.30204347, 85.8508315 ],
       [90.6213935 , 84.94065264, 85.61906139, ..., 83.30204347,
         0.        , 83.77829279],
       [92.93120075, 87.40071844, 82.96883047, ..., 85.8508315 ,
        83.77829279,  0.        ]])

In [5]:
!nbdev_export