In [1]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import anndata
import scanpy as sc
import sklearn
import scvelo as scv
from pathlib import Path
import pertpy as pt

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

Initialize the device

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def cross_standardize(tensor1, tensor2, minmax_scale=False):
    """
    Standardize tensor across the rows
    """
    mean_t1, std_t1 = tensor1.mean(0), tensor1.std(0)
    if minmax_scale:
        min_t1, max_t1 = tensor1.min(0), tensor1.max(0)
        tensor1 = (tensor1 - min_t1) / (max_t1 - min_t1 + 1e-6)
        tensor2 = (tensor2 - min_t1) / (max_t1 - min_t1 + 1e-6)
    else:
        tensor1 = (tensor1 - mean_t1) / (std_t1 + 1e-6)
        tensor2 = (tensor2 - mean_t1) / (std_t1 + 1e-6)
    return tensor1, tensor2

def update_dict(ref, tgt):
    """
    Update a dictionary with the values of another 
    """
    for key in tgt:
        if key not in ref:
            ref[key] = []
        ref[key].append(tgt[key])
    return ref

# def distance_metrics_pertpy(ref, trg):
#     """
#     Distance metrics using pertpy
#     """
#     adata_ref, adata_trg = ref.copy(), trg.copy()
#     adata_together = sc.AnnData(X=np.concatenate([adata_ref.X,
#                                                   adata_trg.X], axis=0))
#     annot = ["reference" for _ in range(len(adata_ref))] + ["target" for _ in range(len(adata_trg))]
#     annot_df = pd.DataFrame(annot)
#     annot_df.columns = ["dataset_type"]
#     adata_together.obs = annot_df
#     adata_together.obsm["X_data"] = adata_together.X.copy()
#     # Wasserstein
#     w_distance = pt.tl.Distance("wasserstein", obsm_key="X_data")
#     # MMD
#     mmd_distance = pt.tl.Distance("mmd", obsm_key="X_data")
#     df_w = w_distance.pairwise(adata_together, groupby="dataset_type").to_numpy()
#     df_mmd = mmd_distance.pairwise(adata_together, groupby="dataset_type", verbose=False).to_numpy()
#     return {'wassersetein':df_w.max(), 
#             'mmd':df_mmd.max()}

def distance_metrics_pertpy(ref, trg):
    """
    Distance metrics using pertpy
    """
    d =  compute_distribution_distances(torch.tensor(ref.X[:,None,:]), 
                                             torch.tensor(trg.X[:,None,:]))
    return dict(zip(d[0], d[1]))

Initialize datamodule

In [3]:
datamodule={'path': PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'leiden'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize autoencoders

In [4]:
vae_kwargs={'in_dim': datamodule.in_dim,
               'n_epochs_anneal_kl': 1000, 
               'kl_weight': None, 
               'likelihood': 'nb', 
               'dropout': False, 
               'learning_rate': 0.001, 
               'dropout_p': False, 
               'model_library_size': True, 
               'batch_norm': True, 
               'kl_warmup_fraction': 0.1, 
               'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
                   'use_c': True, 
                   'trainable_c': False,
                   'l2': True, 
                   'eta_interp': 0, 
                   'interpolate_z': False, 
                   'start_jac_after': 0, 
                   'fl_weight': 0.1,
                   'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
                  "hidden_dims": [256, 10],
                  "batch_norm": True,
                  "dropout": False, 
                  "dropout_p": False,
                  "likelihood": "nb",
                  "learning_rate": 0.001}

# Initialize vae and geometric vae
vae = VAE(**vae_kwargs).to(device)
geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geometric_lib.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geodesic_ae.ckpt")["state_dict"])

<All keys matched successfully>

## Setup CFMs

In [5]:
leavout_timepoints_folder = Path("/home/icb/alessandro.palma/environment/scCFM/project_dir/checkpoints/trajectory/eb")

Initialize datamodule for trajectory

In [6]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_flat_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

datamodule_kwargs_geodesic= {'path': PROJECT_FOLDER / 'data/eb/flat/eb_geodesic.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)
datamodule_geodesic = TrajectoryDataModule(**datamodule_kwargs_geodesic)

## Read data

First, read the latent space anndata and plot the results

In [7]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_lib.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_flat_lib.h5ad")
adata_latent_geodesic = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_geodesic.h5ad")

# Read real anndata
adata_eb_original = sc.read_h5ad(PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad')
sc.tl.pca(adata_eb_original, n_comps=50)
adata_eb_original.X = adata_eb_original.layers["X_norm"].A.copy()

Number of experiments 

In [8]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0, 1: 0.25, 2: 0.5, 3: 0.75, 4: 1.0}

Initialize model

In [9]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [10]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = {}
leaveout_ckpt_flat_latent = {}
leaveout_ckpt_geodesic_latent = {}
leaveout_ckpt_previous_latent = {}

# DATA SPACE METRICS
leaveout_ckpt_vae_data = {}
leaveout_ckpt_flat_data = {}
leaveout_ckpt_geodesic_data = {}
leaveout_ckpt_previous_data = {}

for rep in range(1, 4):
    for tp in range(1, n_timepoints-1):
        print(f"Time point {tp}")
        #Pick time 0 observations
        X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    
        # Pick observations next timepoint 
        X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[tp]].X).to(device)    
    
        # Collect PCs  
        adata_real = adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]]
        X_adata_real_pca = torch.from_numpy(adata_real.obsm["X_pca"]).to(device)
        X_adata_real = torch.from_numpy(adata_real.layers["X_log"].A).to(device)
        
        #Pick library sizes
        l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_geodesic = adata_latent_geodesic.obs.loc[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
        l_t0_flat = torch.from_numpy(l_t0_flat).to(device)
        l_t0_geodesic = torch.from_numpy(l_t0_geodesic).to(device)
    
        # Initialize nets
        net_vae = MLP(**net_hparams).to(device)
        net_flat = MLP(**net_hparams).to(device)
        net_geodesic = MLP(**net_hparams).to(device)
        cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
        cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)
        cfm_geodesic = CFMLitModule(net=net_geodesic, datamodule=datamodule_geodesic, **cfm_kwargs).to(device)
    
        # Read the checkpoints
        cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_vae_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_flat_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_geodesic.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_geodesic_leaveout_{tp}_{rep}.ckpt")["state_dict"])
    
        mu_adata_predicted_vae, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, 
                                                                                     l_t0_vae, 
                                                                                     tp-1, 
                                                                                     cfm_vae, 
                                                                                     vae)
                                                                                    
        mu_adata_predicted_flat, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, 
                                                                                       l_t0_flat, 
                                                                                       tp-1, 
                                                                                       cfm_flat, 
                                                                                       geometric_vae)
                                                                                      
        mu_adata_predicted_geodesic, X_adata_predicted_geodesic, X_adata_latent_geodesic = decode_trajectory_single_step(X_adata_t0_latent_geodesic, 
                                                                                               l_t0_geodesic, 
                                                                                               tp-1, 
                                                                                               cfm_geodesic, 
                                                                                               geodesic_ae, 
                                                                                               model_type="geodesic_ae")
    
        ## PREDICT LATENT TRAJECTORIES 
        print("predict latent trajectory")
        X_adata_t1_latent_vae, X_adata_latent_vae = cross_standardize(X_adata_t1_latent_vae, X_adata_latent_vae[:,:-1])
        X_adata_t1_latent_flat, X_adata_latent_flat = cross_standardize(X_adata_t1_latent_flat, X_adata_latent_flat[:,:-1])
        X_adata_t1_latent_geodesic, X_adata_latent_geodesic = cross_standardize(X_adata_t1_latent_geodesic, X_adata_latent_geodesic[:,:-1])
                                                                                       
        d_dist_vae_l = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"), 
                                             X_adata_latent_vae.unsqueeze(1).to("cpu"))
        d_dist_flat_l = compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
                                             X_adata_latent_flat.unsqueeze(1).to("cpu"))
        d_dist_geod_l = compute_distribution_distances(X_adata_t1_latent_geodesic.unsqueeze(1).to("cpu"),
                                             X_adata_latent_geodesic.unsqueeze(1).to("cpu"))
        d_dist_prev = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"),
                                             X_adata_t0_latent_vae.unsqueeze(1).to("cpu"))
        
        d_dist_vae_l = dict(zip(d_dist_vae_l[0], d_dist_vae_l[1]))
        d_dist_flat_l = dict(zip(d_dist_flat_l[0], d_dist_flat_l[1]))
        d_dist_geod_l = dict(zip(d_dist_geod_l[0], d_dist_geod_l[1]))
        d_dist_prev_l = dict(zip(d_dist_prev[0], d_dist_prev[1]))
        
        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, d_dist_vae_l)
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, d_dist_flat_l)
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, d_dist_geod_l)
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, d_dist_prev_l)
    
        print("predict decoded trajectory")
        X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
        X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
        X_adata_predicted_geodesic = anndata.AnnData(X=X_adata_predicted_geodesic.numpy())
        mu_adata_predicted_vae = anndata.AnnData(X=mu_adata_predicted_vae.numpy())
        mu_adata_predicted_flat = anndata.AnnData(X=mu_adata_predicted_flat.numpy())
        mu_adata_predicted_geodesic = anndata.AnnData(X=mu_adata_predicted_geodesic.numpy())
        
        X_adata_prev = adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp-1]].copy()
        # rnd_idx = np.random.choice(range(len(adata_eb_original)), size=len(X_adata_predicted_vae), replace=False)
        # X_adata_prev = adata_eb_original[rnd_idx]

        X_adata_predicted_vae.layers["X_norm"] = X_adata_predicted_vae.X.copy()
        X_adata_predicted_flat.layers["X_norm"] = X_adata_predicted_flat.X.copy()
        X_adata_predicted_geodesic.layers["X_norm"] = np.exp(X_adata_predicted_geodesic.X.copy())-1
        X_adata_prev.layers["X_norm"] = X_adata_prev.X.copy()
        sc.pp.log1p(X_adata_predicted_vae)
        sc.pp.log1p(X_adata_predicted_flat)
        sc.pp.log1p(X_adata_prev)
        
        sc.tl.pca(X_adata_predicted_vae, n_comps=50)
        sc.tl.pca(X_adata_predicted_flat, n_comps=50)
        sc.tl.pca(X_adata_predicted_geodesic, n_comps=50)
        sc.tl.pca(X_adata_prev, n_comps=50)
        X_adata_predicted_vae.X = X_adata_predicted_vae.layers["X_norm"].copy()
        X_adata_predicted_flat.X = X_adata_predicted_flat.layers["X_norm"].copy()
        X_adata_predicted_geodesic.X = X_adata_predicted_geodesic.layers["X_norm"].copy()
        X_adata_prev.X = X_adata_prev.layers["X_norm"].copy()
        
        # Compute den
        d_dist_vae_d = compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_flat_d = compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_geodesic_d = compute_prdc(torch.from_numpy(X_adata_predicted_geodesic.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_prev_d = compute_prdc(torch.from_numpy(X_adata_prev.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)

        # d_dist_w_mmd_vae = distance_metrics_pertpy(adata_real, X_adata_predicted_vae)
        # d_dist_w_mmd_flat =distance_metrics_pertpy(adata_real, X_adata_predicted_flat)
        # d_dist_w_mmd_geodesic =distance_metrics_pertpy(adata_real, X_adata_predicted_geodesic)
        # d_dist_w_mmd_prev =distance_metrics_pertpy(adata_real, X_adata_prev)
        d_dist_w_mmd_vae = distance_metrics_pertpy(adata_real, mu_adata_predicted_vae)
        d_dist_w_mmd_flat =distance_metrics_pertpy(adata_real, mu_adata_predicted_flat)
        d_dist_w_mmd_geodesic =distance_metrics_pertpy(adata_real, X_adata_predicted_geodesic)
        d_dist_w_mmd_prev =distance_metrics_pertpy(adata_real, X_adata_prev)
        
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_w_mmd_vae)
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_w_mmd_flat)
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_w_mmd_geodesic)
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, d_dist_w_mmd_prev)
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_vae_d)
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_flat_d)
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_geodesic_d)
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, d_dist_prev_d)

        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, {"rep": rep})
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, {"rep": rep})
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, {"rep": rep})
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, {"rep": rep})
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, {"rep": rep})
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, {"rep": rep})
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, {"rep": rep})
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, {"rep": rep})

Time point 1
predict latent trajectory
predict decoded trajectory
Time point 2
predict latent trajectory
predict decoded trajectory
Time point 3
predict latent trajectory
predict decoded trajectory
Time point 1
predict latent trajectory
predict decoded trajectory
Time point 2
predict latent trajectory
predict decoded trajectory
Time point 3
predict latent trajectory
predict decoded trajectory
Time point 1
predict latent trajectory
predict decoded trajectory
Time point 2
predict latent trajectory
predict decoded trajectory
Time point 3
predict latent trajectory
predict decoded trajectory


In [11]:
tp = [1,2,3]*3

In [12]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = pd.DataFrame(leaveout_ckpt_vae_latent)
leaveout_ckpt_flat_latent = pd.DataFrame(leaveout_ckpt_flat_latent)
leaveout_ckpt_geodesic_latent = pd.DataFrame(leaveout_ckpt_geodesic_latent)
leaveout_ckpt_previous_latent = pd.DataFrame(leaveout_ckpt_previous_latent)

# DATA SPACE METRICS
leaveout_ckpt_vae_data = pd.DataFrame(leaveout_ckpt_vae_data)
leaveout_ckpt_flat_data = pd.DataFrame(leaveout_ckpt_flat_data)
leaveout_ckpt_geodesic_data = pd.DataFrame(leaveout_ckpt_geodesic_data)
leaveout_ckpt_previous_data = pd.DataFrame(leaveout_ckpt_previous_data)

In [13]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent["time"] = tp
leaveout_ckpt_flat_latent["time"] = tp 
leaveout_ckpt_geodesic_latent["time"] = tp 
leaveout_ckpt_previous_latent["time"] = tp 

# DATA SPACE METRICS
leaveout_ckpt_vae_data["time"] = tp 
leaveout_ckpt_flat_data["time"] = tp 
leaveout_ckpt_geodesic_data["time"] = tp 
leaveout_ckpt_previous_data["time"] = tp 

## Latent space

**VAE**

In [14]:
leaveout_ckpt_vae_latent.mean(0)

1-Wasserstein    1.964602
2-Wasserstein    2.067490
Linear_MMD       0.087632
Poly_MMD         0.292767
RBF_MMD          0.233982
Mean_MSE         0.091551
Mean_L2          0.296550
Mean_L1          0.240507
rep              2.000000
time             2.000000
dtype: float64

In [15]:
leaveout_ckpt_vae_latent.std(0)/np.sqrt(9)

1-Wasserstein    0.068766
2-Wasserstein    0.065620
Linear_MMD       0.009279
Poly_MMD         0.015491
RBF_MMD          0.008062
Mean_MSE         0.013298
Mean_L2          0.021239
Mean_L1          0.011798
rep              0.288675
time             0.288675
dtype: float64

**Flat VAE**

In [16]:
leaveout_ckpt_flat_latent.mean(0)

1-Wasserstein    1.413430
2-Wasserstein    1.543451
Linear_MMD       0.076828
Poly_MMD         0.266540
RBF_MMD          0.223911
Mean_MSE         0.080226
Mean_L2          0.270242
Mean_L1          0.227224
rep              2.000000
time             2.000000
dtype: float64

In [17]:
leaveout_ckpt_flat_latent.std(0)/np.sqrt(9)

1-Wasserstein    0.084468
2-Wasserstein    0.085379
Linear_MMD       0.013930
Poly_MMD         0.026889
RBF_MMD          0.024301
Mean_MSE         0.015449
Mean_L2          0.029989
Mean_L1          0.028668
rep              0.288675
time             0.288675
dtype: float64

**Geodesic VAE**

In [18]:
leaveout_ckpt_geodesic_latent.mean(0)

1-Wasserstein    2.062476
2-Wasserstein    2.156820
Linear_MMD       0.148777
Poly_MMD         0.359872
RBF_MMD          0.309058
Mean_MSE         0.183059
Mean_L2          0.397406
Mean_L1          0.340322
rep              2.000000
time             2.000000
dtype: float64

In [19]:
leaveout_ckpt_geodesic_latent.std(0)/np.sqrt(9)

1-Wasserstein    0.132530
2-Wasserstein    0.138736
Linear_MMD       0.040133
Poly_MMD         0.049077
RBF_MMD          0.042275
Mean_MSE         0.050251
Mean_L2          0.056044
Mean_L1          0.048060
rep              0.288675
time             0.288675
dtype: float64

**Baseline**

In [20]:
leaveout_ckpt_previous_latent.mean(0)

1-Wasserstein    2.570373
2-Wasserstein    2.694164
Linear_MMD       0.281414
Poly_MMD         0.487632
RBF_MMD          0.388111
Mean_MSE         0.292014
Mean_L2          0.496841
Mean_L1          0.405017
rep              2.000000
time             2.000000
dtype: float64

In [21]:
leaveout_ckpt_previous_latent.std(0)/np.sqrt(9)

1-Wasserstein    0.195134
2-Wasserstein    0.199802
Linear_MMD       0.079529
Poly_MMD         0.073849
RBF_MMD          0.064728
Mean_MSE         0.080947
Mean_L2          0.075136
Mean_L1          0.067284
rep              0.288675
time             0.288675
dtype: float64

## Data space

### VAE

In [22]:
leaveout_ckpt_vae_data.mean(0)

1-Wasserstein    39.926498
2-Wasserstein    43.355890
Linear_MMD        0.084729
Poly_MMD          0.289782
RBF_MMD           0.093701
Mean_MSE          0.093223
Mean_L2           0.304837
Mean_L1           0.121863
precision         0.452893
recall            0.225640
density           0.204526
coverage          0.367233
rep               2.000000
time              2.000000
dtype: float64

In [23]:
(leaveout_ckpt_vae_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.076734
2-Wasserstein    0.190334
Linear_MMD       0.002389
Poly_MMD         0.004118
RBF_MMD          0.000393
Mean_MSE         0.002505
Mean_L2          0.004020
Mean_L1          0.001422
precision        0.010406
recall           0.007099
density          0.015175
coverage         0.014892
rep              0.577350
dtype: float64

### Flat VAE

In [24]:
leaveout_ckpt_flat_data.mean(0)

1-Wasserstein    39.051561
2-Wasserstein    41.995502
Linear_MMD        0.081517
Poly_MMD          0.281975
RBF_MMD           0.091962
Mean_MSE          0.101312
Mean_L2           0.317911
Mean_L1           0.155587
precision         0.538112
recall            0.320581
density           0.302635
coverage          0.491362
rep               2.000000
time              2.000000
dtype: float64

In [25]:
(leaveout_ckpt_flat_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.033349
2-Wasserstein    0.042345
Linear_MMD       0.000749
Poly_MMD         0.001277
RBF_MMD          0.000170
Mean_MSE         0.000545
Mean_L2          0.000839
Mean_L1          0.000468
precision        0.010167
recall           0.008574
density          0.011467
coverage         0.006163
rep              0.577350
dtype: float64

### Geodesic VAE

In [26]:
leaveout_ckpt_geodesic_data.mean(0)

1-Wasserstein     88.590220
2-Wasserstein    193.089633
Linear_MMD         0.972043
Poly_MMD           0.975561
RBF_MMD            0.175337
Mean_MSE           0.324564
Mean_L2            0.568741
Mean_L1            0.145377
precision          0.065568
recall             0.334207
density            0.008051
coverage           0.016565
rep                2.000000
time               2.000000
dtype: float64

In [27]:
(leaveout_ckpt_geodesic_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    1.066383
2-Wasserstein    9.547412
Linear_MMD       0.025047
Poly_MMD         0.012023
RBF_MMD          0.000915
Mean_MSE         0.003162
Mean_L2          0.002775
Mean_L1          0.000484
precision        0.001966
recall           0.012303
density          0.000386
coverage         0.000950
rep              0.577350
dtype: float64

### Baseline

In [28]:
leaveout_ckpt_previous_data.mean(0)

1-Wasserstein    45.482058
2-Wasserstein    47.811630
Linear_MMD        0.201387
Poly_MMD          0.437026
RBF_MMD           0.127186
Mean_MSE          0.185233
Mean_L2           0.413694
Mean_L1           0.077155
precision         0.189846
recall            0.408392
density           0.038347
coverage          0.124779
rep               2.000000
time              2.000000
dtype: float64

In [29]:
(leaveout_ckpt_previous_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.00000
2-Wasserstein    0.00000
Linear_MMD       0.00000
Poly_MMD         0.00000
RBF_MMD          0.00000
Mean_MSE         0.00000
Mean_L2          0.00000
Mean_L1          0.00000
precision        0.00000
recall           0.00000
density          0.00000
coverage         0.00000
rep              0.57735
dtype: float64