In [None]:
import os
num_threads = "16"
os.environ["OMP_NUM_THREADS"] = num_threads
os.environ["OPENBLAS_NUM_THREADS"] = num_threads
os.environ["MKL_NUM_THREADS"] = num_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = num_threads
os.environ["NUMEXPR_NUM_THREADS"] = num_threads

In [None]:
import matplotlib.pyplot as plt
import torch
from torch.autograd import grad, Variable
import autograd
import autograd.numpy as np
import copy
import scipy as sp
from scipy import stats
from sklearn import metrics
import sys
import ot
import gwot
from gwot import models, sim, ts, util
import gwot.bridgesampling as bs
import dcor
from tqdm import tqdm

sys.path.append("..")
import importlib
import models
import random
import model_fig1 as model_sim
import mmd

In [None]:
PLT_CELL = 2.5

In [None]:
import glob

fnames_all = glob.glob("out_N_*.npy")
fnames_all_gwot = glob.glob("out_gwot_N_*.npy")

In [None]:
srand_all = np.array([int(f.split("_")[4]) for f in fnames_all])
srand_all_gwot = np.array([int(f.split("_")[5]) for f in fnames_all_gwot])

In [None]:
lamda_all = np.array([float(f.split("_")[6].split(".npy")[0]) for f in fnames_all])
lamda_all_gwot = np.array([float(f.split("_")[7].split(".npy")[0]) for f in fnames_all_gwot])

In [None]:
N_all = np.array([int(f.split("_")[2]) for f in fnames_all])
N_all_gwot = np.array([int(f.split("_")[3]) for f in fnames_all_gwot])

In [None]:
x_all = [np.load(f, allow_pickle = True).item(0)["x"] for f in fnames_all]
x_gwot_all = [np.load(f, allow_pickle = True).item(0)["samples_gwot"] for f in fnames_all_gwot]

In [None]:
# setup simulation object
sim = gwot.sim.Simulation(V = model_sim.Psi, dV = model_sim.dPsi, birth_death = False, 
                          N = None,
                          T = model_sim.T, 
                          d = model_sim.dim, 
                          D = model_sim.D, 
                          t_final = model_sim.t_final, 
                          ic_func = model_sim.ic_func, 
                          pool = None)

sim_gt = copy.deepcopy(sim)
sim_gt.N = np.array([1_000, ]*model_sim.T)
sim_gt.sample(steps_scale = int(model_sim.sim_steps/sim.T));

In [None]:
plt.scatter(np.linspace(0, model_sim.t_final, model_sim.T)[sim_gt.t_idx], sim_gt.x[:, 0], alpha = 0.01, color = "blue")

In [None]:
x_all[0].shape

In [None]:
with torch.no_grad():
    d_reconstruct = np.array([[dcor.energy_distance(sim_gt.x[sim_gt.t_idx == i, :], x_all[j][i, :]) for i in range(x_all[j].shape[0])] for j in tqdm(range(len(x_all)), position = 0, leave = True)])
d_gwot = np.array([[dcor.energy_distance(sim_gt.x[sim_gt.t_idx == i, :], x_gwot_all[j][i, :]) for i in range(x_gwot_all[j].shape[0])] for j in tqdm(range(len(x_gwot_all)), position = 0, leave = True)])

In [None]:
plt.plot(d_gwot[(N_all_gwot == 1) & (lamda_all_gwot == 0.005), :].mean(0), 'o-', label = "gWOT")
plt.plot(d_reconstruct[(N_all == 1) & (lamda_all == 0.05), :].mean(0), 'o-', label = "Langevin")
plt.legend()

In [None]:
fnames_all[0]

In [None]:
N_vals, _ = np.unique(N_all, return_index = True)
N_vals_gwot, _ = np.unique(N_all_gwot, return_index = True)
lamda_vals, _ = np.unique(lamda_all, return_index = True)
lamda_vals_gwot, _ = np.unique(lamda_all_gwot, return_index = True)
srand_vals, _ = np.unique(srand_all, return_index = True)
srand_vals_gwot, _ = np.unique(srand_all_gwot, return_index = True)

In [None]:
d_reconstruct_tensor = np.full((len(N_vals), len(lamda_vals), len(srand_vals), sim_gt.T), float("NaN"))
for (_N, _lamda, _srand) in zip(N_all, lamda_all, srand_all):
    d_reconstruct_tensor[N_vals == _N, lamda_vals == _lamda, srand_vals == _srand, :] = d_reconstruct[(N_all == _N) & (lamda_all == _lamda) & (srand_all == _srand), :].flatten()

d_gwot_tensor = np.full((len(N_vals_gwot), len(lamda_vals_gwot), len(srand_vals_gwot), sim_gt.T), float("NaN"))
for (_N, _lamda, _srand) in zip(N_all_gwot, lamda_all_gwot, srand_all_gwot):
    d_gwot_tensor[N_vals_gwot == _N, lamda_vals_gwot == _lamda, srand_vals_gwot == _srand, :] = d_gwot[(N_all_gwot == _N) & (lamda_all_gwot == _lamda) & (srand_all_gwot == _srand), :].flatten()

In [None]:
plt.subplot(1, 2, 1)
plt.plot(lamda_vals, np.sqrt(d_reconstruct_tensor[0].mean(-1)).mean(-1), 'o-')
plt.title("Langevin")
plt.subplot(1, 2, 2)
plt.plot(lamda_vals_gwot, np.sqrt(d_gwot_tensor[0].mean(-1)).mean(-1), 'o-')
plt.title("gWOT")

In [None]:
plt.figure(figsize = (3*PLT_CELL, 3/2*PLT_CELL))
plt.subplot(1, 2, 1)
# im = plt.imshow(np.nanmean(d_reconstruct_tensor, (2, 3)), origin = "lower")
im = plt.imshow(np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1), origin = "lower")
plt.xticks(range(len(lamda_vals)), lamda_vals, rotation = 30)
plt.yticks(range(len(N_vals)), N_vals)
plt.colorbar(im,fraction=0.038, pad=0.04)
plt.xlabel("$\lambda$")
plt.ylabel("N")
plt.title("Langevin")

plt.subplot(1, 2, 2)
# im = plt.imshow(np.nanmean(d_gwot_tensor, (2, 3)), origin = "lower")
im = plt.imshow(np.sqrt(d_gwot_tensor.mean(-1)).mean(-1), origin = "lower")
plt.xticks(range(len(lamda_vals_gwot)), lamda_vals_gwot, rotation = 30)
plt.yticks(range(len(N_vals_gwot)), N_vals_gwot)
plt.colorbar(im,fraction=0.038, pad=0.04)
plt.xlabel("$\lambda$")
plt.ylabel("N")
plt.title("gWOT")

plt.tight_layout()

In [None]:
np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1)

In [None]:
plt.plot(lamda_vals, np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1).T, 'o-');

In [None]:
plt.subplot(1, 2, 1)
plt.plot(N_vals, np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1), 'o-', label = "Langevin");
plt.ylim(0, 0.5)
plt.legend()
plt.xscale("log")
plt.subplot(1, 2, 2)
plt.plot(N_vals, np.sqrt(d_gwot_tensor.mean(-1)).mean(-1), 'o-', label = "gWOT");
plt.ylim(0, 0.5)
plt.legend()
plt.xscale("log")

In [None]:
means

In [None]:
plt.figure(figsize = (PLT_CELL, 1.75*PLT_CELL))

sds = np.std(np.sqrt(d_reconstruct_tensor.mean(-1)), axis = 2)
means = np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1)
min_idx = np.nanargmin(means, axis = 1)
sds_minmean = np.array([x[y] for (x, y) in zip(sds, min_idx)])
plt.errorbar(N_vals, np.nanmin(means, 1), sds_minmean, label = "MFL", color = "blue", marker = "o")

sds = np.std(np.sqrt(d_gwot_tensor.mean(-1)), axis = 2)
means = np.sqrt(d_gwot_tensor.mean(-1)).mean(-1)
min_idx = np.nanargmin(means, axis = 1)
sds_minmean = np.array([x[y] for (x, y) in zip(sds, min_idx)])
plt.errorbar(N_vals, np.nanmin(means, 1), sds_minmean, label = "gWOT", color = "red", marker = "o")

plt.ylabel("RMS Energy Distance")
plt.xlabel("N")

plt.xscale("log")
plt.legend()

plt.tight_layout()
plt.savefig("../fig1_distances.pdf")