In [1]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import anndata
import scanpy as sc
import sklearn
import scvelo as scv

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Initialize the device

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def cross_standardize(tensor1, tensor2):
    """
    Standardize tensor across the rows
    """
    mean_t1, std_t1 = tensor1.mean(0), tensor1.std(0)
    tensor1 = (tensor1 - mean_t1) / (std_t1 + 1e-6)
    tensor2 = (tensor2 - mean_t1) / (std_t1 + 1e-6)
    return tensor1, tensor2

def update_dict(ref, tgt):
    """
    Update a dictionary with the values of another 
    """
    for key in tgt:
        if key not in ref:
            ref[key] = []
        ref[key].append(tgt[key])
    return ref

Initialize datamodule

In [4]:
datamodule={'path': PROJECT_FOLDER / 'data/hein_et_al/processed/unperturbed_time_course_host.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'cluster'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize autoencoders

In [5]:
vae_kwargs={'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
           'use_c': True, 
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
          "hidden_dims": [256, 10],
          "batch_norm": True,
          "dropout": False, 
          "dropout_p": False,
          "likelihood": "nb",
          "learning_rate": 0.001}

# Initialize vae and geometric vae
vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_geometric_lib.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_geodesic_ae.ckpt")["state_dict"])

<All keys matched successfully>

## Setup CFMs

In [6]:
leavout_timepoints_folder = CKPT_FOLDER / "trajectory" / "hein_et_al"

Initialize datamodule for trajectory

In [7]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_lib_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_flat_lib_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

datamodule_kwargs_geodesic= {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_geodesic_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)
datamodule_geodesic = TrajectoryDataModule(**datamodule_kwargs_geodesic)

# Mapping real times to index
idx2time = datamodule_vae.idx2time

## Read data

First, read the latent space anndata and plot the results

In [8]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_lib_complete.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_flat_lib_complete.h5ad")
adata_latent_geodesic = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_geodesic_complete.h5ad")

# Read real anndata
adata_hein_original = sc.read_h5ad(PROJECT_FOLDER / 'data/hein_et_al/processed/unperturbed_time_course_host.h5ad')
sc.tl.pca(adata_hein_original, n_comps=50)
adata_hein_original.X = adata_hein_original.layers["X_norm"].copy()

Number of experiments 

In [9]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0,
 1: 0.05,
 2: 0.16666666666666666,
 3: 0.23333333333333334,
 4: 0.4,
 5: 0.6,
 6: 0.8,
 7: 1.0}

Initialize model

In [10]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [27]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = {}
leaveout_ckpt_flat_latent = {}
leaveout_ckpt_geodesic_latent = {}
leaveout_ckpt_previous_latent = {}

In [28]:
# DATA SPACE METRICS
leaveout_ckpt_vae_data = {}
leaveout_ckpt_flat_data = {}
leaveout_ckpt_geodesic_data = {}
leaveout_ckpt_previous_data = {}

In [29]:
for rep in range(1, 4):
    for tp in range(1, n_timepoints-1):
        print(f"Time point {tp}")
        #Pick time 0 observations
        X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
        X_adata_t0_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    
        # Pick observations next timepoint 
        X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
        X_adata_t1_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[tp]].X).to(device)    
    
        # Collect PCs
        adata_real = adata_hein_original[adata_hein_original.obs["experimental_time"]==idx2time[tp]]
        X_adata_real_pca = torch.from_numpy(adata_real.obsm["X_pca"]).to(device)
        X_adata_real = torch.from_numpy(adata_real.layers["X_log"].A).to(device)
    
        #Pick library sizes
        l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
        l_t0_geodesic = adata_latent_geodesic.obs.loc[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    
        #Pick library sizes
        l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
        l_t0_flat = torch.from_numpy(l_t0_flat).to(device)
        l_t0_geodesic = torch.from_numpy(l_t0_geodesic).to(device)
    
        # Initialize nets
        net_vae = MLP(**net_hparams).to(device)
        net_flat = MLP(**net_hparams).to(device)
        net_geodesic = MLP(**net_hparams).to(device)
        cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
        cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)
        cfm_geodesic = CFMLitModule(net=net_geodesic, datamodule=datamodule_geodesic, **cfm_kwargs).to(device)
    
        # Read the checkpoints
        cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_vae_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_flat_leaveout_{tp}_{rep}.ckpt")["state_dict"])
        cfm_geodesic.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_geodesic_leaveout_{tp}_{rep}.ckpt")["state_dict"])
    
        mu_adata_predicted_vae, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, 
                                                                                     l_t0_vae, 
                                                                                     tp-1, 
                                                                                     cfm_vae, 
                                                                                     vae)
                                                                                    
        mu_adata_predicted_flat, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, 
                                                                                       l_t0_flat, 
                                                                                       tp-1, 
                                                                                       cfm_flat, 
                                                                                       geometric_vae)
                                                                                      
        mu_adata_predicted_geodesic, X_adata_predicted_geodesic, X_adata_latent_geodesic = decode_trajectory_single_step(X_adata_t0_latent_geodesic, 
                                                                                               l_t0_geodesic, 
                                                                                               tp-1, 
                                                                                               cfm_geodesic, 
                                                                                               geodesic_ae, 
                                                                                               model_type="geodesic_ae")
    
        ## PREDICT LATENT TRAJECTORIES 
        X_adata_t1_latent_vae, X_adata_latent_vae = cross_standardize(X_adata_t1_latent_vae, X_adata_latent_vae[:,:-1])
        X_adata_t1_latent_flat, X_adata_latent_flat = cross_standardize(X_adata_t1_latent_flat, X_adata_latent_flat[:,:-1])
        X_adata_t1_latent_geodesic, X_adata_latent_geodesic = cross_standardize(X_adata_t1_latent_geodesic, X_adata_latent_geodesic[:,:-1])
                                                                                       
        d_dist_vae_l = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"), 
                                             X_adata_latent_vae.unsqueeze(1).to("cpu"))
        d_dist_flat_l = compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
                                             X_adata_latent_flat.unsqueeze(1).to("cpu"))
        d_dist_geod_l = compute_distribution_distances(X_adata_t1_latent_geodesic.unsqueeze(1).to("cpu"),
                                             X_adata_latent_geodesic.unsqueeze(1).to("cpu"))
        d_dist_prev = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"),
                                             X_adata_t0_latent_vae.unsqueeze(1).to("cpu"))
        
        d_dist_vae_l = dict(zip(d_dist_vae_l[0], d_dist_vae_l[1]))
        d_dist_flat_l = dict(zip(d_dist_flat_l[0], d_dist_flat_l[1]))
        d_dist_geod_l = dict(zip(d_dist_geod_l[0], d_dist_geod_l[1]))
        d_dist_prev_l = dict(zip(d_dist_prev[0], d_dist_prev[1]))
    
        print("predict decoded trajectory")
        X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
        X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
        X_adata_predicted_geodesic = anndata.AnnData(X=X_adata_predicted_geodesic.numpy())
        X_adata_prev = adata_hein_original[adata_hein_original.obs["experimental_time"]==idx2time[tp-1]]
        
        sc.pp.log1p(X_adata_predicted_vae)
        sc.pp.log1p(X_adata_predicted_flat)
        sc.pp.log1p(X_adata_prev)
        sc.tl.pca(X_adata_predicted_vae, n_comps=50)
        sc.tl.pca(X_adata_predicted_flat, n_comps=50)
        sc.tl.pca(X_adata_predicted_geodesic, n_comps=50)
        sc.tl.pca(X_adata_prev, n_comps=50)
        
        d_dist_vae_d = compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"].copy()), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_flat_d = compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"].copy()), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_geodesic_d = compute_prdc(torch.from_numpy(X_adata_predicted_geodesic.obsm["X_pca"].copy()), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)
        d_dist_prev_d = compute_prdc(torch.from_numpy(X_adata_prev.obsm["X_pca"]), 
                                                 X_adata_real_pca.to("cpu"), nearest_k=10)

        # UPDATE DICTS 
        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, d_dist_vae_l)
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, d_dist_flat_l)
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, d_dist_geod_l)
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, d_dist_prev_l)
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_vae_d)
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_flat_d)
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_geodesic_d)
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, d_dist_prev_d)
        
        leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, {"rep": rep})
        leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, {"rep": rep})
        leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, {"rep": rep})
        leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous_latent, {"rep": rep})
        leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, {"rep": rep})
        leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, {"rep": rep})
        leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, {"rep": rep})
        leaveout_ckpt_previous_data = update_dict(leaveout_ckpt_previous_data, {"rep": rep})

Time point 1
predict decoded trajectory


  view_to_actual(adata)


Time point 2
predict decoded trajectory


  view_to_actual(adata)


Time point 3
predict decoded trajectory


  view_to_actual(adata)


Time point 4
predict decoded trajectory


  view_to_actual(adata)


Time point 5
predict decoded trajectory


  view_to_actual(adata)


Time point 6
predict decoded trajectory


  view_to_actual(adata)


Time point 1
predict decoded trajectory


  view_to_actual(adata)


Time point 2
predict decoded trajectory


  view_to_actual(adata)


Time point 3
predict decoded trajectory


  view_to_actual(adata)


Time point 4
predict decoded trajectory


  view_to_actual(adata)


Time point 5
predict decoded trajectory


  view_to_actual(adata)


Time point 6
predict decoded trajectory


  view_to_actual(adata)


Time point 1
predict decoded trajectory


  view_to_actual(adata)


Time point 2
predict decoded trajectory


  view_to_actual(adata)


Time point 3
predict decoded trajectory


  view_to_actual(adata)


Time point 4
predict decoded trajectory


  view_to_actual(adata)


Time point 5
predict decoded trajectory


  view_to_actual(adata)


Time point 6
predict decoded trajectory


  view_to_actual(adata)


**Latent**

VAE

In [95]:
pd.DataFrame(leaveout_ckpt_vae_latent).mean(0)

1-Wasserstein    2.016699
2-Wasserstein    2.121451
Linear_MMD       0.109529
Poly_MMD         0.294836
RBF_MMD          0.240606
Mean_MSE         0.124335
Mean_L2          0.321971
Mean_L1          0.258785
rep              2.000000
dtype: float64

In [96]:
pd.DataFrame(leaveout_ckpt_vae_latent).std(0)/np.sqrt(18)

1-Wasserstein    0.084612
2-Wasserstein    0.082567
Linear_MMD       0.024541
Poly_MMD         0.036462
RBF_MMD          0.028124
Mean_MSE         0.024962
Mean_L2          0.034869
Mean_L1          0.028204
rep              0.198030
dtype: float64

Flat VAE

In [97]:
pd.DataFrame(leaveout_ckpt_flat_latent).mean(0)

1-Wasserstein    1.784052
2-Wasserstein    1.962076
Linear_MMD       0.159411
Poly_MMD         0.357878
RBF_MMD          0.294699
Mean_MSE         0.164320
Mean_L2          0.374722
Mean_L1          0.305853
rep              2.000000
dtype: float64

In [98]:
pd.DataFrame(leaveout_ckpt_flat_latent).std(0)/np.sqrt(18)

1-Wasserstein    0.125382
2-Wasserstein    0.135599
Linear_MMD       0.032181
Poly_MMD         0.042932
RBF_MMD          0.035017
Mean_MSE         0.028401
Mean_L2          0.037497
Mean_L1          0.028780
rep              0.198030
dtype: float64

Geodesic VAE

In [99]:
pd.DataFrame(leaveout_ckpt_geodesic_latent).mean(0)

1-Wasserstein    1.988606
2-Wasserstein    2.099862
Linear_MMD       0.123717
Poly_MMD         0.316766
RBF_MMD          0.259798
Mean_MSE         0.144897
Mean_L2          0.343652
Mean_L1          0.282884
rep              2.000000
dtype: float64

In [101]:
pd.DataFrame(leaveout_ckpt_geodesic_latent).std(0)/np.sqrt(18)

1-Wasserstein    0.100507
2-Wasserstein    0.095822
Linear_MMD       0.026820
Poly_MMD         0.037082
RBF_MMD          0.032214
Mean_MSE         0.030431
Mean_L2          0.039705
Mean_L1          0.033472
rep              0.198030
dtype: float64

Baseline

In [102]:
pd.DataFrame(leaveout_ckpt_previous_latent).mean(0)

1-Wasserstein    2.676166
2-Wasserstein    2.770522
Linear_MMD       0.332431
Poly_MMD         0.515711
RBF_MMD          0.437044
Mean_MSE         0.358540
Mean_L2          0.546338
Mean_L1          0.459409
rep              2.000000
dtype: float64

In [103]:
pd.DataFrame(leaveout_ckpt_previous_latent).std(0)/np.sqrt(18)

1-Wasserstein    0.151179
2-Wasserstein    0.149283
Linear_MMD       0.077163
Poly_MMD         0.062531
RBF_MMD          0.058982
Mean_MSE         0.077303
Mean_L2          0.059436
Mean_L1          0.054842
rep              0.198030
dtype: float64

**Data space**

VAE

In [104]:
pd.DataFrame(leaveout_ckpt_vae_data).mean(0)

precision    0.808547
recall       0.130680
density      1.647513
coverage     0.627807
rep          2.000000
dtype: float64

In [105]:
pd.DataFrame(leaveout_ckpt_vae_data).std(0)/np.sqrt(18)

precision    0.035903
recall       0.018926
density      0.165088
coverage     0.031488
rep          0.198030
dtype: float64

Flat VAE

In [106]:
pd.DataFrame(leaveout_ckpt_flat_data).mean(0)

precision    0.804371
recall       0.076660
density      3.919827
coverage     0.848067
rep          2.000000
dtype: float64

In [108]:
pd.DataFrame(leaveout_ckpt_flat_data).std(0)/np.sqrt(18)

precision    0.030477
recall       0.013537
density      0.291145
coverage     0.025187
rep          0.198030
dtype: float64

Geodesic VAE

In [109]:
pd.DataFrame(leaveout_ckpt_geodesic_data).mean(0)

precision    0.130843
recall       0.587871
density      0.019796
coverage     0.026560
rep          2.000000
dtype: float64

In [110]:
pd.DataFrame(leaveout_ckpt_geodesic_data).std(0)/np.sqrt(18)

precision    0.010483
recall       0.017022
density      0.001905
coverage     0.005038
rep          0.198030
dtype: float64

Previous VAE

In [111]:
pd.DataFrame(leaveout_ckpt_previous_data).mean(0)

precision    0.373793
recall       0.364364
density      0.250444
coverage     0.177981
rep          2.000000
dtype: float64

In [112]:
pd.DataFrame(leaveout_ckpt_previous_data).std(0)/np.sqrt(18)

precision    0.043178
recall       0.030804
density      0.056982
coverage     0.028410
rep          0.198030
dtype: float64