In [37]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import anndata
import scanpy as sc
import sklearn
import scvelo as scv
from pathlib import Path
import pertpy as pt

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

Initialize the device

In [38]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def cross_standardize(tensor1, tensor2, minmax_scale=False):
    """
    Standardize tensor across the rows
    """
    mean_t1, std_t1 = tensor1.mean(0), tensor1.std(0)
    if minmax_scale:
        min_t1, max_t1 = tensor1.min(0), tensor1.max(0)
        tensor1 = (tensor1 - min_t1) / (max_t1 - min_t1 + 1e-6)
        tensor2 = (tensor2 - min_t1) / (max_t1 - min_t1 + 1e-6)
    else:
        tensor1 = (tensor1 - mean_t1) / (std_t1 + 1e-6)
        tensor2 = (tensor2 - mean_t1) / (std_t1 + 1e-6)
    return tensor1, tensor2

def update_dict(ref, tgt):
    """
    Update a dictionary with the values of another 
    """
    for key in tgt:
        if key not in ref:
            ref[key] = []
        ref[key].append(tgt[key])
    return ref

# def distance_metrics_pertpy(ref, trg):
#     """
#     Distance metrics using pertpy
#     """
#     adata_ref, adata_trg = ref.copy(), trg.copy()
#     adata_together = sc.AnnData(X=np.concatenate([adata_ref.X,
#                                                   adata_trg.X], axis=0))
#     annot = ["reference" for _ in range(len(adata_ref))] + ["target" for _ in range(len(adata_trg))]
#     annot_df = pd.DataFrame(annot)
#     annot_df.columns = ["dataset_type"]
#     adata_together.obs = annot_df
#     adata_together.obsm["X_data"] = adata_together.X.copy()
#     # Wasserstein
#     w_distance = pt.tl.Distance("wasserstein", obsm_key="X_data")
#     # MMD
#     mmd_distance = pt.tl.Distance("mmd", obsm_key="X_data")
#     df_w = w_distance.pairwise(adata_together, groupby="dataset_type").to_numpy()
#     df_mmd = mmd_distance.pairwise(adata_together, groupby="dataset_type", verbose=False).to_numpy()
#     return {'wassersetein':df_w.max(), 
#             'mmd':df_mmd.max()}

def distance_metrics_pertpy(ref, trg):
    """
    Distance metrics using pertpy
    """
    d =  compute_distribution_distances(torch.tensor(ref.X[:,None,:]), 
                                             torch.tensor(trg.X[:,None,:]))
    return dict(zip(d[0], d[1]))

Initialize datamodule

In [39]:
datamodule={'path': PROJECT_FOLDER / 'data/hein_et_al/processed/unperturbed_time_course_host.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'cluster'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize autoencoders

In [40]:
vae_kwargs={'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
           'use_c': True, 
           'trainable_c': False,
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
          "hidden_dims": [256, 10],
          "batch_norm": True,
          "dropout": False, 
          "dropout_p": False,
          "likelihood": "nb",
          "learning_rate": 0.001}

# Initialize vae and geometric vae
vae = VAE(**vae_kwargs).to(device)
geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_geometric_lib.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_geodesic_ae.ckpt")["state_dict"])

<All keys matched successfully>

## Setup CFMs

In [41]:
leavout_timepoints_folder = Path("/home/icb/alessandro.palma/environment/scCFM/project_dir/checkpoints/trajectory/hein_et_al")

Initialize datamodule for trajectory

In [42]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_lib_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_flat_lib_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

datamodule_kwargs_geodesic= {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_geodesic_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)
datamodule_geodesic = TrajectoryDataModule(**datamodule_kwargs_geodesic)

## Read data

First, read the latent space anndata and plot the results

In [43]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_lib_complete.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_flat_lib_complete.h5ad")
adata_latent_geodesic = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_geodesic_complete.h5ad")

# Read real anndata
adata_hein_original = sc.read_h5ad(PROJECT_FOLDER / 'data/hein_et_al/processed/unperturbed_time_course_host.h5ad')
sc.tl.pca(adata_hein_original, n_comps=50)
adata_hein_original.X = adata_hein_original.layers["X_norm"].A.copy()

Number of experiments 

In [44]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0,
 1: 0.05,
 2: 0.16666666666666666,
 3: 0.23333333333333334,
 4: 0.4,
 5: 0.6,
 6: 0.8,
 7: 1.0}

Initialize model

In [45]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [1]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = {}
leaveout_ckpt_flat_latent = {}
leaveout_ckpt_geodesic_latent = {}
leaveout_ckpt_previous_latent = {}

# DATA SPACE METRICS
leaveout_ckpt_vae_data = {}
leaveout_ckpt_flat_data = {}
leaveout_ckpt_geodesic_data = {}
leaveout_ckpt_previous_data = {}

for rep in range(1, 4):
    for tp in range(1, n_timepoints-1):
        print(f"Time point {tp}")
        #Pick time 0 observations
        X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    
        # Pick observations next timepoint 
        X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[tp]].X).to(device)    
    
        # Collect PCs
        adata_real = adata_hein_original[adata_hein_original.obs["experimental_time"]==idx2time[tp]]
        X_adata_real_pca = torch.from_numpy(adata_real.obsm["X_pca"]).to(device)
        X_adata_real = torch.from_numpy(adata_real.layers["X_log"].A).to(device)
    
        #Pick library sizes
        l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_geodesic = adata_latent_geodesic.obs.loc[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
        l_t0_flat = torch.from_numpy(l_t0_flat).to(device)
        l_t0_geodesic = torch.from_numpy(l_t0_geodesic).to(device)
    
        # Initialize nets
        net_vae = MLP(**net_hparams).to(device)
        net_flat = MLP(**net_hparams).to(device)
        net_geodesic = MLP(**net_hparams).to(device)
        cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
        cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)
        cfm_geodesic = CFMLitModule(net=net_geodesic, datamodule=datamodule_geodesic, **cfm_kwargs).to(device)
    
        # Read the checkpoints
        cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_vae_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_flat_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_geodesic.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_geodesic_leaveout_{tp}_{rep}.ckpt")["state_dict"])
    
        mu_adata_predicted_vae, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, 
                                                                                     l_t0_vae, 
                                                                                     tp-1, 
                                                                                     cfm_vae, 
                                                                                     vae)
                                                                                    
        mu_adata_predicted_flat, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, 
                                                                                       l_t0_flat, 
                                                                                       tp-1, 
                                                                                       cfm_flat, 
                                                                                       geometric_vae)
                                                                                      
        mu_adata_predicted_geodesic, X_adata_predicted_geodesic, X_adata_latent_geodesic = decode_trajectory_single_step(X_adata_t0_latent_geodesic, 
                                                                                               l_t0_geodesic, 
                                                                                               tp-1, 
                                                                                               cfm_geodesic, 
                                                                                               geodesic_ae, 
                                                                                               model_type="geodesic_ae")
    
        ## PREDICT LATENT TRAJECTORIES 
        X_adata_t1_latent_vae, X_adata_latent_vae = cross_standardize(X_adata_t1_latent_vae, X_adata_latent_vae[:,:-1])
        X_adata_t1_latent_flat, X_adata_latent_flat = cross_standardize(X_adata_t1_latent_flat, X_adata_latent_flat[:,:-1])
        X_adata_t1_latent_geodesic, X_adata_latent_geodesic = cross_standardize(X_adata_t1_latent_geodesic, X_adata_latent_geodesic[:,:-1])
                                                                                       
        d_dist_vae_l = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"), 
                                             X_adata_latent_vae.unsqueeze(1).to("cpu"))
        d_dist_flat_l = compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
                                             X_adata_latent_flat.unsqueeze(1).to("cpu"))
        d_dist_geod_l = compute_distribution_distances(X_adata_t1_latent_geodesic.unsqueeze(1).to("cpu"),
                                             X_adata_latent_geodesic.unsqueeze(1).to("cpu"))
        d_dist_prev = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"),
                                             X_adata_t0_latent_vae.unsqueeze(1).to("cpu"))
        
        d_dist_vae_l = dict(zip(d_dist_vae_l[0], d_dist_vae_l[1]))
        d_dist_flat_l = dict(zip(d_dist_flat_l[0], d_dist_flat_l[1]))
        d_dist_geod_l = dict(zip(d_dist_geod_l[0], d_dist_geod_l[1]))
        d_dist_prev_l = dict(zip(d_dist_prev[0], d_dist_prev[1]))

        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, d_dist_vae_l)
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, d_dist_flat_l)
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, d_dist_geod_l)
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, d_dist_prev_l)
    
        print("predict decoded trajectory")
        X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
        X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
        X_adata_predicted_geodesic = anndata.AnnData(X=X_adata_predicted_geodesic.numpy())
        mu_adata_predicted_vae = anndata.AnnData(X=mu_adata_predicted_vae.numpy())
        mu_adata_predicted_flat = anndata.AnnData(X=mu_adata_predicted_flat.numpy())
        mu_adata_predicted_geodesic = anndata.AnnData(X=mu_adata_predicted_geodesic.numpy())
        
        X_adata_prev = adata_hein_original[adata_hein_original.obs["experimental_time"]==idx2time[tp-1]].copy()
        # rnd_idx = np.random.choice(range(len(adata_hein_original)), size=len(X_adata_predicted_vae), replace=False)
        # X_adata_prev = adata_hein_original[rnd_idx]

        X_adata_predicted_vae.layers["X_norm"] = X_adata_predicted_vae.X.copy()
        X_adata_predicted_flat.layers["X_norm"] = X_adata_predicted_flat.X.copy()
        X_adata_predicted_geodesic.layers["X_norm"] = np.exp(X_adata_predicted_geodesic.X.copy())-1
        X_adata_prev.layers["X_norm"] = X_adata_prev.X.copy()
        sc.pp.log1p(X_adata_predicted_vae)
        sc.pp.log1p(X_adata_predicted_flat)
        sc.pp.log1p(X_adata_prev)
        
        sc.tl.pca(X_adata_predicted_vae, n_comps=50)
        sc.tl.pca(X_adata_predicted_flat, n_comps=50)
        sc.tl.pca(X_adata_predicted_geodesic, n_comps=50)
        sc.tl.pca(X_adata_prev, n_comps=50)
        X_adata_predicted_vae.X = X_adata_predicted_vae.layers["X_norm"].copy()
        X_adata_predicted_flat.X = X_adata_predicted_flat.layers["X_norm"].copy()
        X_adata_predicted_geodesic.X = X_adata_predicted_geodesic.layers["X_norm"].copy()
        X_adata_prev.X = X_adata_prev.layers["X_norm"].copy()
        
        # Compute den
        d_dist_vae_d = compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"].copy()), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_flat_d = compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"].copy()), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_geodesic_d = compute_prdc(torch.from_numpy(X_adata_predicted_geodesic.obsm["X_pca"].copy()), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_prev_d = compute_prdc(torch.from_numpy(X_adata_prev.obsm["X_pca"].copy()), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)

        # d_dist_w_mmd_vae = distance_metrics_pertpy(adata_real, X_adata_predicted_vae)
        # d_dist_w_mmd_flat =distance_metrics_pertpy(adata_real, X_adata_predicted_flat)
        # d_dist_w_mmd_geodesic =distance_metrics_pertpy(adata_real, X_adata_predicted_geodesic)
        # d_dist_w_mmd_prev =distance_metrics_pertpy(adata_real, X_adata_prev)
        d_dist_w_mmd_vae = distance_metrics_pertpy(adata_real, mu_adata_predicted_vae)
        d_dist_w_mmd_flat =distance_metrics_pertpy(adata_real, mu_adata_predicted_flat)
        d_dist_w_mmd_geodesic =distance_metrics_pertpy(adata_real, X_adata_predicted_geodesic)
        d_dist_w_mmd_prev =distance_metrics_pertpy(adata_real, X_adata_prev)
        
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_w_mmd_vae)
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_w_mmd_flat)
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_w_mmd_geodesic)
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, d_dist_w_mmd_prev)
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_vae_d)
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_flat_d)
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_geodesic_d)
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, d_dist_prev_d)

        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, {"rep": rep})
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, {"rep": rep})
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, {"rep": rep})
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, {"rep": rep})
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, {"rep": rep})
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, {"rep": rep})
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, {"rep": rep})
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, {"rep": rep})

NameError: name 'n_timepoints' is not defined

In [47]:
mu_adata_predicted_geodesic.X.max()

3.0483348

In [48]:
mu_adata_predicted_geodesic.X.max()

3.0483348

In [49]:
tp = [1,2,3,4,5,6]*3

In [50]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = pd.DataFrame(leaveout_ckpt_vae_latent)
leaveout_ckpt_flat_latent = pd.DataFrame(leaveout_ckpt_flat_latent)
leaveout_ckpt_geodesic_latent = pd.DataFrame(leaveout_ckpt_geodesic_latent)
leaveout_ckpt_previous_latent = pd.DataFrame(leaveout_ckpt_previous_latent)

# DATA SPACE METRICS
leaveout_ckpt_vae_data = pd.DataFrame(leaveout_ckpt_vae_data)
leaveout_ckpt_flat_data = pd.DataFrame(leaveout_ckpt_flat_data)
leaveout_ckpt_geodesic_data = pd.DataFrame(leaveout_ckpt_geodesic_data)
leaveout_ckpt_previous_data = pd.DataFrame(leaveout_ckpt_previous_data)

In [51]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent["time"] = tp
leaveout_ckpt_flat_latent["time"] = tp 
leaveout_ckpt_geodesic_latent["time"] = tp 
leaveout_ckpt_previous_latent["time"] = tp 

# DATA SPACE METRICS
leaveout_ckpt_vae_data["time"] = tp 
leaveout_ckpt_flat_data["time"] = tp 
leaveout_ckpt_geodesic_data["time"] = tp 
leaveout_ckpt_previous_data["time"] = tp 

## Latent space

### VAE

In [52]:
leaveout_ckpt_vae_latent.mean(0)

1-Wasserstein    2.016699
2-Wasserstein    2.121451
Linear_MMD       0.109529
Poly_MMD         0.294836
RBF_MMD          0.240606
Mean_MSE         0.124335
Mean_L2          0.321971
Mean_L1          0.258785
rep              2.000000
time             3.500000
dtype: float64

In [53]:
leaveout_ckpt_vae_latent.std(0)/np.sqrt(18)

1-Wasserstein    0.084612
2-Wasserstein    0.082567
Linear_MMD       0.024541
Poly_MMD         0.036462
RBF_MMD          0.028124
Mean_MSE         0.024962
Mean_L2          0.034869
Mean_L1          0.028204
rep              0.198030
time             0.414208
dtype: float64

Flat VAE

In [54]:
leaveout_ckpt_flat_latent.mean(0)

1-Wasserstein    1.784052
2-Wasserstein    1.962076
Linear_MMD       0.159411
Poly_MMD         0.357878
RBF_MMD          0.294699
Mean_MSE         0.164320
Mean_L2          0.374722
Mean_L1          0.305853
rep              2.000000
time             3.500000
dtype: float64

In [55]:
(leaveout_ckpt_flat_latent.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.010301
2-Wasserstein    0.012283
Linear_MMD       0.002355
Poly_MMD         0.004570
RBF_MMD          0.004549
Mean_MSE         0.002360
Mean_L2          0.004618
Mean_L1          0.004033
rep              0.577350
dtype: float64

Geodesic VAE

In [56]:
leaveout_ckpt_geodesic_latent.mean(0)

1-Wasserstein    1.988606
2-Wasserstein    2.099862
Linear_MMD       0.123717
Poly_MMD         0.316766
RBF_MMD          0.259798
Mean_MSE         0.144897
Mean_L2          0.343652
Mean_L1          0.282884
rep              2.000000
time             3.500000
dtype: float64

In [57]:
(leaveout_ckpt_geodesic_latent.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.014962
2-Wasserstein    0.017538
Linear_MMD       0.006610
Poly_MMD         0.010507
RBF_MMD          0.008059
Mean_MSE         0.008274
Mean_L2          0.011682
Mean_L1          0.008496
rep              0.577350
dtype: float64

Baseline

In [58]:
leaveout_ckpt_previous_latent.mean(0)

1-Wasserstein    2.676166
2-Wasserstein    2.770522
Linear_MMD       0.332431
Poly_MMD         0.515711
RBF_MMD          0.437044
Mean_MSE         0.358540
Mean_L2          0.546338
Mean_L1          0.459409
rep              2.000000
time             3.500000
dtype: float64

In [59]:
(leaveout_ckpt_previous_latent.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.00000
2-Wasserstein    0.00000
Linear_MMD       0.00000
Poly_MMD         0.00000
RBF_MMD          0.00000
Mean_MSE         0.00000
Mean_L2          0.00000
Mean_L1          0.00000
rep              0.57735
dtype: float64

### Data space

### VAE

In [60]:
leaveout_ckpt_vae_data.mean(0)

1-Wasserstein    20.729193
2-Wasserstein    21.293918
Linear_MMD        0.061055
Poly_MMD          0.211279
RBF_MMD           0.107533
Mean_MSE          0.078647
Mean_L2           0.266358
Mean_L1           0.172970
precision         0.806755
recall            0.128265
density           1.730449
coverage          0.623003
rep               2.000000
time              3.500000
dtype: float64

In [61]:
(leaveout_ckpt_vae_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.004779
2-Wasserstein    0.005733
Linear_MMD       0.000066
Poly_MMD         0.000110
RBF_MMD          0.000046
Mean_MSE         0.000076
Mean_L2          0.000141
Mean_L1          0.000097
precision        0.009982
recall           0.006468
density          0.114479
coverage         0.016224
rep              0.577350
dtype: float64

### Flat VAE

In [62]:
leaveout_ckpt_flat_data.mean(0)

1-Wasserstein    26.455578
2-Wasserstein    28.401863
Linear_MMD        0.091276
Poly_MMD          0.256261
RBF_MMD           0.129980
Mean_MSE          0.105373
Mean_L2           0.311699
Mean_L1           0.212559
precision         0.817493
recall            0.078406
density           3.873392
coverage          0.844084
rep               2.000000
time              3.500000
dtype: float64

In [63]:
(leaveout_ckpt_flat_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.030261
2-Wasserstein    0.045585
Linear_MMD       0.000086
Poly_MMD         0.000206
RBF_MMD          0.000217
Mean_MSE         0.000149
Mean_L2          0.000240
Mean_L1          0.000210
precision        0.013144
recall           0.003712
density          0.106844
coverage         0.009352
rep              0.577350
dtype: float64

## Geodesic VAE

In [64]:
leaveout_ckpt_geodesic_data.mean(0)

1-Wasserstein    24.944361
2-Wasserstein    47.553928
Linear_MMD        0.087869
Poly_MMD          0.259042
RBF_MMD           0.129626
Mean_MSE          0.094511
Mean_L2           0.286499
Mean_L1           0.174464
precision         0.130843
recall            0.587871
density           0.019796
coverage          0.026560
rep               2.000000
time              3.500000
dtype: float64

In [65]:
(leaveout_ckpt_geodesic_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.155584
2-Wasserstein    4.283589
Linear_MMD       0.001194
Poly_MMD         0.002123
RBF_MMD          0.000590
Mean_MSE         0.000704
Mean_L2          0.001026
Mean_L1          0.000433
precision        0.013981
recall           0.012860
density          0.002844
coverage         0.002006
rep              0.577350
dtype: float64

### Previous VAE

In [66]:
leaveout_ckpt_previous_data.mean(0)

1-Wasserstein    25.040489
2-Wasserstein    25.630669
Linear_MMD        0.048910
Poly_MMD          0.198033
RBF_MMD           0.101330
Mean_MSE          0.061222
Mean_L2           0.225436
Mean_L1           0.088712
precision         0.373793
recall            0.364364
density           0.250447
coverage          0.177981
rep               2.000000
time              3.500000
dtype: float64

In [67]:
(leaveout_ckpt_previous_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.00000
2-Wasserstein    0.00000
Linear_MMD       0.00000
Poly_MMD         0.00000
RBF_MMD          0.00000
Mean_MSE         0.00000
Mean_L2          0.00000
Mean_L1          0.00000
precision        0.00000
recall           0.00000
density          0.00000
coverage         0.00000
rep              0.57735
dtype: float64