In [None]:
#| default_exp diffusion_laziness
# Diffusion Curvature utils
from diffusion_curvature.utils import *
from diffusion_curvature.datasets import *
# Python necessities
import numpy as np
import jax
import jax.numpy as jnp
from fastcore.all import *
import matplotlib.pyplot as plt
# Notebook Helpers
from nbdev.showdoc import *
from tqdm.notebook import trange, tqdm
from fastcore.all import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Diffusion Laziness Estimators
> What's the shape of this diffusion?

# Wasserstein Spread of Diffusion

In [None]:
#|export
import jax.numpy as jnp
from jax import jit

@jit
def wasserstein_spread_of_diffusion(
                D:jax.Array, # manifold geodesic distances
                Pt:jax.Array, # powered diffusion matrix/t-step ehat diffusions
                ):
        """
        Returns how "spread out" each diffusion is, with wasserstein distance
        Presumes that the manifold distances have been separately calculated
        """
        return jnp.sum(D * Pt, axis=-1)

### Benchmarking

In [None]:
D = np.random.rand(1000,1000)
Pt = np.random.rand(1000,1000)
Pt = Pt / np.sum(Pt, axis=1)[:,None]

In [None]:
%%timeit
wasserstein_spread_of_diffusion(D,Pt)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


1.08 ms ± 57.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
key = jax.random.PRNGKey(0)
Djax = jax.random.normal(key, (1000, 1000))
key = jax.random.PRNGKey(10)
Ptjax = jax.random.normal(key, (1000, 1000))

In [None]:
%%timeit
wasserstein_spread_of_diffusion(Djax,Ptjax)

706 µs ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Wow, it's nearly two orders of magnitude faster when using jax arrays.

# Entropy of Diffusion

In [None]:
#|export
import jax.scipy
import jax.numpy as jnp

def entropy_of_diffusion(
    Pt:jax.Array, # powered diffusion matrix
    epsilon=1e-5, # threshold for small values, for speed
): 
        """
        Returns the pointwise entropy of diffusion from the powered diffusion matrix in the input
        Assumes that Pt sums to 1
        """
        # Use only the elements of Pt that are greater than epsilon
        Pt = Pt * (Pt>epsilon)
        # Normalize Pt so that it sums to 1
        Pt = Pt / (jnp.sum(Pt, axis=-1) + 1e-12)
        # Pt = (Pt + 1e-10) /(1 + 1e-10*Pt.shape[0]) # ensure, for differentiability, that there are no zeros in Pt, but that it still sums to 1.
        entropy_elementwise = jax.scipy.special.entr(Pt)
        entropy_of_rows = jnp.sum(entropy_elementwise, axis=-1)
        # normalize so max value is 1
        # entropy_of_rows = entropy_of_rows / (-jnp.log(1/jnp.sum(Pt>epsilon, axis=-1)))
        return entropy_of_rows

In [None]:
from scipy.stats import entropy

In [None]:
assert jnp.allclose(entropy_of_diffusion(Pt),entropy(Pt,axis=1), atol = 1e-3)

# Diffusion Laziness Model

The above functions estimate the laziness of a powered diffusion matrix, at a single time. Here, we extend the computation over multiple times, as well as making it more convenient to call.

## Get Multiple Powers of Diffusion at Once

The first step is taking a diffusion matrix, and a list of times, and producing powerings of the matrix for each time.
We do this by beginning with the lowest number, then for each subsequent number, taking its additive factors (i.e. partitions), seeing if any of them are already in the list, and using them if so.

In [None]:
#|export
from typing import List

def get_matrix_power_recursive(
    desired_power:int, 
    Pt_dict:dict, # should be, by default, {1 : P}
):
    if desired_power in Pt_dict.keys():
        return Pt_dict[desired_power], Pt_dict
    # Given no existing keys, we factor things into the closest powers of two. But if there is a large existing key (larger than the poewr of two), we'll use that.
    best_power = desired_power // 2
    max_power = max(Pt_dict.keys())
    if max_power >= best_power: u = max_power
    else:
        u = best_power
        _, Pt_dict = get_matrix_power_recursive(best_power, Pt_dict)
    P_minusmax, Pt_dict = get_matrix_power_recursive(desired_power - u, Pt_dict)
    Pt = Pt_dict[u] @ P_minusmax
    Pt_dict[desired_power] = Pt
    return Pt, Pt_dict


def powers_of_diffusion(
    P:jax.Array, # diffusion matrix
    ts:List[Int], # list of times 
)->List[jax.Array]: 
    """
    Returns list[P^t for t in ts], but done efficiently.
    """
    Pt_dict = { 1: P }
    Pts = []
    for t in ts:
        Pt, Pt_dict = get_matrix_power_recursive(t, Pt_dict)
        Pts.append(Pt)
    return Pts

Sanity check

In [None]:
A = random_jnparray(10,10)
A_19, A_power_dict = get_matrix_power_recursive(
    19, {1:A}
)

In [None]:
jnp.allclose(jnp.linalg.matrix_power(A, 19), A_19)

Array(True, dtype=bool)

In [None]:
A_power_dict.keys()

dict_keys([1, 2, 4, 5, 9, 10, 19])

Speed test: jnp's implementation of matmul is, as expected faster. But only by a factor of 5. That's pretty good. Quite likely python is the cause of slowness, not the algorithm we've used.

In [None]:
%%timeit
t = np.random.randint(40,60)
At, Adict = get_matrix_power_recursive(t, { 1 : A })

41.6 µs ± 70.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%%timeit
t = np.random.randint(40,60)
At = jnp.linalg.matrix_power(A, t)

8.88 µs ± 44.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Now let's test the real use case: getting powered matrices from a list of powers.

In [None]:
ts = [1,2,3,4,5,6,7,8,10,13,15,17,21,25]
Pts = powers_of_diffusion(A, ts)
for i,t in enumerate(ts):
    jnp.allclose(Pts[i], jnp.linalg.matrix_power(A, t))

Timing it, we've now become faster than the bare metal -- though surprisingly not by that much.

In [None]:
%%timeit
ts = np.arange(1,100)
Pts = powers_of_diffusion(A, ts)

581 µs ± 3.09 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%%timeit
ts = np.arange(1,100)
Pts = [jnp.linalg.matrix_power(A, t) for t in ts]

710 µs ± 106 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


We also see that computing a hundred sequential powers instead of one only scales by a factor of ten, versus, for the basic jax, a factor of 100. That's the important constant.

## Diffusion Laziness Wrapper

In [None]:
from typing import Literal

class DiffusionLaziness():
    DIFFUSION_TYPES = Literal['diffusion matrix','heat kernel']
    LAZINESS_METHODS = Literal['Entropic', 'Wasserstein']
    def __init__(
        diffusion_type:DIFFUSION_TYPES = "diffusion matrix",
        laziness_method:LAZINESS_METHODS = "Entropic",
    ):
        store_attr()


    def fit_transform(
        G, # graph
        t, # time or list of times.
    ):
        # compute diffusion matrix from graph
        