In [2]:
#| default_exp diffusion_laziness
# Diffusion Curvature utils
from diffusion_curvature.utils import *
from diffusion_curvature.datasets import *
# Python necessities
import numpy as np
import jax
import jax.numpy as jnp
from fastcore.all import *
import matplotlib.pyplot as plt
# Notebook Helpers
from nbdev.showdoc import *
from tqdm.notebook import trange, tqdm
%load_ext autoreload
%autoreload 2

# Diffusion Laziness Estimators
> What's the shape of this diffusion?

# Wasserstein Spread of Diffusion

In [3]:
#|export
import jax.numpy as jnp
from jax import jit

@jit
def wasserstein_spread_of_diffusion(
                D, # manifold geodesic distances
                Pt, # powered diffusion matrix/t-step ehat diffusions
                ):
        """
        Returns how "spread out" each diffusion is, with wasserstein distance
        Presumes that the manifold distances have been separately calculated
        """
        return jnp.sum(D * Pt, axis=-1)

### Benchmarking

In [4]:
D = np.random.rand(1000,1000)
Pt = np.random.rand(1000,1000)
Pt = Pt / np.sum(Pt, axis=1)[:,None]

In [5]:
%%timeit
wasserstein_spread_of_diffusion(D,Pt)

1.82 ms ± 131 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
key = jax.random.PRNGKey(0)
Djax = jax.random.normal(key, (1000, 1000))
key = jax.random.PRNGKey(10)
Ptjax = jax.random.normal(key, (1000, 1000))

In [7]:
jnp.allclose(Djax,Ptjax)

Array(False, dtype=bool)

In [8]:
%%timeit
wasserstein_spread_of_diffusion(Djax,Ptjax)

20.4 µs ± 126 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Wow, it's nearly two orders of magnitude faster when using jax arrays.

# Entropy of Diffusion

In [9]:
#|export
import jax.scipy
import jax.numpy as jnp
# def entropy_of_diffusion(Pt):
#         """
#         Returns the pointwise entropy of diffusion from the powered diffusion matrix in the input
#         Assumes that Pt sums to 1
#         """
#         Pt = (Pt + 1e-10) /(1 + 1e-10*Pt.shape[0]) # ensure, for differentiability, that there are no zeros in Pt, but that it still sums to 1.
#         entropy_elementwise = jax.scipy.special.entr(Pt)
#         entropy_of_rows = jnp.sum(entropy_elementwise, axis=-1)
#         # normalize so max value is 1
#         entropy_of_rows = entropy_of_rows / (-jnp.log(1/Pt.shape[0]))
#         return entropy_of_rows
def entropy_of_diffusion(Pt, epsilon=1e-5):
        """
        Returns the pointwise entropy of diffusion from the powered diffusion matrix in the input
        Assumes that Pt sums to 1
        """
        # Use only the elements of Pt that are greater than epsilon
        Pt = Pt * (Pt>epsilon)
        # Normalize Pt so that it sums to 1
        Pt = Pt / (jnp.sum(Pt, axis=-1) + 1e-12)
        # Pt = (Pt + 1e-10) /(1 + 1e-10*Pt.shape[0]) # ensure, for differentiability, that there are no zeros in Pt, but that it still sums to 1.
        entropy_elementwise = jax.scipy.special.entr(Pt)
        entropy_of_rows = jnp.sum(entropy_elementwise, axis=-1)
        # normalize so max value is 1
        # entropy_of_rows = entropy_of_rows / (-jnp.log(1/jnp.sum(Pt>epsilon, axis=-1)))
        return entropy_of_rows

In [10]:
from scipy.stats import entropy
assert jnp.allclose(entropy_of_diffusion(Pt),entropy(Pt,axis=1))

AssertionError: 

In [12]:
jnp.sum(Pt, axis=-1)[:,None]

Array([[1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [0.99999994],
       [0.99999994],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [0.99999994],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.        ],
       [1.   