In [21]:
#|default_exp graphs
import numpy as np

# Graph Creation Utils

Written by Alex Tong & Guillaume Huguet for [KrishnaswamyLab/HeatGeo: Embedding with the Heat-geodesic dissimilarity](https://github.com/KrishnaswamyLab/HeatGeo). Adapted for diffusion curvature.

This uses the marvelous library PyGSP to handle graph creation and basic graph signal processing.

By default we use the alpha-decay kernel from [PHATE](https://www.nature.com/articles/s41587-019-0336-3). We implemented other types of kernels, to use them you need to install our package with `[dev]`.

In [22]:
#| export
import graphtools as gt
import pygsp
from typing import Union
from graphtools.matrix import set_diagonal, to_array
from scipy import sparse
from sklearn.preprocessing import normalize
import numpy as np

try:
    # optional dependencies
    import scanpy as sc
    import umap
except ImportError as imp_err:
    sc = imp_err
    umap = imp_err

In [23]:
#| export

def diff_op(graph):
    """
    Compute the diffusion operator for a pygsp graph.
    """
    assert isinstance(graph, pygsp.graphs.Graph)
    K = set_diagonal(graph.W, 1)
    diff_op_ = normalize(K, norm="l1", axis=1)
    return diff_op_


def kernel_degree(graph):
    """
    Compute the kernel degree for a pygsp graph.
    """
    assert isinstance(graph, pygsp.graphs.Graph)
    K = set_diagonal(graph.W, 1)
    return to_array(K.sum(axis=1)).reshape(-1, 1)


def diff_aff(graph):
    """
    Compute the diffusion affinity for a pygsp graph.
    """
    assert isinstance(graph, pygsp.graphs.Graph)
    K = set_diagonal(graph.W, 1)
    row_degrees = kernel_degree(graph)

    if sparse.issparse(K):
        # diagonal matrix
        degrees = sparse.csr_matrix(
            (
                1 / np.sqrt(row_degrees.flatten()),
                np.arange(len(row_degrees)),
                np.arange(len(row_degrees) + 1),
            )
        )
        return degrees @ K @ degrees
    else:
        col_degrees = row_degrees.T
        return (K / np.sqrt(row_degrees)) / np.sqrt(col_degrees)


###------------------------Graphs Classes ----------------------------###


def get_knn_graph(X, knn=5, **kwargs):
    return pygsp.graphs.NNGraph(X, k=knn)


def get_alpha_decay_graph(
    X,
    knn: int = 5,
    decay: float = 40.0,
    anisotropy: float = 0,
    n_pca: int = None,
    **kwargs
):
    return gt.Graph(
        X,
        knn=knn,
        decay=decay,
        anisotropy=anisotropy,
        n_pca=n_pca,
        use_pygsp=True,
        random_state=42,
    ).to_pygsp()


def get_scanpy_graph(X, knn=5, **kwargs):

    if isinstance(sc, ImportError):
        raise ImportError("Scanpy is not installed.")

    adata = sc.AnnData(X)
    sc.pp.neighbors(adata, n_neighbors=knn)
    w = adata.obsp["connectivities"]
    return pygsp.graphs.Graph(w)


def get_umap_graph(X, knn=5, **kwargs):  # knn default to 15 in UMAP
    if isinstance(umap, ImportError):
        raise ImportError("UMAP is not installed.")
    umap_op = umap.UMAP(n_neighbors=knn, metric="euclidean")
    umap_op.fit(X)
    w = umap_op.graph_.toarray()
    return pygsp.graphs.Graph(w)

# The Differentiable Kernel

In [24]:
#|export
import jax
import jax.numpy as jnp

def generic_kernel(
        D, # distance matrix
        sigma, # kernel bandwidth
        anisotropic_density_normalization, 

):  
    W = (1/(sigma*np.sqrt(2*jnp.pi)))*jnp.exp((-D**2)/(2*sigma**2))
    D = jnp.diag(1/((jnp.sum(W,axis=1)+1e-8)**anisotropic_density_normalization))
    W = D @ W @ D
    return W

In [None]:
#|export

def diffusion_matrix_from_affinities(
        W
):
    W = W + jnp.eye(len(W))*1e-8
    D = jnp.diag(1/jnp.sum(W,axis=1))
    P = D @ W
    return P

In [None]:
from diffusion_curvature.utils import random_jnparray
from diffusion_curvature.distances import pairwise_euclidean
X = random_jnparray(100,9)
D = pairwise_euclidean(X,X)
W = generic_kernel(D,0.7,0.5)
P = diffusion_matrix_from_affinities(W)

In [11]:
import jax.numpy as jnp
jnp.max(jnp.array([1, 0.4, 1, 3])/jnp.array([1e-8,1,1,1]))

Array(1.e+08, dtype=float32)

In [9]:
jnp.array([1, 0.4, 1, 3])/jnp.array([0,1,1,1])

Array([inf, 0.4, 1. , 3. ], dtype=float32)