In [None]:
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp

DEFAULT_SEED = 42

from gaussian import MultivariateNormal, DynamicMultivariateNormal, VarianceExploding, VariancePreserving, SubVariancePreserving

In [None]:
def cube_vertices(dim, side_len=1.0, var=1e-2):
    vertices1d = np.array([-side_len, side_len])
    all_vertices1d = vertices1d.reshape(1, 2).repeat(dim, axis=0)
    all_vertices = np.meshgrid(*all_vertices1d)
    vertices = np.stack(all_vertices, axis=-1).reshape(-1, dim)
    var = var * np.eye(dim)
    return [DynamicMultivariateNormal(dim, vertex, var) for vertex in vertices]

In [None]:
def simulate_sde(
    mix, tf=1.0, num_sample=2000, nt=50_000, num_save=200, seed=DEFAULT_SEED
):
    # exact if mix.f == 0 :)
    x_init = mix.sample(num_sample, seed=seed)
    rho, xi = 7, np.linspace(0.0, 1.0, nt + 1)
    t = tf * np.pow(xi, rho)

    save_every = nt // num_save
    x = x_init[:, None, :].repeat(num_save + 1, 1)

    rng = np.random.default_rng(seed)

    xi = x_init
    for i, (ti, dti) in enumerate(zip(t, np.diff(t))):
        dwi = np.sqrt(dti) * rng.normal(size=xi.shape)
        xi = xi + dti * mix.f(ti, xi) + mix.g(ti) * dwi
        if i % save_every == 0:
            x[:, i // save_every, :] = xi

    return t[::save_every], x


def simulate_ode(
    mix, tf=1.0, num_sample=2_000, num_save=200, seed=DEFAULT_SEED
):
    def flat_ode(t, x_flat):
        x = x_flat.reshape(-1, mix.dim)
        return mix.ode(t, x).flatten()

    x_init = mix.sample(num_sample, seed=seed)
    rho, xi = 7, np.linspace(0.0, 1.0, num_save + 1)
    t = tf * np.pow(xi, rho)

    solve_params = dict(
        rtol=1e-10, atol=1e-10, t_eval=t
    )
    sol = solve_ivp(flat_ode, (0.0, tf), x_init.flatten(), **solve_params)
    x = sol.y.reshape(num_sample, mix.dim, num_save + 1).transpose(0, 2, 1)
    return sol.t, x

## Square (2D)

In [None]:
def plot2d(mix, t, x, show_every=50):
    num_plots = (len(t) - 1) // show_every + 1
    fig, ax = plt.subplots(1, num_plots, figsize=(15, 3))

    for i in range(num_plots):
        si = i * show_every
        ti = t[si]
        xi = x[:, si, :]
        ax[i].scatter(*xi.T, s=1)

        x_min, x_max = xi.min(0), xi.max(0)
        x_range = x_max - x_min
        x_min, x_max = x_min - 0.15 * x_range, x_max + 0.15 * x_range

        x_cont_flat = np.linspace(x_min, x_max, 200)
        x_cont = np.stack(np.meshgrid(*x_cont_flat.T), -1)
        dens = mix.density(ti, x_cont)

        x1_cont, x2_cont = x_cont[:, :, 0], x_cont[:, :, 1]
        ax[i].contour(x1_cont, x2_cont, dens, levels=10, alpha=0.5, cmap="plasma")

        x_quiv_flat = x_cont_flat[5::10]
        x_quiv = np.stack(np.meshgrid(*x_quiv_flat.T), -1)
        score, div_score = mix.score_with_div(ti, x_quiv)

        x1_quiv, x2_quiv = x_quiv[:, :, 0], x_quiv[:, :, 1]
        score1, score2 = score[:, :, 0], score[:, :, 1]
        ax[i].quiver(x1_quiv, x2_quiv, score1, score2, div_score, alpha=1.0)
    return fig, ax

### Variance exploding

In [None]:
mix2d = VarianceExploding(cube_vertices(2))
tf = 20.0  # np.sqrt(4.0)
prior2d = MultivariateNormal(mix2d.dim, cov=mix2d.added_noise_sq(tf) * np.eye(mix2d.dim))
t_sde, x_sde = simulate_sde(mix2d, tf=tf)
t_ode, x_ode = simulate_ode(mix2d, tf=tf)

#### Noising

In [None]:
fig, ax = plot2d(mix2d, t_sde, x_sde)

#### Pseudo-noising (with the ODE)

In [None]:
fig, ax = plot2d(mix2d, t_ode, x_ode)

#### Comparing the SDE and the ODE

In [None]:
def nonuniform_time(tf, nt):
    return tf * np.array([0.0, *(1.0 / (2**k) for k in reversed(range(nt)))])


def simulate_ve_sde(
    mix: VarianceExploding, nt, tf, num_sample=1000, seed=DEFAULT_SEED
):
    # with VE, mix.f == 0, and the simulation is exact :)
    t = nonuniform_time(tf, nt)
    dt = np.diff(t)
    x_init = mix.sample(num_sample, seed=seed)
    x = x_init[None, :, :].repeat(nt + 1, 0)

    rng = np.random.default_rng(seed)
    dw = rng.normal(size=(nt, num_sample, mix.dim))

    for i, (ti, dti, dwi) in enumerate(zip(t, dt, dw)):
        dg_sq = mix.added_noise_sq(ti + dti) - mix.added_noise_sq(ti)
        x[i + 1] = x[i] + np.sqrt(dg_sq) * dwi

    return t, x


def simulate_reverse_ve_ode(
    mix: VarianceExploding, nt, tf, num_sample=1000, seed=DEFAULT_SEED
):
    def flat_ode(t, x_flat):
        x = x_flat.reshape(-1, mix.dim)
        return -mix.ode(tf - t, x).flatten()

    t = tf - nonuniform_time(tf, nt)[::-1]

    prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf) * np.eye(mix.dim))
    x_init = prior.sample(num_sample, seed=seed)

    solve_params = dict(rtol=1e-10, atol=1e-10, t_eval=t)
    sol = solve_ivp(flat_ode, (0.0, tf), x_init.flatten(), **solve_params)
    x = sol.y.T.reshape(nt + 1, num_sample, mix.dim)
    return tf - sol.t, x

In [None]:
rvs = cube_vertices(2)
mix2d = VarianceExploding(rvs)
nt, tf = 5, 2.0
t_sde, x_sde = simulate_ve_sde(mix2d, nt, tf)
t_ode, x_ode = simulate_reverse_ve_ode(mix2d, nt, tf)
t_ode, x_ode = t_ode[::-1], x_ode[::-1]

In [None]:
def plot_x(ax, x_samples, cols):
    ax.scatter(*x_samples.T, c=cols, s=2)

    x_max = np.abs(x_samples).max()
    ax.set_xlim(-x_max - 0.05 * x_max, x_max + 0.05 * x_max)
    ax.set_ylim(-x_max - 0.05 * x_max, x_max + 0.05 * x_max)

    x_ticks = dict(bottom=False, top=False, labelbottom=False)
    y_ticks = dict(left=False, right=False, labelleft=False)
    ax.tick_params(**x_ticks, **y_ticks)
    ax.set_aspect("equal")

    return fig, ax

c0 = mpl.colormaps["tab20c"](0)
c1 = mpl.colormaps["tab20c"](8)
c2 = mpl.colormaps["tab20c"](5)
c3 = mpl.colormaps["tab20c"](16)
c = np.array([c0, c1, c2, c3])

means = np.array([rv.mean.mean for rv in rvs])

In [None]:
c_sde = c[np.square(x_sde[0, :, None, :] - means).sum(-1).argmin(-1)]

fig, ax = plt.subplots(1, len(t_sde), figsize=(3 * 6, 3))

for i, xi in enumerate(x_sde):
    plot_x(ax[i], xi, c_sde)

fig.tight_layout(pad=0.5)
fig.savefig("img/song_sde.pdf")

In [None]:
c_ode = c[np.square(x_ode[0, :, None, :] - means).sum(-1).argmin(-1)]

fig, ax = plt.subplots(1, len(t_ode), figsize=(3 * 6, 3))

for i, xi in enumerate(x_ode):
    plot_x(ax[i], xi, c_ode)

fig.tight_layout(pad=0.5)
fig.savefig("img/song_ode.pdf")

### Variance-preserving

In [None]:
mix2d = VariancePreserving(cube_vertices(2))
prior2d = MultivariateNormal(mix2d.dim)
t_sde, x_sde = simulate_sde(mix2d)
t_ode, x_ode = simulate_ode(mix2d)

In [None]:
fig, ax = plot2d(mix2d, t_sde, x_sde)

In [None]:
fig, ax = plot2d(mix2d, t_ode, x_ode)

### Compare clustering (VP-case)

In [None]:
def plot2d_clusters(t, x, show_every=50, side_len=1.0):
    num_plots = (len(t) - 1) // show_every + 1
    fig, ax = plt.subplots(1, num_plots, figsize=(15, 3))

    right = x[:, 0, 0] > 0.5 * side_len
    upper = x[:, 0, 1] > 0.5 * side_len
    cmap = mpl.color_sequences["tab10"]
    colors = np.empty((len(x), 3))
    colors[right & upper] = cmap[0]
    colors[right & ~upper] = cmap[1]
    colors[~right & upper] = cmap[2]
    colors[~right & ~upper] = cmap[3]

    for i in range(num_plots):
        si = i * show_every
        xi = x[:, si, :]
        ax[i].scatter(*xi.T, s=1, c=colors)

    return fig, ax

In [None]:
fig, ax = plot2d_clusters(t_sde, x_sde)

In [None]:
fig, ax = plot2d_clusters(t_ode, x_ode)

## Cube (3D)

In [None]:
def plot3d(mix, t, x, show_indices=None):
    fig = plt.figure(figsize=(15, 3))

    if show_indices is None:
        show_indices = range(0, len(t) + 1, len(t) // 5)
    for i, si in enumerate(show_indices):
        ax = fig.add_subplot(1, len(show_indices), i + 1, projection="3d")
        ti = t[si]
        xi = x[:, si, :]

        div_score = mix.score_with_div(ti, xi)[1][..., 0]
        norm_col = mpl.colors.Normalize(div_score.min(), div_score.max())
        colors = mpl.colormaps["plasma"](norm_col(div_score))

        ax.scatter(*xi.T, s=1, c=colors, alpha=0.5)
        ax.set_title(f"t = {ti:.2f}")
    
    return fig

### Variance exploding

In [None]:
mix3d = VarianceExploding(cube_vertices(3))
tf = 40.0
show_indices = [0, 2, 5, 10, -1]
t_sde, x_sde = simulate_sde(mix3d, tf=tf)
t_ode, x_ode = simulate_ode(mix3d, tf=tf)

In [None]:
plot3d(mix3d, t_sde, x_sde, show_indices=show_indices);

In [None]:
plot3d(mix3d, t_ode, x_ode, show_indices=show_indices);

### Variance preserving

In [None]:
mix3d = VariancePreserving(cube_vertices(3))
show_indices = [0, 2, 5, 10, -1]
t_sde, x_sde = simulate_sde(mix3d)
t_ode, x_ode = simulate_ode(mix3d)

In [None]:
plot3d(mix3d, t_sde, x_sde);

In [None]:
plot3d(mix3d, t_ode, x_ode);