In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp

DEFAULT_SEED = 42

from gaussian import MultivariateNormal, DynamicMultivariateNormal, VarianceExploding, VariancePreserving, SubVariancePreserving
from numerical import EulerSolver, BroydenSolver

In [None]:
def cube_vertices(dim, side_len=1.0, var=1e-2):
    vertices1d = np.array([-side_len, side_len])
    all_vertices1d = vertices1d.reshape(1, 2).repeat(dim, axis=0)
    all_vertices = np.meshgrid(*all_vertices1d)
    vertices = np.stack(all_vertices, axis=-1).reshape(-1, dim)
    var = var * np.eye(dim)
    return [DynamicMultivariateNormal(dim, vertex, var) for vertex in vertices]

In [None]:
def plot_simulation(mix, t, x, show_every=50):
    num_plots = (len(t) - 1) // show_every + 1
    fig, ax = plt.subplots(1, num_plots, figsize=(15, 3))

    for i in range(num_plots):
        si = i * show_every
        ti = t[si]
        xi = x[:, si, :]

        x1_cont = np.linspace(xi[:, 0].min() - 1.0, xi[:, 0].max() + 1.0, 200)
        x2_cont = np.linspace(xi[:, 1].min() - 1.0, xi[:, 1].max() + 1.0, 200)
        x_cont = np.stack(np.meshgrid(x1_cont, x2_cont), -1)
        x1_quiv = x1_cont[5::10]
        x2_quiv = x2_cont[5::10]
        x_quiv = np.stack(np.meshgrid(x1_quiv, x2_quiv), -1)
        score, div_score = mix.score_with_div(ti, x_quiv)
    
        ax[i].scatter(*xi.T, s=1)
        ax[i].contour(x_cont[:, :, 0], x_cont[:, :, 1], np.log(1e-8 + mix.density(ti, x_cont)), levels=10, alpha=0.5, cmap="plasma")
        ax[i].quiver(x_quiv[:, :, 0], x_quiv[:, :, 1], score[:, :, 0], score[:, :, 1], div_score, alpha=0.8)
    return fig, ax

In [None]:
dim = 2
norms = cube_vertices(2)

### Explicit Euler plots

In [None]:
t_min = 1e-6
tf = 2.0
mix = VarianceExploding(norms)

num_sample = 5000
num_save = 500

x_init = mix.sample(num_sample)

sol = EulerSolver(mix.ode_with_jac, t_min, tf, num_save)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

sol = EulerSolver(mix.ode_with_jac, t_min, tf, num_save, linear_ts=True)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

In [None]:
t_min = 1e-6
tf = 1.0
mix = VariancePreserving(norms)

num_sample = 5000
num_save = 500

x_init = mix.sample(num_sample)

sol = EulerSolver(mix.ode_with_jac, t_min, tf, num_save)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

sol = EulerSolver(mix.ode_with_jac, t_min, tf, num_save, linear_ts=True)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

In [None]:
t_min = 1e-6
tf = 1.0
mix = SubVariancePreserving(norms)

num_sample = 5000
num_save = 500

x_init = mix.sample(num_sample)

sol = EulerSolver(mix.ode_with_jac, t_min, tf, num_save)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

sol = EulerSolver(mix.ode_with_jac, t_min, tf, num_save, linear_ts=True)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

### Implicit Euler plots (using Broyden method)

In [None]:
t_min = 1e-6
tf = 2.0
mix = VarianceExploding(norms)

num_sample = 5000
num_save = 500

x_init = mix.sample(num_sample)

sol = BroydenSolver(mix.ode, mix.ode_with_jac, t_min, tf, num_save)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

sol = BroydenSolver(mix.ode, mix.ode_with_jac, t_min, tf, num_save, linear_ts=True)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

In [None]:
t_min = 1e-6
tf = 1.0
mix = VariancePreserving(norms)

num_sample = 5000
num_save = 500

x_init = mix.sample(num_sample)

sol = BroydenSolver(mix.ode, mix.ode_with_jac, t_min, tf, num_save)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

sol = BroydenSolver(mix.ode, mix.ode_with_jac, t_min, tf, num_save, linear_ts=True)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

In [None]:
t_min = 1e-6
tf = 1.0
mix = SubVariancePreserving(norms)

num_sample = 5000
num_save = 500

x_init = mix.sample(num_sample)

sol = BroydenSolver(mix.ode, mix.ode_with_jac, t_min, tf, num_save)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

sol = BroydenSolver(mix.ode, mix.ode_with_jac, t_min, tf, num_save, linear_ts=True)
t, x, _ = sol(x_init)
plot_simulation(mix, t, x, show_every=120)

### Negative log-likelihood per formulation (VE, VP, sub-VP)

In [None]:
def compute_nll_euler_inf(mix, prior, n_data=5000, t_min=1e-6, tf=1.0):
    def flat_extended_ode(t, x_cumdiv_flat):
        x, _ = np.split(x_cumdiv_flat.reshape(-1, mix.dim + 1), [mix.dim], -1)
        dx, dlogp = mix.extended_ode(t, x)
        return np.concatenate([dx, dlogp], 1).flatten()

    x_data = mix.sample(n_data)
    delta_logp = np.zeros((n_data, 1))
    x_logp_init = np.concatenate([x_data, delta_logp], axis=1)

    sol = solve_ivp(flat_extended_ode, (t_min, tf), x_logp_init.flatten())

    x_logp_fin = sol.y[:, -1].reshape(n_data, mix.dim + 1)
    x, delta_logp = np.split(x_logp_fin, [mix.dim], -1)
    prior_fin = np.log(prior.density(x))
    return x, -(delta_logp[:, 0] + prior_fin).mean() / np.log(2.0) / mix.dim

In [None]:
tf = 2.0
mix = VarianceExploding(norms)
prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf))
x, nll = compute_nll_euler_inf(mix, prior, tf=tf)
plt.scatter(*x.T, s=1)

x1_cont = np.linspace(x[:, 0].min() - 1.0, x[:, 0].max() + 1.0, 200)
x2_cont = np.linspace(x[:, 1].min() - 1.0, x[:, 1].max() + 1.0, 200)
x_cont = np.stack(np.meshgrid(x1_cont, x2_cont), -1)
plt.contour(x_cont[:, :, 0], x_cont[:, :, 1], prior.density(x_cont), levels=10, alpha=0.5, cmap="plasma")

plt.title(f"NLL: {nll:.2f}")

In [None]:
tf = 1.0
mix = VariancePreserving(norms)
prior = MultivariateNormal(mix.dim)
x, nll = compute_nll_euler_inf(mix, prior)
plt.scatter(*x.T, s=1)

x1_cont = np.linspace(x[:, 0].min() - 1.0, x[:, 0].max() + 1.0, 200)
x2_cont = np.linspace(x[:, 1].min() - 1.0, x[:, 1].max() + 1.0, 200)
x_cont = np.stack(np.meshgrid(x1_cont, x2_cont), -1)
plt.contour(x_cont[:, :, 0], x_cont[:, :, 1], prior.density(x_cont), levels=10, alpha=0.5, cmap="plasma")

plt.title(f"NLL: {nll:.2f}")

In [None]:
tf = 1.0
mix = SubVariancePreserving(norms)
prior = MultivariateNormal(mix.dim)
x, nll = compute_nll_euler_inf(mix, prior)
plt.scatter(*x.T, s=1)

x1_cont = np.linspace(x[:, 0].min() - 1.0, x[:, 0].max() + 1.0, 200)
x2_cont = np.linspace(x[:, 1].min() - 1.0, x[:, 1].max() + 1.0, 200)
x_cont = np.stack(np.meshgrid(x1_cont, x2_cont), -1)
plt.contour(x_cont[:, :, 0], x_cont[:, :, 1], prior.density(x_cont), levels=10, alpha=0.5, cmap="plasma")

plt.title(f"NLL: {nll:.2f}")

### VE NLL per dim (solve_ivp)

In [None]:
dims = [1, 2, 4, 8]
for dim in dims:
    tf = 2.0
    norms = cube_vertices(dim)
    mix = VarianceExploding(norms)
    prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf))
    x, nll = compute_nll_euler_inf(mix, prior, tf=tf)
    print(f"dim: {dim}, nll: {nll}")

### VE NLL per dim (10, 20, 40, 80 steps)

#### Explicit Euler

In [None]:
dims = [1, 2, 4, 8]
for dim in dims:
    t_min = 1e-6
    tf = 2.
    norms = cube_vertices(dim)
    mix = VarianceExploding(norms)
    prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf))

    num_sample = 5000
    steps_list = [10, 20, 40, 80]
    for steps in steps_list:
        sol = EulerSolver(mix.ode_with_jac, t_min, tf, steps)
        x = mix.sample(num_sample)
        _, _, nll = sol(x)
        print("dim:", dim, "steps:", steps, "nll:", nll.mean())
    
    print('-'*5)

#### Implicit Euler (using Broyden method)

In [None]:
dims = [1, 2, 4, 8]
for dim in dims:
    t_min = 1e-6
    tf = 2.
    norms = cube_vertices(dim)
    mix = VarianceExploding(norms)
    prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf))

    num_sample = 5000
    steps_list = [10, 20, 40, 80]
    for steps in steps_list:
        sol = BroydenSolver(mix.ode, mix.ode_with_jac, t_min, tf, steps)
        x = mix.sample(num_sample)
        _, _, nll = sol(x)
        print("dim:", dim, "steps:", steps, "nll:", nll.mean())
    
    print('-'*5)