In [None]:
#|default_exp compute_mean_flat_entropies

# 1c1 Mean Entropy of Uniform Flat Spaces

In [None]:
#|export
import jax.numpy as jnp
from diffusion_curvature.core import DiffusionCurvature, get_adaptive_graph
from diffusion_curvature.utils import *
from tqdm.auto import tqdm
import graphtools

def average_flat_entropies(
        dim,
        t,
        num_trials,
        num_points_in_comparison = 10000,
        graph_former = get_adaptive_graph
):
    DC = DiffusionCurvature(
        laziness_method="Entropic",
        flattening_method="Fixed",
        comparison_method="Subtraction",
        points_per_cluster=None, # construct separate comparison spaces around each point
        comparison_space_size_factor=1
    )
    flat_spreads = jnp.zeros(num_trials)
    for i in range(num_trials):
        Rn = jnp.concatenate([jnp.zeros((1,dim)), 2*random_jnparray(num_points_in_comparison-1, dim)-1])
        G = graph_former(Rn) #graphtools.Graph(Rn, anisotropy=1, knn=k, decay=None,).to_pygsp()
        fs = DC.unsigned_curvature(G, t, idx=0)
        flat_spreads = flat_spreads.at[i].set(fs)
    return jnp.mean(flat_spreads)

In [None]:
#|export
import h5py
from fastcore.all import *
@call_parse
def create_mean_entropy_database(
    outfile = "../data/entropies_averaged.h5",
    dimensions:Param("(Intrinsic) Dimensions to Take Mean Entropies over", int, nargs='+') = [1,2,3,4,5,6],
    knns:Param("k-nearest neighbor values to compute", int, nargs='+') = [5,10,15],
    ts:Param("time values to compute", int, nargs='+') = [25,30,35],
):
    # load the database
    f = h5py.File(outfile,'a')
    for i, dim in tqdm(enumerate(dimensions)):
        # load the group corresponding to dimension; create if it doesn't exist
        if str(dim) in f.keys(): dim_group = f[str(dim)]
        else:               dim_group = f.create_group(str(dim))
        for j, knn in tqdm(enumerate(knns),leave=False):
            if str(knn) in dim_group.keys(): knn_group = dim_group[str(knn)]
            else:                            knn_group = dim_group.create_group(str(knn))
            for k, t in tqdm(enumerate(ts),leave=False):
                if str(t) in knn_group.keys(): continue
                else:
                    afe = average_flat_entropies(dim, knn, t, 100)
                    knn_group.create_dataset(str(t), data=afe)
    return f


In [None]:
#|export
import h5py
def load_average_entropies(filename):
    d = {}
    with h5py.File(filename,'r') as f:
        for dim in f.keys():
            d[dim] = {}
            for knn in f[dim].keys():
                d[dim][knn] = {}
                for t in f[dim][knn].keys():
                    d[dim][knn][t] = f[dim][knn][t][()]
    return d

In [None]:
d = load_average_entropies('../data/entropies_averaged.h5')

In [None]:
d

{'1': {'10': {'25': 4.1617084},
  '15': {'25': 4.5504103},
  '5': {'25': 3.4661007}},
 '2': {'10': {'25': 5.775014}, '15': {'25': 6.2073145}, '5': {'25': 4.912664}},
 '3': {'10': {'25': 7.301388},
  '15': {'25': 7.7671103},
  '5': {'25': 6.2948627}},
 '4': {'10': {'25': 8.641593}, '15': {'25': 8.925011}, '5': {'25': 7.6750712}},
 '5': {'10': {'25': 9.111419}, '15': {'25': 9.178771}, '5': {'25': 8.638502}},
 '6': {'10': {'25': 9.189743}, '15': {'25': 9.203704}, '5': {'25': 9.015003}}}

In [3]:
!nbdev_export