In [1]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp

DEFAULT_SEED = 42

from gaussian import MultivariateNormal, DynamicMultivariateNormal, VarianceExploding, VariancePreserving, SubVariancePreserving

In [7]:
def cube_vertices(dim, side_len=1.0, var=1e-2):
    vertices1d = np.array([0.0, side_len])
    all_vertices1d = vertices1d.reshape(1, 2).repeat(dim, axis=0)
    all_vertices = np.meshgrid(*all_vertices1d)
    vertices = np.stack(all_vertices, axis=-1).reshape(-1, dim)
    var = var * np.eye(dim)
    return [DynamicMultivariateNormal(dim, vertex, var) for vertex in vertices]

In [8]:
def compute_quasi_exact_nll(
    mix, prior, n_data=10_000, batch_size=1_000, t_min=1e-6, tf=1.0, seed=DEFAULT_SEED
):
    def flat_extended_ode(t, x_cumdiv_flat):
        x, _ = np.split(x_cumdiv_flat.reshape(-1, mix.dim + 1), [mix.dim], -1)
        dx, dlogp = mix.extended_ode(t, x)
        return np.concatenate([dx, dlogp], 1).flatten()

    def solve_batch(x_init):
        n_data = x_init.shape[0]
        delta_logp = np.zeros((n_data, 1))
        x_logp_init = np.concatenate([x_init, delta_logp], axis=1).flatten()

        solve_params = dict(rtol=1e-10, atol=1e-10, t_eval=np.array([t_min, tf]))
        sol = solve_ivp(flat_extended_ode, (t_min, tf), x_logp_init, **solve_params)

        x_logp_fin = sol.y[:, -1].reshape(n_data, mix.dim + 1)
        x_fin, delta_logp = np.split(x_logp_fin, [mix.dim], -1)
        prior_fin = np.log(prior.density(x_fin))
        nll = -(delta_logp[:, 0] + prior_fin).mean()
        
        bpd_cst = 1.0 / (np.log(2.0) * mix.dim)
        return x_fin,  nll * bpd_cst

    num_batches = n_data // batch_size + (n_data % batch_size > 0)
    seeds = np.random.SeedSequence(seed).spawn(num_batches)
    nll = 0.0
    for seed in seeds:
        x_init = mix.sample(batch_size, seed)
        x_fin, nll_batch = solve_batch(x_init)
        nll += nll_batch

    return x_fin, nll / num_batches

In [9]:
for dim in range(2, 6):
    print(f"Dimension {dim} - VE")
    for tf in [1.0, 20.0, 40.0, 60.0, 80.0, 100.0]:
        mix = VarianceExploding(cube_vertices(dim))
        prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf) * np.eye(mix.dim))

        x_bw, nll = compute_quasi_exact_nll(mix, prior, tf=tf)
        print(f"tf={tf}: nll={nll}")

Dimension 2 - VE
tf=1.0: nll=-0.07914062447730236
tf=20.0: nll=-0.2622657275358472
tf=40.0: nll=-0.18390859378481877
tf=60.0: nll=-0.030760761347738575
tf=80.0: nll=0.1648992084007826
tf=100.0: nll=0.3711655900300173
Dimension 3 - VE
tf=1.0: nll=-0.07219784221226758
tf=20.0: nll=0.23539383185448318
tf=40.0: nll=1.0760998209634547
tf=60.0: nll=1.633057154259444
tf=80.0: nll=2.0383644720132903
tf=100.0: nll=2.355796112173455
Dimension 4 - VE
tf=1.0: nll=-0.06845292165460785
tf=20.0: nll=1.1609602580799225
tf=40.0: nll=2.1282864031483633
tf=60.0: nll=2.7072008002496277
tf=80.0: nll=3.1201214831174644
tf=100.0: nll=3.4410697952916167
Dimension 5 - VE
tf=1.0: nll=-0.06802151657433862
tf=20.0: nll=1.7762474955548995
tf=40.0: nll=2.7628125023108927
tf=60.0: nll=3.345287043698614
tf=80.0: nll=3.7594537572157245
tf=100.0: nll=4.080978802693484


In [10]:
for dim in range(2, 6):
    print(f"Dimension {dim} - VP")
    for beta_max in [10.0, 15.0, 20.0, 30.0, 50.0, 80.0]:
        mix = VariancePreserving(cube_vertices(dim), beta_max=beta_max)
        prior = MultivariateNormal(mix.dim)

        x_bw, nll = compute_quasi_exact_nll(mix, prior)
        print(f"beta_max={beta_max}: nll={nll}")

Dimension 2 - VP
beta_max=10.0: nll=-0.27326817491199856
beta_max=15.0: nll=-0.27413922948746283
beta_max=20.0: nll=-0.27414448144653997
beta_max=30.0: nll=-0.2741159980322
beta_max=50.0: nll=-0.27410590111636657
beta_max=80.0: nll=-0.2740964040274491
Dimension 3 - VP
beta_max=10.0: nll=-0.2760049894729151
beta_max=15.0: nll=-0.2773749782937044
beta_max=20.0: nll=-0.2775207680183966
beta_max=30.0: nll=-0.2775350702249296
beta_max=50.0: nll=-0.2775081909124967
beta_max=80.0: nll=-0.2774680603326158
Dimension 4 - VP
beta_max=10.0: nll=-0.27010211192133377
beta_max=15.0: nll=-0.2714160205702662
beta_max=20.0: nll=-0.27151648916538873
beta_max=30.0: nll=-0.27146231354730654
beta_max=50.0: nll=-0.2713223260592701
beta_max=80.0: nll=-0.27113440557137997
Dimension 5 - VP
beta_max=10.0: nll=-0.27073640184162556
beta_max=15.0: nll=-0.27186124359031594
beta_max=20.0: nll=-0.2718436379077664
beta_max=30.0: nll=-0.27159617363947575
beta_max=50.0: nll=-0.2711151657337706
beta_max=80.0: nll=-0.27046