In [None]:
import os
num_threads = "16"
os.environ["OMP_NUM_THREADS"] = num_threads
os.environ["OPENBLAS_NUM_THREADS"] = num_threads
os.environ["MKL_NUM_THREADS"] = num_threads
os.environ["VECLIB_MAXIMUM_THREADS"] = num_threads
os.environ["NUMEXPR_NUM_THREADS"] = num_threads

In [None]:
import matplotlib.pyplot as plt
import torch
import copy
import scipy as sp
from scipy import stats
from sklearn import metrics
import sys
import ot
import gwot
from gwot import models, sim, ts, util
import gwot.bridgesampling as bs
import dcor
from tqdm import tqdm
import numpy as np

sys.path.append("..")
import importlib
import models
import random
import mmd

In [None]:
PLT_CELL = 2.5

In [None]:
import glob

fnames_all = glob.glob("out_N_*.npy")
srand_all = np.array([int(f.split("_")[4]) for f in fnames_all])
lamda_all = np.array([float(f.split("_")[6].split(".npy")[0]) for f in fnames_all])
N_all = np.array([int(f.split("_")[2]) for f in fnames_all])
x_all = [np.load(f, allow_pickle = True).item(0)["model_x"] for f in fnames_all]
x_gt_all = [np.load(f, allow_pickle = True).item(0)["X_gt"] for f in fnames_all]
day_gt = np.load(fnames_all[0], allow_pickle = True).item(0)["day_gt"]
tsdata_all = [np.load(f, allow_pickle = True).item(0)["tsdata"] for f in fnames_all]

In [None]:
fnames_all_gwot = glob.glob("out_gwot_N_*.npy")
srand_all_gwot = np.array([int(f.split("_")[5]) for f in fnames_all_gwot])
lamda_all_gwot = np.array([float(f.split("_")[7].split(".npy")[0]) for f in fnames_all_gwot])
N_all_gwot = np.array([int(f.split("_")[3]) for f in fnames_all_gwot])
x_all_gwot = [np.load(f, allow_pickle = True).item(0)["samples_gwot"] for f in fnames_all_gwot]
x_gt_all_gwot = [np.load(f, allow_pickle = True).item(0)["X_gt"] for f in fnames_all_gwot]
tsdata_all_gwot = [np.load(f, allow_pickle = True).item(0)["tsdata"] for f in fnames_all_gwot]

In [None]:
days, day_idx = np.unique(day_gt, return_inverse = True)

In [None]:
with torch.no_grad():
    d_reconstruct = np.sqrt(np.array([[dcor.energy_distance(x_gt_all[j][day_idx == i, :], x_all[j][i, :]) for i in range(len(days))] for j in tqdm(range(len(x_all)), position = 0, leave = True)]))

In [None]:
d_gwot = np.sqrt(np.array([[dcor.energy_distance(x_gt_gwot_all[j][day_idx == i, :], x_gwot_all[j][i, :]) for i in range(len(days))] for j in tqdm(range(len(x_gwot_all)), position = 0, leave = True)]))

In [None]:
d_sample = np.sqrt(np.array([[dcor.energy_distance(x_gt_all[j][day_idx == i, :], tsdata_all[j].x[tsdata_all[j].t_idx == i, :]) for i in range(len(days))] for j in tqdm(range(len(x_all)), position = 0, leave = True)]))

In [None]:
N_vals, _ = np.unique(N_all, return_index = True)
N_vals_gwot, _ = np.unique(N_all_gwot, return_index = True)
lamda_vals, _ = np.unique(lamda_all, return_index = True)
lamda_vals_gwot, _ = np.unique(lamda_all_gwot, return_index = True)
srand_vals, _ = np.unique(srand_all, return_index = True)
srand_vals_gwot, _ = np.unique(srand_all_gwot, return_index = True)

In [None]:
d_reconstruct_tensor = np.full((len(N_vals), len(lamda_vals), len(srand_vals), d_reconstruct.shape[-1]), float("NaN"))
for (_N, _lamda, _srand) in zip(N_all, lamda_all, srand_all):
    d_reconstruct_tensor[N_vals == _N, lamda_vals == _lamda, srand_vals == _srand, :] = d_reconstruct[(N_all == _N) & (lamda_all == _lamda) & (srand_all == _srand), :].flatten()

In [None]:
d_gwot_tensor = np.full((len(N_vals_gwot), len(lamda_vals_gwot), len(srand_vals_gwot), d_gwot.shape[-1]), float("NaN"))
for (_N, _lamda, _srand) in zip(N_all_gwot, lamda_all_gwot, srand_all_gwot):
    d_gwot_tensor[N_vals_gwot == _N, lamda_vals_gwot == _lamda, srand_vals_gwot == _srand, :] = d_gwot[(N_all_gwot == _N) & (lamda_all_gwot == _lamda) & (srand_all_gwot == _srand), :].flatten()

In [None]:
d_sample_tensor = np.full((len(N_vals), len(lamda_vals), len(srand_vals), d_sample.shape[-1]), float("NaN"))
for (_N, _lamda, _srand) in zip(N_all, lamda_all, srand_all):
    d_sample_tensor[N_vals == _N, lamda_vals == _lamda, srand_vals == _srand, :] = d_sample[(N_all == _N) & (lamda_all == _lamda) & (srand_all == _srand), :].flatten()

In [None]:
for l in np.unique(lamda_all):
    plt.errorbar(days, d_reconstruct[lamda_all == l, :].mean(0), d_reconstruct[lamda_all == l, :].std(0), color = "blue")
for l in np.unique(lamda_all_gwot):
    plt.errorbar(days, d_gwot[lamda_all_gwot == l, :].mean(0), d_gwot[lamda_all_gwot == l, :].std(0), color = "red")
for l in np.unique(lamda_all):
    plt.errorbar(days, d_sample[lamda_all == l, :].mean(0), d_sample[lamda_all == l, :].std(0), color = "green")

In [None]:
plt.subplot(1, 2, 1)
plt.errorbar(lamda_vals, np.nanmean(d_reconstruct_tensor, axis = (2, 3)).flatten(), np.nanstd(np.nanmean(d_reconstruct_tensor, axis = 3), axis = 2).flatten(), marker = "o", color = "blue")
plt.hlines(d_sample_tensor[0, 0, :, :].mean(), min(lamda_vals), max(lamda_vals), color = "green")
plt.hlines([d_sample_tensor[0, 0, :, :].mean() + d_sample_tensor[0, 0, :, :].mean(1).std(), 
            d_sample_tensor[0, 0, :, :].mean() - d_sample_tensor[0, 0, :, :].mean(1).std()], min(lamda_vals), max(lamda_vals), linestyle = 'dashed', color = "green", label = "samples")
plt.title("Langevin")
plt.xlabel("$\\lambda$")
plt.legend()
plt.ylim(0.55, 1.75)
plt.xscale("log")
plt.subplot(1, 2, 2)
plt.errorbar(lamda_vals_gwot, np.nanmean(d_gwot_tensor, axis = (2, 3)).flatten(), np.nanstd(np.nanmean(d_gwot_tensor, axis = 3), axis = 2).flatten(), marker = "o", color = "red")
plt.hlines(d_sample_tensor[0, 0, :, :].mean(), min(lamda_vals_gwot), max(lamda_vals_gwot), color = "green")
plt.hlines([d_sample_tensor[0, 0, :, :].mean() + d_sample_tensor[0, 0, :, :].mean(1).std(), 
            d_sample_tensor[0, 0, :, :].mean() - d_sample_tensor[0, 0, :, :].mean(1).std()], min(lamda_vals_gwot), max(lamda_vals_gwot), linestyle = "dashed", color = "green", label = "samples")
plt.legend()
plt.xlabel("$\\lambda$")
plt.xscale("log")
plt.title("gWOT")
plt.ylim(0.55, 1.75)

In [None]:
lamda_vals_gwot

In [None]:
plt.figure(figsize = (PLT_CELL, PLT_CELL))
tmp = d_reconstruct_tensor[0, np.argmin(np.nanmean(d_reconstruct_tensor, axis = (2, 3)).flatten()), :, :]
plt.errorbar(days, np.nanmean(tmp, 0).flatten(), np.nanstd(tmp, 0).flatten(), marker = "o", color = "blue", label = "MFL")
tmp = d_gwot_tensor[0, np.argmin(np.nanmean(d_gwot_tensor, axis = (2, 3)).flatten()), :, :]
plt.errorbar(days, np.nanmean(tmp, 0).flatten(), np.nanstd(tmp, 0).flatten(), marker = "o", color = "red", label = "gWOT")
tmp = d_sample_tensor[0, 0, :, :]
plt.errorbar(days, np.nanmean(tmp, 0).flatten(), np.nanstd(tmp, 0).flatten(), marker = "o", color = "green", label = "Subsample")
plt.xlabel("day")
plt.ylabel("Energy Distance")
plt.legend(prop = {"size" : 8})
plt.ylim(0.25, 2.5)
plt.title("Error")
plt.tight_layout()
plt.savefig("../reprogramming_distances.pdf")

In [None]:
x_gt_all[0].shape

In [None]:
i = np.where(lamda_all == 0.025)[0][0]
M = 500

fig = plt.figure(figsize = (3*PLT_CELL, PLT_CELL))
plt.subplot(1, 3, 2)
with torch.no_grad():
    plt.scatter(x_all[i][:, :, 0], x_all[i][:, :, 1], c = np.kron(np.linspace(0, 1, len(days)), np.ones(M)), alpha = 0.5, s=  4)
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.gca().get_yaxis().set_visible(False)
plt.title("MFL")
plt.xlim(-20, 20); plt.ylim(-20, 20)
plt.subplot(1, 3, 1)
plt.scatter(tsdata_all[i].x[:, 0], tsdata_all[i].x[:, 1], c = tsdata_all[i].t_idx, alpha = 1, s = 4)
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.title("Subsample")
plt.xlim(-20, 20); plt.ylim(-20, 20)
plt.subplot(1, 3, 3)
im = plt.scatter(x_gt_all[i][:, 0], x_gt_all[i][:, 1], c = day_gt, alpha = 0.05, s = 4, rasterized = True)
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.title("Full dataset")
plt.gca().get_yaxis().set_visible(False)
plt.xlim(-20, 20); plt.ylim(-20, 20)

plt.tight_layout()

fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])
cb = fig.colorbar(im, cax=cbar_ax)
cb.set_alpha(1)
cb.draw_all()
cbar_ax.set_title("day")

plt.savefig("../reprogramming_snapshots.pdf", dpi = 300)

In [None]:
import anndata
import umap

In [None]:
# ADATA_PATH = "data_repr.h5ad"
# adata = anndata.read_h5ad(ADATA_PATH)
# adata = adata[(adata.obs.day >= 2.5) & (adata.obs.day < 6.5), :]
trans = umap.UMAP(n_neighbors = 25, verbose = True)
X_gt_umap = trans.fit_transform(x_gt_all[i])

plt.scatter(X_gt_umap[:, 0], X_gt_umap[:, 1], c = day_gt, alpha = 0.1, marker = ".")

X_sample_umap = trans.transform(tsdata_all[0].x)

with torch.no_grad():
    X_langevin_umap = trans.transform(x_all[i].reshape(-1, x_all[i].shape[-1]))

In [None]:
fig = plt.figure(figsize = (3*PLT_CELL, PLT_CELL))
plt.subplot(1, 3, 2)
plt.scatter(X_langevin_umap[:, 0], X_langevin_umap[:, 1], c = np.kron(np.linspace(0, 1, len(days)), np.ones(M)), alpha = 0.5, s=  4)
plt.xlabel("UMAP1"); plt.ylabel("UMAP2")
plt.gca().get_yaxis().set_visible(False)
plt.title("MFL")
plt.subplot(1, 3, 1)
plt.scatter(X_sample_umap[:, 0], X_sample_umap[:, 1], c = tsdata_all[i].t_idx, alpha = 1, s = 4)
plt.xlabel("UMAP1"); plt.ylabel("UMAP2")
plt.title("Subsample")
plt.subplot(1, 3, 3)
im = plt.scatter(X_gt_umap[:, 0], X_gt_umap[:, 1], c = day_gt, alpha = 0.05, s = 4, rasterized = True)
plt.xlabel("UMAP1"); plt.ylabel("UMAP2")
plt.title("Full dataset")
plt.gca().get_yaxis().set_visible(False)

plt.tight_layout()

fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])
cb = fig.colorbar(im, cax=cbar_ax)
cb.set_alpha(1)
cb.draw_all()
cbar_ax.set_title("day")

plt.savefig("../reprogramming_snapshots.pdf", dpi = 300)