In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp

DEFAULT_SEED = 42

from gaussian import MultivariateNormal, DynamicMultivariateNormal, VarianceExploding, VariancePreserving, SubVariancePreserving

In [29]:
def cube_vertices(dim, side_len=1.0, var=1e-2):
    vertices1d = np.array([-side_len, side_len])
    all_vertices1d = vertices1d.reshape(1, 2).repeat(dim, axis=0)
    all_vertices = np.meshgrid(*all_vertices1d)
    vertices = np.stack(all_vertices, axis=-1).reshape(-1, dim)
    var = var * np.eye(dim)
    return [DynamicMultivariateNormal(dim, vertex, var) for vertex in vertices]

In [None]:
def compute_quasi_exact_nll(
    mix, prior, n_data=50_000, batch_size=2_000, t_min=0.0, tf=1.0, seed=DEFAULT_SEED
):
    def flat_extended_ode(t, x_cumdiv_flat):
        x, _ = np.split(x_cumdiv_flat.reshape(-1, mix.dim + 1), [mix.dim], -1)
        dx, dlogp = mix.extended_ode(t, x)
        return np.concatenate([dx, dlogp[..., None]], 1).flatten()

    def solve_batch(x_init):
        n_data = x_init.shape[0]
        delta_logp = np.zeros((n_data, 1))
        x_logp_init = np.concatenate([x_init, delta_logp], axis=1).flatten()

        solve_params = dict(method="DOP853", rtol=1e-10, atol=1e-10)
        sol = solve_ivp(flat_extended_ode, (t_min, tf), x_logp_init, **solve_params)

        x_logp_fin = sol.y[:, -1].reshape(n_data, mix.dim + 1)
        x_fin, delta_logp = np.split(x_logp_fin, [mix.dim], -1)
        prior_fin = np.log(prior.density(x_fin))
        h_init = -np.log(mix.density(0.0, x_init))
        h_fin = -np.log(mix.density(tf, x_fin))
        nll_emp = -(delta_logp[:, 0] + prior_fin).mean()
        nll_theo = -(h_fin - h_init + prior_fin).mean()

        return x_fin, nll_emp / (np.log(2.0) * mix.dim), nll_theo / (np.log(2.0) * mix.dim)

    num_batches = n_data // batch_size + (n_data % batch_size > 0)
    seeds = np.random.SeedSequence(seed).spawn(num_batches)
    nll_emp, nll_theo = 0.0, 0.0
    for seed in seeds:
        x_init = mix.sample(batch_size, seed=seed)
        x_fin, nll_emp_batch, nll_theo_batch = solve_batch(x_init)
        nll_emp += nll_emp_batch
        nll_theo += nll_theo_batch

    return x_fin, nll_emp / num_batches, nll_theo / num_batches

In [103]:
for pow_dim in range(3, 4):
    dim = 2 ** pow_dim
    print(f"Dimension {dim} - VE")
    for tf in [2.0]:
    # for tf in [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0]:
        mix = VarianceExploding(cube_vertices(dim))
        prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf) * np.eye(mix.dim))

        # seeds = np.random.SeedSequence(DEFAULT_SEED).spawn(2)
        # x_init = mix.sample(100_000, t=0.0, seed=seeds[0])
        # x_fin = mix.sample(100_000, t=tf, seed=seeds[1])
        # h_init = np.mean(np.log(mix.density(0.0, x_init)))
        # h_fin = np.mean(np.log(mix.density(tf, x_fin)))
        # h_prior = np.mean(np.log(prior.density(x_fin)))
        # nll = -(h_init - h_fin + h_prior) / (np.log(2.0) * mix.dim)
        # nll = -h_init / (np.log(2.0) * mix.dim)

        x_bw, nll_emp, nll_theo = compute_quasi_exact_nll(mix, prior, tf=tf)
        print(f"tf={tf}: nll={nll_emp} & {nll_theo} (empirical & theoretical)")

Dimension 8 - VE
tot_prob: 9.664e+08, inv_prob: 8.900e+08, tot_score: 2.560e-10, tot_tensor_score: 1.201e-09, tot_jac_score: -1.990e-09
tot_prob: 1.032e+09, inv_prob: 9.578e+08, tot_score: 2.462e-10, tot_tensor_score: 1.143e-09, tot_jac_score: -1.847e-09


KeyboardInterrupt: 

In [17]:
for dim in range(2, 6):
    print(f"Dimension {dim} - VP")
    for beta_max in [5.0, 10.0, 15.0, 20.0, 30.0, 50.0]:
        mix = VariancePreserving(cube_vertices(dim), beta_max=beta_max)
        prior = MultivariateNormal(mix.dim)

        x_bw, nll = compute_quasi_exact_nll(mix, prior)
        print(f"beta_max={beta_max}: nll={nll}")

Dimension 2 - VP
beta_max=5.0: nll=-0.26473580390107515
beta_max=10.0: nll=-0.26465953666208264
beta_max=15.0: nll=-0.2646459555562728
beta_max=20.0: nll=-0.26464052251283554
beta_max=30.0: nll=-0.264635484640041
beta_max=50.0: nll=-0.26462878531056677
Dimension 3 - VP
beta_max=5.0: nll=-0.27835336742055167
beta_max=10.0: nll=-0.27830832560036056
beta_max=15.0: nll=-0.2782836397812913
beta_max=20.0: nll=-0.27827116429084714
beta_max=30.0: nll=-0.2782548878596516
beta_max=50.0: nll=-0.2782270085170098
Dimension 4 - VP
beta_max=5.0: nll=-0.2764088676445521


KeyboardInterrupt: 