In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt

DEFAULT_SEED = 42

from gaussian import MultivariateNormal, VarianceExploding, VariancePreserving, SubVariancePreserving
from numerical import EulerSolver
from utils import cube_vertices, solve_numerical_scheme, solve_flow

In [None]:
norms = cube_vertices(dim=2)

mix_dict = {'VE': (VarianceExploding(norms), 2.0), 'VP': (VariancePreserving(norms), 1.0), 'sub-VP': (SubVariancePreserving(norms), 1.0)}
timesteps_list = [10, 20, 40, 80, 160, 320]

### solve_ivp per formulation (VE, VP, sub-VP)

In [None]:
solve_ivp_dict = {}
for k, (mix, tf) in mix_dict.items():
    prior = MultivariateNormal(mix.dim, cov=mix.added_noise_sq(tf))
    x, _ = solve_flow(mix, prior, tf=tf)
    solve_ivp_dict[k] = x

### Explicit Euler per formulation (VE, VP, sub-VP) and per timesteps (10, 20, 40, 80, 160, 320)

In [None]:
explicit_euler_dict = {}
for k, (mix, tf) in mix_dict.items():
    print(k)
    cur_mix_dict = {}
    for n_ts in timesteps_list:
        print(n_ts)
        _, x, _ = solve_numerical_scheme(solver=EulerSolver, mix=mix, n_samples=5000, t_min=1e-6, tf=tf, n_timesteps=n_ts, linear_ts=False)
        _, x_linear, _ = solve_numerical_scheme(solver=EulerSolver, mix=mix, n_samples=5000, t_min=1e-6, tf=tf, n_timesteps=n_ts, linear_ts=True)
        cur_mix_dict[n_ts] = (x[:,-1,:], x_linear[:,-1,:])
    explicit_euler_dict[k] = cur_mix_dict

    print('-'*5)

### Plots

In [None]:
def make_plots(linear_ts):
    fig, axes = plt.subplots(3, len(timesteps_list)+1, figsize=(15, 5), dpi=300)
    fig.suptitle(f'linear_ts={linear_ts}')

    for ax, col in zip(axes[0], ['solve_ivp', *timesteps_list]):
        ax.set_title(col)

    for ax, row in zip(axes[:,0], list(mix_dict.keys())):
        ax.set_ylabel(row, rotation=90, size='large')

    for i, k in enumerate(mix_dict.keys()):
        axes[i, 0].scatter(*solve_ivp_dict[k].T, s=1)
        for j, n_ts in enumerate(timesteps_list):
            axes[i, j+1].scatter(*explicit_euler_dict[k][n_ts][linear_ts].T, s=1)

    return fig, axes

In [None]:
make_plots(linear_ts=False);

In [None]:
make_plots(linear_ts=True);

### MSE solve_ivp vs explicit Euler per formulation (VE, VP, sub-VP) per timesteps (10, 20, 40, 80, 160, 320) and per timesteps modality (EDM, linear)

In [None]:
for k in mix_dict.keys():
    print(k)
    for ts in timesteps_list:
        mse_edm = ((solve_ivp_dict[k] - explicit_euler_dict[k][ts][0])**2).mean()
        mse_linear = ((solve_ivp_dict[k] - explicit_euler_dict[k][ts][1])**2).mean()
        print(f"timesteps: {ts}, MSE EDM: {mse_edm}, MSE linear: {mse_linear}")
    print('-'*5)