In [None]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import anndata
import scanpy as sc
import sklearn
import scvelo as scv

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

Initialize the device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def cross_standardize(tensor1, tensor2):
    """
    Standardize tensor across the rows
    """
    mean_t1, std_t1 = tensor1.mean(0), tensor1.std(0)
    tensor1 = (tensor1 - mean_t1) / (std_t1 + 1e-6)
    tensor2 = (tensor2 - mean_t1) / (std_t1 + 1e-6)
    return tensor1, tensor2

def update_dict(ref, tgt):
    for key in tgt:
        if key not in ref:
            ref[key] = []
        ref[key].append(tgt[key])
    return ref

Initialize datamodule

In [None]:
datamodule={'path': PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'leiden'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize autoencoders

In [None]:
vae_kwargs={'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
           'use_c': True, 
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
          "hidden_dims": [256, 10],
          "batch_norm": True,
          "dropout": False, 
          "dropout_p": False,
          "likelihood": "nb",
          "learning_rate": 0.001}

# Initialize vae and geometric vae
vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geometric_lib.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geodesic_ae.ckpt")["state_dict"])

# vae.eval()
# geometric_vae.eval()
# geodesic_ae.eval()

## Setup CFMs

In [None]:
leavout_timepoints_folder = CKPT_FOLDER / "trajectory" / "eb"

Initialize datamodule for trajectory

In [None]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_flat_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

datamodule_kwargs_geodesic= {'path': PROJECT_FOLDER / 'data/eb/flat/eb_geodesic.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)
datamodule_geodesic = TrajectoryDataModule(**datamodule_kwargs_geodesic)

# Mapping real times to index
idx2time = datamodule_vae.idx2time

## Read data

First, read the latent space anndata and plot the results

In [None]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_lib.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_flat_lib.h5ad")
adata_latent_geodesic = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_geodesic.h5ad")

# Read real anndata
adata_eb_original = sc.read_h5ad(PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad')
sc.tl.pca(adata_eb_original, n_comps=50)
adata_eb_original.X = adata_eb_original.layers["X_norm"].copy()

Number of experiments 

In [40]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0, 1: 0.25, 2: 0.5, 3: 0.75, 4: 1.0}

Initialize model

In [41]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [42]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = {}
leaveout_ckpt_flat_latent = {}
leaveout_ckpt_geodesic_latent = {}
leaveout_ckpt_previous_latent = {}

In [43]:
# DATA SPACE METRICS
leaveout_ckpt_vae_data = {}
leaveout_ckpt_flat_data = {}
leaveout_ckpt_geodesic_data = {}
leaveout_ckpt_previous = {}

In [47]:
for tp in range(1, n_timepoints-1):
    print(f"Time point {tp}")
    #Pick time 0 observations
    X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t0_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)

    # Pick observations next timepoint 
    X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_t1_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[tp]].X).to(device)    

    # Collect PCs    
    X_adata_real_pca = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].obsm["X_pca"]).to(device)
    X_adata_real = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].layers["X_log"].A).to(device)
    adata_real = adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]]
    
    #Pick library sizes
    l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_geodesic = adata_latent_geodesic.obs.loc[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()

    #Pick library sizes
    l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
    l_t0_flat = torch.from_numpy(l_t0_flat).to(device)
    l_t0_geodesic = torch.from_numpy(l_t0_geodesic).to(device)

    # Initialize nets
    net_vae = MLP(**net_hparams).to(device)
    net_flat = MLP(**net_hparams).to(device)
    net_geodesic = MLP(**net_hparams).to(device)
    cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
    cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)
    cfm_geodesic = CFMLitModule(net=net_geodesic, datamodule=datamodule_geodesic, **cfm_kwargs).to(device)

    # Read the checkpoints
    cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_vae_leaveout_{tp}.ckpt")["state_dict"])
    cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_flat_leaveout_{tp}.ckpt")["state_dict"])
    cfm_geodesic.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_geodesic_leaveout_{tp}.ckpt")["state_dict"])

    mu_adata_predicted_vae, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, 
                                                                                 l_t0_vae, 
                                                                                 tp-1, 
                                                                                 cfm_vae, 
                                                                                 vae)
                                                                                
    mu_adata_predicted_flat, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, 
                                                                                   l_t0_flat, 
                                                                                   tp-1, 
                                                                                   cfm_flat, 
                                                                                   geometric_vae)
                                                                                  
    mu_adata_predicted_geodesic, X_adata_predicted_geodesic, X_adata_latent_geodesic = decode_trajectory_single_step(X_adata_t0_latent_geodesic, 
                                                                                           l_t0_geodesic, 
                                                                                           tp-1, 
                                                                                           cfm_geodesic, 
                                                                                           geodesic_ae, 
                                                                                           model_type="geodesic_ae")

    ## PREDICT LATENT TRAJECTORIES 
    print("predict latent trajectory")
    X_adata_t1_latent_vae, X_adata_latent_vae = cross_standardize(X_adata_t1_latent_vae, X_adata_latent_vae[:,:-1])
    X_adata_t1_latent_flat, X_adata_latent_flat = cross_standardize(X_adata_t1_latent_flat, X_adata_latent_flat[:,:-1])
    X_adata_t1_latent_geodesic, X_adata_latent_geodesic = cross_standardize(X_adata_t1_latent_geodesic, X_adata_latent_geodesic[:,:-1])
                                                                                   
    d_dist_vae_l = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"), 
                                         X_adata_latent_vae.unsqueeze(1).to("cpu"))
    d_dist_flat_l = compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
                                         X_adata_latent_flat.unsqueeze(1).to("cpu"))
    d_dist_geod_l = compute_distribution_distances(X_adata_t1_latent_geodesic.unsqueeze(1).to("cpu"),
                                         X_adata_latent_geodesic.unsqueeze(1).to("cpu"))
    d_dist_prev = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"),
                                         X_adata_t0_latent_vae.unsqueeze(1).to("cpu"))
    
    d_dist_vae_l = dict(zip(d_dist_vae_l[0], d_dist_vae_l[1]))
    d_dist_flat_l = dict(zip(d_dist_flat_l[0], d_dist_flat_l[1]))
    d_dist_geod_l = dict(zip(d_dist_geod_l[0], d_dist_geod_l[1]))
    d_dist_prev_l = dict(zip(d_dist_prev[0], d_dist_prev[1]))
    
    leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, d_dist_vae_l)
    leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, d_dist_flat_l)
    leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, d_dist_geod_l)
    leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous, d_dist_prev_l)

    print("predict decoded trajectory")
    X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
    X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
    X_adata_predicted_geodesic = anndata.AnnData(X=X_adata_predicted_geodesic.numpy())
    X_adata_prev = adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp-1]]
    
    sc.pp.log1p(X_adata_predicted_vae)
    sc.pp.log1p(X_adata_predicted_flat)
    sc.pp.log1p(X_adata_prev)
    sc.tl.pca(X_adata_predicted_vae, n_comps=50)
    sc.tl.pca(X_adata_predicted_flat, n_comps=50)
    sc.tl.pca(X_adata_predicted_geodesic, n_comps=50)
    sc.tl.pca(X_adata_prev, n_comps=50)
    
    d_dist_vae_d = compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"]), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)
    d_dist_flat_d = compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"]), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)
    d_dist_geodesic_d = compute_prdc(torch.from_numpy(X_adata_predicted_geodesic.obsm["X_pca"]), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)
    d_dist_prev_d = compute_prdc(torch.from_numpy(X_adata_prev.obsm["X_pca"]), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)

    leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_vae_d)
    leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_flat_d)
    leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_geodesic_d)
    leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous, d_dist_prev_d)

Time point 1
predict decoded trajectory


  view_to_actual(adata)


Time point 2
predict decoded trajectory


  view_to_actual(adata)


Time point 3
predict decoded trajectory


  view_to_actual(adata)


In [48]:
pd.DataFrame(leaveout_ckpt_geodesic_latent).mean(0)

1-Wasserstein    2.189452
2-Wasserstein    2.287932
Linear_MMD       0.186619
Poly_MMD         0.406263
RBF_MMD          0.339504
Mean_MSE         0.233476
Mean_L2          0.454132
Mean_L1          0.386530
dtype: float64

In [49]:
pd.DataFrame(leaveout_ckpt_vae_latent).mean(0)

1-Wasserstein    2.007355
2-Wasserstein    2.109688
Linear_MMD       0.082285
Poly_MMD         0.285634
RBF_MMD          0.230820
Mean_MSE         0.089981
Mean_L2          0.297048
Mean_L1          0.241165
dtype: float64

In [50]:
pd.DataFrame(leaveout_ckpt_flat_latent).mean(0)

1-Wasserstein    1.504800
2-Wasserstein    1.635964
Linear_MMD       0.090862
Poly_MMD         0.291474
RBF_MMD          0.251780
Mean_MSE         0.093645
Mean_L2          0.294366
Mean_L1          0.256488
dtype: float64

In [51]:
pd.DataFrame(leaveout_ckpt_previous_latent).mean(0)

1-Wasserstein    2.754464
2-Wasserstein    2.879792
Linear_MMD       0.359602
Poly_MMD         0.558430
RBF_MMD          0.450974
Mean_MSE         0.370892
Mean_L2          0.567490
Mean_L1          0.469690
precision        0.183880
recall           0.466403
density          0.036168
coverage         0.126623
dtype: float64

**Data space**

In [52]:
pd.DataFrame(leaveout_ckpt_geodesic_data).mean(0)

precision    0.084946
recall       0.291387
density      0.011258
coverage     0.021029
dtype: float64

In [53]:
pd.DataFrame(leaveout_ckpt_vae_data).mean(0)

precision    0.492681
recall       0.211001
density      0.245806
coverage     0.404954
dtype: float64

In [54]:
pd.DataFrame(leaveout_ckpt_flat_data).mean(0)

precision    0.573021
recall       0.294820
density      0.348883
coverage     0.518967
dtype: float64

In [56]:
pd.DataFrame(leaveout_ckpt_previous_latent).mean(0)

1-Wasserstein    2.754464
2-Wasserstein    2.879792
Linear_MMD       0.359602
Poly_MMD         0.558430
RBF_MMD          0.450974
Mean_MSE         0.370892
Mean_L2          0.567490
Mean_L1          0.469690
precision        0.183880
recall           0.466403
density          0.036168
coverage         0.126623
dtype: float64