In [2]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import anndata
import scanpy as sc
import sklearn
import scvelo as scv
from pathlib import Path
import pertpy as pt

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

Initialize the device

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def cross_standardize(tensor1, tensor2, minmax_scale=False):
    """
    Standardize tensor across the rows
    """
    mean_t1, std_t1 = tensor1.mean(0), tensor1.std(0)
    if minmax_scale:
        min_t1, max_t1 = tensor1.min(0), tensor1.max(0)
        tensor1 = (tensor1 - min_t1) / (max_t1 - min_t1 + 1e-6)
        tensor2 = (tensor2 - min_t1) / (max_t1 - min_t1 + 1e-6)
    else:
        tensor1 = (tensor1 - mean_t1) / (std_t1 + 1e-6)
        tensor2 = (tensor2 - mean_t1) / (std_t1 + 1e-6)
    return tensor1, tensor2

def update_dict(ref, tgt):
    """
    Update a dictionary with the values of another 
    """
    for key in tgt:
        if key not in ref:
            ref[key] = []
        ref[key].append(tgt[key])
    return ref

# def distance_metrics_pertpy(ref, trg):
#     """
#     Distance metrics using pertpy
#     """
#     adata_ref, adata_trg = ref.copy(), trg.copy()
#     adata_together = sc.AnnData(X=np.concatenate([adata_ref.X/10,
#                                                   adata_trg.X/10], axis=0))
#     annot = ["reference" for _ in range(len(adata_ref))] + ["target" for _ in range(len(adata_trg))]
#     annot_df = pd.DataFrame(annot)
#     annot_df.columns = ["dataset_type"]
#     adata_together.obs = annot_df
#     adata_together.obsm["X_data"] = adata_together.X.copy()
#     # Wasserstein
#     w_distance = pt.tl.Distance("wasserstein", obsm_key="X_data")
#     # MMD
#     mmd_distance = pt.tl.Distance("mmd", obsm_key="X_data")
#     df_w = w_distance.pairwise(adata_together, groupby="dataset_type").to_numpy()
#     df_mmd = mmd_distance.pairwise(adata_together, groupby="dataset_type", verbose=False).to_numpy()
#     return {'wassersetein':df_w.max(), 
#             'mmd':df_mmd.max()}

def distance_metrics_pertpy(ref, trg):
    """
    Distance metrics using pertpy
    """
    d =  compute_distribution_distances(torch.tensor(ref.X[:,None,:]), 
                                             torch.tensor(trg.X[:,None,:]))
    return dict(zip(d[0], d[1]))

Initialize datamodule to set vaes up

In [4]:
datamodule={'path': PROJECT_FOLDER / 'data/schiebinger_et_al/processed/schiebinger_et_al.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'cell_sets'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 64, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize autoencoders

In [6]:
vae_kwargs={'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
                   'use_c': False, 
                   'trainable_c': False,
                   'l2': True, 
                   'eta_interp': 0, 
                   'interpolate_z': False, 
                   'start_jac_after': 0, 
                   'fl_weight': 0.1,
                   'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
                  "hidden_dims": [256, 10],
                  "batch_norm": True,
                  "dropout": False, 
                  "dropout_p": False,
                  "likelihood": "nb",
                  "learning_rate": 0.001}

# Initialize vae and geometric vae
vae = VAE(**vae_kwargs).to(device)
geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/schiebinger_et_al/best_model_vae.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/schiebinger_et_al/best_model_geometric.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/schiebinger_et_al/best_model_geodesic.ckpt")["state_dict"])

<All keys matched successfully>

## Setup CFMs

In [7]:
leavout_timepoints_folder = Path("/home/icb/alessandro.palma/environment/scCFM/project_dir/checkpoints/trajectory/schiebinger_et_al")

Initialize datamodule for trajectory

In [8]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/schiebinger_et_al/flat/schiebinger_et_al_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/schiebinger_et_al/flat/schiebinger_et_al_flat_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

datamodule_kwargs_geodesic= {'path': PROJECT_FOLDER / 'data/schiebinger_et_al/flat/schiebinger_et_al_geodesic.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)
datamodule_geodesic = TrajectoryDataModule(**datamodule_kwargs_geodesic)

## Read data

First, read the latent space anndata and plot the results

In [9]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "schiebinger_et_al" / "flat" / "schiebinger_et_al_lib.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "schiebinger_et_al" / "flat" / "schiebinger_et_al_flat_lib.h5ad")
adata_latent_geodesic = sc.read_h5ad(DATA_DIR / "schiebinger_et_al" / "flat" / "schiebinger_et_al_geodesic.h5ad")

# Read real anndata
adata_schiebinger_original = sc.read_h5ad(PROJECT_FOLDER / 'data/schiebinger_et_al/processed/schiebinger_et_al.h5ad')
sc.tl.pca(adata_schiebinger_original, n_comps=50)
adata_schiebinger_original.X = adata_schiebinger_original.layers["X_norm"].A.copy()

Number of experiments 

In [10]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0,
 1: 0.027777777777777776,
 2: 0.05555555555555555,
 3: 0.08333333333333333,
 4: 0.1111111111111111,
 5: 0.1388888888888889,
 6: 0.16666666666666666,
 7: 0.19444444444444445,
 8: 0.2222222222222222,
 9: 0.25,
 10: 0.2777777777777778,
 11: 0.3055555555555556,
 12: 0.3333333333333333,
 13: 0.3611111111111111,
 14: 0.3888888888888889,
 15: 0.4166666666666667,
 16: 0.4444444444444444,
 17: 0.4583333333333333,
 18: 0.4722222222222222,
 19: 0.4861111111111111,
 20: 0.5,
 21: 0.5277777777777778,
 22: 0.5555555555555556,
 23: 0.5833333333333334,
 24: 0.6111111111111112,
 25: 0.6388888888888888,
 26: 0.6666666666666666,
 27: 0.6944444444444444,
 28: 0.7222222222222222,
 29: 0.75,
 30: 0.7777777777777778,
 31: 0.8055555555555556,
 32: 0.8333333333333334,
 33: 0.8611111111111112,
 34: 0.8888888888888888,
 35: 0.9166666666666666,
 36: 0.9444444444444444,
 37: 0.9722222222222222,
 38: 1.0}

Initialize model

In [11]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [12]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = {}
leaveout_ckpt_flat_latent = {}
leaveout_ckpt_geodesic_latent = {}
leaveout_ckpt_previous_latent = {}

# DATA SPACE METRICS
leaveout_ckpt_vae_data = {}
leaveout_ckpt_flat_data = {}
leaveout_ckpt_geodesic_data = {}
leaveout_ckpt_previous_data = {}

for rep in range(1,4):
    for tp in [2, 5, 10, 15, 20, 25, 30]:
        print(f"Time point {tp}")
        #Pick time 0 observations
        X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    
        # Pick observations next timepoint 
        X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[tp]].X).to(device)    
    
        # Collect PCs    
        adata_real = adata_schiebinger_original[adata_schiebinger_original.obs["experimental_time"]==idx2time[tp]]
        X_adata_real_pca = torch.from_numpy(adata_real.obsm["X_pca"]).to(device)
        X_adata_real = torch.from_numpy(adata_real.layers["X_log"].A).to(device)
        
        #Pick library sizes
        l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_geodesic = adata_latent_geodesic.obs.loc[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
        l_t0_flat = torch.from_numpy(l_t0_flat).to(device)
        l_t0_geodesic = torch.from_numpy(l_t0_geodesic).to(device)
    
        # Initialize nets
        net_vae = MLP(**net_hparams).to(device)
        net_flat = MLP(**net_hparams).to(device)
        net_geodesic = MLP(**net_hparams).to(device)
        cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
        cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)
        cfm_geodesic = CFMLitModule(net=net_geodesic, datamodule=datamodule_geodesic, **cfm_kwargs).to(device)
    
        # Read the checkpoints
        cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"schiebinger_vae_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"schiebinger_flat_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_geodesic.load_state_dict(torch.load(leavout_timepoints_folder / f"schiebinger_geodesic_leaveout_{tp}_{rep}.ckpt")["state_dict"])
    
        mu_adata_predicted_vae, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, 
                                                                                     l_t0_vae, 
                                                                                     tp-1, 
                                                                                     cfm_vae, 
                                                                                     vae)
                                                                                    
        mu_adata_predicted_flat, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, 
                                                                                       l_t0_flat, 
                                                                                       tp-1, 
                                                                                       cfm_flat, 
                                                                                       geometric_vae)
                                                                                      
        mu_adata_predicted_geodesic, X_adata_predicted_geodesic, X_adata_latent_geodesic = decode_trajectory_single_step(X_adata_t0_latent_geodesic, 
                                                                                               l_t0_geodesic, 
                                                                                               tp-1, 
                                                                                               cfm_geodesic, 
                                                                                               geodesic_ae, 
                                                                                               model_type="geodesic_ae")
    
        ## PREDICT LATENT TRAJECTORIES 
        X_adata_t1_latent_vae, X_adata_latent_vae = cross_standardize(X_adata_t1_latent_vae, X_adata_latent_vae[:,:-1])
        X_adata_t1_latent_flat, X_adata_latent_flat = cross_standardize(X_adata_t1_latent_flat, X_adata_latent_flat[:,:-1])
        X_adata_t1_latent_geodesic, X_adata_latent_geodesic = cross_standardize(X_adata_t1_latent_geodesic, X_adata_latent_geodesic[:,:-1])
                                                                                       
        d_dist_vae_l = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"), 
                                             X_adata_latent_vae.unsqueeze(1).to("cpu"))
        d_dist_flat_l = compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
                                             X_adata_latent_flat.unsqueeze(1).to("cpu"))
        d_dist_geod_l = compute_distribution_distances(X_adata_t1_latent_geodesic.unsqueeze(1).to("cpu"),
                                             X_adata_latent_geodesic.unsqueeze(1).to("cpu"))
        d_dist_prev = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"),
                                             X_adata_t0_latent_vae.unsqueeze(1).to("cpu"))
        
        d_dist_vae_l = dict(zip(d_dist_vae_l[0], d_dist_vae_l[1]))
        d_dist_flat_l = dict(zip(d_dist_flat_l[0], d_dist_flat_l[1]))
        d_dist_geod_l = dict(zip(d_dist_geod_l[0], d_dist_geod_l[1]))
        d_dist_prev_l = dict(zip(d_dist_prev[0], d_dist_prev[1]))

        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, d_dist_vae_l)
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, d_dist_flat_l)
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, d_dist_geod_l)
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, d_dist_prev_l)
        
        print("predict decoded trajectory")
        X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
        X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
        X_adata_predicted_geodesic = anndata.AnnData(X=X_adata_predicted_geodesic.numpy())
        mu_adata_predicted_vae = anndata.AnnData(X=mu_adata_predicted_vae.numpy())
        mu_adata_predicted_flat = anndata.AnnData(X=mu_adata_predicted_flat.numpy())
        mu_adata_predicted_geodesic = anndata.AnnData(X=mu_adata_predicted_geodesic.numpy())
        
        X_adata_prev = adata_schiebinger_original[adata_schiebinger_original.obs["experimental_time"]==idx2time[tp-1]].copy()
        # rnd_idx = np.random.choice(range(len(adata_eb_original)), size=len(X_adata_predicted_vae), replace=False)
        # X_adata_prev = adata_eb_original[rnd_idx]

        X_adata_predicted_vae.layers["X_norm"] = X_adata_predicted_vae.X.copy()
        X_adata_predicted_flat.layers["X_norm"] = X_adata_predicted_flat.X.copy()
        X_adata_predicted_geodesic.layers["X_norm"] = np.exp(X_adata_predicted_geodesic.X.copy())-1
        X_adata_prev.layers["X_norm"] = X_adata_prev.X.copy()
        sc.pp.log1p(X_adata_predicted_vae)
        sc.pp.log1p(X_adata_predicted_flat)
        sc.pp.log1p(X_adata_prev)
        
        sc.tl.pca(X_adata_predicted_vae, n_comps=50)
        sc.tl.pca(X_adata_predicted_flat, n_comps=50)
        sc.tl.pca(X_adata_predicted_geodesic, n_comps=50)
        sc.tl.pca(X_adata_prev, n_comps=50)
        X_adata_predicted_vae.X = X_adata_predicted_vae.layers["X_norm"].copy()
        X_adata_predicted_flat.X = X_adata_predicted_flat.layers["X_norm"].copy()
        X_adata_predicted_geodesic.X = X_adata_predicted_geodesic.layers["X_norm"].copy()
        X_adata_prev.X = X_adata_prev.layers["X_norm"].copy()
        
        # Compute den
        d_dist_vae_d = compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_flat_d = compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_geodesic_d = compute_prdc(torch.from_numpy(X_adata_predicted_geodesic.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_prev_d = compute_prdc(torch.from_numpy(X_adata_prev.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)

        # d_dist_w_mmd_vae = distance_metrics_pertpy(adata_real, X_adata_predicted_vae)
        # d_dist_w_mmd_flat =distance_metrics_pertpy(adata_real, X_adata_predicted_flat)
        # d_dist_w_mmd_geodesic =distance_metrics_pertpy(adata_real, X_adata_predicted_geodesic)
        # d_dist_w_mmd_prev =distance_metrics_pertpy(adata_real, X_adata_prev)
        d_dist_w_mmd_vae = distance_metrics_pertpy(adata_real, mu_adata_predicted_vae)
        d_dist_w_mmd_flat =distance_metrics_pertpy(adata_real, mu_adata_predicted_flat)
        d_dist_w_mmd_geodesic =distance_metrics_pertpy(adata_real, X_adata_predicted_geodesic)
        d_dist_w_mmd_prev =distance_metrics_pertpy(adata_real, X_adata_prev)
        
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_w_mmd_vae)
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_w_mmd_flat)
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_w_mmd_geodesic)
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, d_dist_w_mmd_prev)
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_vae_d)
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_flat_d)
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_geodesic_d)
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, d_dist_prev_d)

        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, {"rep": rep})
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, {"rep": rep})
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, {"rep": rep})
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, {"rep": rep})
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, {"rep": rep})
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, {"rep": rep})
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, {"rep": rep})
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, {"rep": rep})

Time point 2
predict decoded trajectory
Time point 5
predict decoded trajectory
Time point 10
predict decoded trajectory
Time point 15
predict decoded trajectory
Time point 20
predict decoded trajectory
Time point 25
predict decoded trajectory
Time point 30
predict decoded trajectory
Time point 2
predict decoded trajectory
Time point 5
predict decoded trajectory
Time point 10
predict decoded trajectory
Time point 15
predict decoded trajectory
Time point 20
predict decoded trajectory
Time point 25
predict decoded trajectory
Time point 30
predict decoded trajectory
Time point 2
predict decoded trajectory
Time point 5
predict decoded trajectory
Time point 10
predict decoded trajectory
Time point 15
predict decoded trajectory
Time point 20
predict decoded trajectory
Time point 25
predict decoded trajectory
Time point 30
predict decoded trajectory


In [13]:
tp = [1,2,3,4,5,6,7]*3

In [14]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = pd.DataFrame(leaveout_ckpt_vae_latent)
leaveout_ckpt_flat_latent = pd.DataFrame(leaveout_ckpt_flat_latent)
leaveout_ckpt_geodesic_latent = pd.DataFrame(leaveout_ckpt_geodesic_latent)
leaveout_ckpt_previous_latent = pd.DataFrame(leaveout_ckpt_previous_latent)

# DATA SPACE METRICS
leaveout_ckpt_vae_data = pd.DataFrame(leaveout_ckpt_vae_data)
leaveout_ckpt_flat_data = pd.DataFrame(leaveout_ckpt_flat_data)
leaveout_ckpt_geodesic_data = pd.DataFrame(leaveout_ckpt_geodesic_data)
leaveout_ckpt_previous_data = pd.DataFrame(leaveout_ckpt_previous_data)

In [15]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent["time"] = tp
leaveout_ckpt_flat_latent["time"] = tp 
leaveout_ckpt_geodesic_latent["time"] = tp 
leaveout_ckpt_previous_latent["time"] = tp 

# DATA SPACE METRICS
leaveout_ckpt_vae_data["time"] = tp 
leaveout_ckpt_flat_data["time"] = tp 
leaveout_ckpt_geodesic_data["time"] = tp 
leaveout_ckpt_previous_data["time"] = tp 

## Latent space

Geodesic

In [16]:
leaveout_ckpt_geodesic_latent.mean(0)

1-Wasserstein    2.329433
2-Wasserstein    2.492044
Linear_MMD       0.385581
Poly_MMD         0.535126
RBF_MMD          0.454486
Mean_MSE         0.426298
Mean_L2          0.570427
Mean_L1          0.483610
rep              2.000000
time             4.000000
dtype: float64

In [43]:
(leaveout_ckpt_geodesic_latent.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.015211
2-Wasserstein    0.016414
Linear_MMD       0.008611
Poly_MMD         0.007975
RBF_MMD          0.010070
Mean_MSE         0.008860
Mean_L2          0.008126
Mean_L1          0.010147
rep              0.577350
dtype: float64

VAE

In [18]:
leaveout_ckpt_vae_latent.mean(0)

1-Wasserstein    1.999991
2-Wasserstein    2.078605
Linear_MMD       0.191206
Poly_MMD         0.381490
RBF_MMD          0.292974
Mean_MSE         0.211045
Mean_L2          0.402064
Mean_L1          0.314145
rep              2.000000
time             4.000000
dtype: float64

In [42]:
(leaveout_ckpt_vae_latent.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.021897
2-Wasserstein    0.021782
Linear_MMD       0.008896
Poly_MMD         0.010692
RBF_MMD          0.012825
Mean_MSE         0.009696
Mean_L2          0.011361
Mean_L1          0.012056
rep              0.577350
dtype: float64

Flat

In [20]:
leaveout_ckpt_flat_latent.mean(0)

1-Wasserstein    1.545736
2-Wasserstein    1.638642
Linear_MMD       0.163053
Poly_MMD         0.346271
RBF_MMD          0.289706
Mean_MSE         0.170591
Mean_L2          0.360636
Mean_L1          0.299171
rep              2.000000
time             4.000000
dtype: float64

In [41]:
(leaveout_ckpt_flat_latent.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.035848
2-Wasserstein    0.036462
Linear_MMD       0.009945
Poly_MMD         0.015722
RBF_MMD          0.013542
Mean_MSE         0.009390
Mean_L2          0.013847
Mean_L1          0.011714
rep              0.577350
dtype: float64

Baseline

In [22]:
leaveout_ckpt_previous_latent.mean(0)

1-Wasserstein    3.138594
2-Wasserstein    3.213354
Linear_MMD       0.752031
Poly_MMD         0.828778
RBF_MMD          0.699239
Mean_MSE         0.767308
Mean_L2          0.834317
Mean_L1          0.698441
rep              2.000000
time             4.000000
dtype: float64

In [34]:
(leaveout_ckpt_previous_latent.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.00000
2-Wasserstein    0.00000
Linear_MMD       0.00000
Poly_MMD         0.00000
RBF_MMD          0.00000
Mean_MSE         0.00000
Mean_L2          0.00000
Mean_L1          0.00000
rep              0.57735
dtype: float64

## Data space

**Flat**

Geodesic

In [24]:
leaveout_ckpt_geodesic_data.mean(0)

1-Wasserstein     150.621103
2-Wasserstein    2988.165986
Linear_MMD          6.668410
Poly_MMD            2.263231
RBF_MMD             0.363236
Mean_MSE            0.693917
Mean_L2             0.781490
Mean_L1             0.183662
precision           0.036196
recall              0.230541
density             0.003977
coverage            0.001679
rep                 2.000000
time                4.000000
dtype: float64

In [25]:
(leaveout_ckpt_geodesic_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein     0.657903
2-Wasserstein    35.749086
Linear_MMD        0.106061
Poly_MMD          0.014464
RBF_MMD           0.000501
Mean_MSE          0.000292
Mean_L2           0.000178
Mean_L1           0.000066
precision         0.007817
recall            0.002312
density           0.000934
coverage          0.000142
rep               0.577350
dtype: float64

VAE

In [26]:
leaveout_ckpt_vae_data.mean(0)

1-Wasserstein     91.733385
2-Wasserstein    103.294799
Linear_MMD         1.361351
Poly_MMD           1.087649
RBF_MMD            0.286201
Mean_MSE           0.682997
Mean_L2            0.777758
Mean_L1            0.188730
precision          0.344058
recall             0.050446
density            0.118456
coverage           0.093550
rep                2.000000
time               4.000000
dtype: float64

In [27]:
(leaveout_ckpt_vae_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.591066
2-Wasserstein    0.780703
Linear_MMD       0.015506
Poly_MMD         0.006174
RBF_MMD          0.002134
Mean_MSE         0.006587
Mean_L2          0.004072
Mean_L1          0.001397
precision        0.025628
recall           0.003773
density          0.015675
coverage         0.004605
rep              0.577350
dtype: float64

Flat

In [28]:
leaveout_ckpt_flat_data.mean(0)

1-Wasserstein    86.025011
2-Wasserstein    97.125271
Linear_MMD        1.391957
Poly_MMD          1.098078
RBF_MMD           0.285913
Mean_MSE          0.852192
Mean_L2           0.870995
Mean_L1           0.252778
precision         0.378826
recall            0.072305
density           0.146882
coverage          0.127343
rep               2.000000
time              4.000000
dtype: float64

In [29]:
(leaveout_ckpt_flat_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.461155
2-Wasserstein    0.556764
Linear_MMD       0.021463
Poly_MMD         0.008520
RBF_MMD          0.002822
Mean_MSE         0.020882
Mean_L2          0.009633
Mean_L1          0.003751
precision        0.065256
recall           0.005225
density          0.047184
coverage         0.013797
rep              0.577350
dtype: float64

Baseline

In [30]:
leaveout_ckpt_previous_data.mean(0)

1-Wasserstein    82.719952
2-Wasserstein    92.218817
Linear_MMD        1.078137
Poly_MMD          0.861033
RBF_MMD           0.185775
Mean_MSE          0.583367
Mean_L2           0.665078
Mean_L1           0.115671
precision         0.270361
recall            0.072238
density           0.123425
coverage          0.071784
rep               2.000000
time              4.000000
dtype: float64

In [31]:
(leaveout_ckpt_previous_data.groupby("time").std()/np.sqrt(3)).mean(0)

1-Wasserstein    0.00000
2-Wasserstein    0.00000
Linear_MMD       0.00000
Poly_MMD         0.00000
RBF_MMD          0.00000
Mean_MSE         0.00000
Mean_L2          0.00000
Mean_L1          0.00000
precision        0.00000
recall           0.00000
density          0.00000
coverage         0.00000
rep              0.57735
dtype: float64