In [None]:
#|default_exp distances

# Manifold Distances

Here we provide reimplementations of algorithms for estimating the manifold distances of a graph.

## PHATE Distances

Wasserstein Diffusion Curvature -- despite the name -- requires only manifold distances. This saves quite a bit of computation, but doesn't alleviate the need for a good approximation of the manifold's geodesic distance. Here, we implement one straightforward and accurate manifold distance: that proposed by Moon et al. in PHATE (2019). The PHATE distance is an extension of the diffusion distance, except instead of calculating the L2 distances between diffusion coordinates (which corresponds roughly to the rows of the diffusion matrix), it takes the L2 distances between the log-transformed diffusions. This flips the weighting from local to global, as a diffusion that assigns a small mass where another assigns a miniscule mass becomes much further than those that differ only at their centers. This log transform has the additional advantage of, through the WAWA formulation of the heat equation, recovering the distance term.

It is defined as:
$$d_p(x,y) = \| \log(p_y^t)-\log(p_x^t) \|_2 $$

In [None]:
#|export
from sklearn.metrics import pairwise_distances
import numpy as np
import scipy
import graphtools
def phate_distances(G:graphtools.api.Graph):
    assert G.Pt is not None
    if type(G.Pt) == np.ndarray:
        log_Pts = -np.log(G.Pt + 1e-6)
        D = pairwise_distances(log_Pts)
    elif type(G.Pt) == scipy.sparse.csr_matrix:
        # TODO: There's likely a more efficient way of doing this. 
        # But I mustn't tempt the devil of premature optimization
        Pt_np = G.Pt.toarray()
        log_Pts = -np.log(Pt_np + 1e-6)
        D = pairwise_distances(log_Pts)
    G.D = D
    return G

In [None]:
from diffusion_curvature.datasets import torus
X_torus, torus_gaussian_curvature = torus(n=3000)
import graphtools
G_torus = graphtools.Graph(X_torus)
G_torus.Pt = G_torus.P ** 4

In [None]:
G_torus = phate_distances(G_torus)

In [None]:
G_torus.D

array([[ 0.        , 79.44588272, 91.79222008, ..., 86.89214593,
        76.89037532, 84.1055431 ],
       [79.44588272,  0.        , 80.39207031, ..., 74.74836671,
        75.23454511, 71.46405865],
       [91.79222008, 80.39207031,  0.        , ..., 87.75808631,
        88.17255902, 84.99987688],
       ...,
       [86.89214593, 74.74836671, 87.75808631, ...,  0.        ,
        83.05921696, 79.6832004 ],
       [76.89037532, 75.23454511, 88.17255902, ..., 83.05921696,
         0.        , 80.13944645],
       [84.1055431 , 71.46405865, 84.99987688, ..., 79.6832004 ,
        80.13944645,  0.        ]])