In [20]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import anndata
import scanpy as sc
import sklearn
import scvelo as scv

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

Initialize the device

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [22]:
def cross_standardize(tensor1, tensor2):
    """
    Standardize tensor across the rows
    """
    mean_t1, std_t1 = tensor1.mean(0), tensor1.std(0)
    tensor1 = (tensor1 - mean_t1) / (std_t1 + 1e-6)
    tensor2 = (tensor2 - mean_t1) / (std_t1 + 1e-6)
    return tensor1, tensor2

def update_dict(ref, tgt):
    for key in tgt:
        if key not in ref:
            ref[key] = []
        ref[key].append(tgt[key])
    return ref

Initialize datamodule

In [23]:
datamodule={'path': PROJECT_FOLDER / 'data/hein_et_al/processed/unperturbed_time_course_host.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'cluster'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize autoencoders

In [24]:
vae_kwargs={'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
           'use_c': True, 
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
          "hidden_dims": [256, 10],
          "batch_norm": True,
          "dropout": False, 
          "dropout_p": False,
          "likelihood": "nb",
          "learning_rate": 0.001}

# Initialize vae and geometric vae
vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_geometric_lib.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/hein_et_al_complete/best_model_geodesic_ae.ckpt")["state_dict"])

vae.eval()
geometric_vae.eval()
geodesic_ae.eval()

GeodesicAE(
  (encoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=1727, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (decoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=10, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=1727, bias=True)
        (1): BatchNorm1d(1727, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (decoder_mu_lib): Linear(in_features=256, out_features=1727, bias=True)
  (latent_layer): Linear(in_features=256, out_features=10, bias=True)
  (criterion): MSELoss()
)

## Setup CFMs

In [25]:
leavout_timepoints_folder = CKPT_FOLDER / "trajectory" / "hein_et_al"

Initialize datamodule for trajectory

In [26]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_lib_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_flat_lib_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

datamodule_kwargs_geodesic= {'path': PROJECT_FOLDER / 'data/hein_et_al/flat/hein_geodesic_complete.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)
datamodule_geodesic = TrajectoryDataModule(**datamodule_kwargs_geodesic)

# Mapping real times to index
idx2time = datamodule_vae.idx2time

## Read data

First, read the latent space anndata and plot the results

In [38]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_lib_complete.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_flat_lib_complete.h5ad")
adata_latent_geodesic = sc.read_h5ad(DATA_DIR / "hein_et_al" / "flat" / "hein_geodesic_complete.h5ad")

# Read real anndata
adata_hein_original = sc.read_h5ad(PROJECT_FOLDER / 'data/hein_et_al/processed/unperturbed_time_course_host.h5ad')
sc.tl.pca(adata_hein_original, n_comps=50)
adata_hein_original.X = adata_hein_original.layers["X_norm"].copy()

Number of experiments 

In [28]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0,
 1: 0.05,
 2: 0.16666666666666666,
 3: 0.23333333333333334,
 4: 0.4,
 5: 0.6,
 6: 0.8,
 7: 1.0}

Initialize model

In [29]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [47]:
# LATENT SPACE METRICS
leaveout_ckpt_vae_latent = {}
leaveout_ckpt_flat_latent = {}
leaveout_ckpt_geodesic_latent = {}
leaveout_ckpt_previous_latent = {}

In [48]:
# DATA SPACE METRICS
leaveout_ckpt_vae_data = {}
leaveout_ckpt_flat_data = {}
leaveout_ckpt_geodesic_data = {}
leaveout_ckpt_previous = {}

In [49]:
for tp in range(1, n_timepoints-1):
    print(f"Time point {tp}")
    #Pick time 0 observations
    X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t0_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)

    # Pick observations next timepoint 
    X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_t1_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[tp]].X).to(device)    

    # Collect PCs    
    X_adata_real_pca = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].obsm["X_pca"]).to(device)
    X_adata_real = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].layers["X_log"].A).to(device)

    #Pick library sizes
    l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_geodesic = adata_latent_geodesic.obs.loc[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()

    #Pick library sizes
    l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
    l_t0_flat = torch.from_numpy(l_t0_flat).to(device)
    l_t0_geodesic = torch.from_numpy(l_t0_geodesic).to(device)

    # Initialize nets
    net_vae = MLP(**net_hparams).to(device)
    net_flat = MLP(**net_hparams).to(device)
    net_geodesic = MLP(**net_hparams).to(device)
    cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
    cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)
    cfm_geodesic = CFMLitModule(net=net_geodesic, datamodule=datamodule_geodesic, **cfm_kwargs).to(device)

    # Read the checkpoints
    cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_vae_leaveout_{tp}.ckpt")["state_dict"])
    cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_flat_leaveout_{tp}.ckpt")["state_dict"])
    cfm_geodesic.load_state_dict(torch.load(leavout_timepoints_folder / f"hein_geodesic_leaveout_{tp}.ckpt")["state_dict"])

    mu_adata_predicted_vae, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, 
                                                                                 l_t0_vae, 
                                                                                 tp-1, 
                                                                                 cfm_vae, 
                                                                                 vae)
                                                                                
    mu_adata_predicted_flat, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, 
                                                                                   l_t0_flat, 
                                                                                   tp-1, 
                                                                                   cfm_flat, 
                                                                                   geometric_vae)
                                                                                  
    mu_adata_predicted_geodesic, X_adata_predicted_geodesic, X_adata_latent_geodesic = decode_trajectory_single_step(X_adata_t0_latent_geodesic, 
                                                                                           l_t0_geodesic, 
                                                                                           tp-1, 
                                                                                           cfm_geodesic, 
                                                                                           geodesic_ae, 
                                                                                           model_type="geodesic_ae")

    ## PREDICT LATENT TRAJECTORIES 
    X_adata_t1_latent_vae, X_adata_latent_vae = cross_standardize(X_adata_t1_latent_vae, X_adata_latent_vae[:,:-1])
    X_adata_t1_latent_flat, X_adata_latent_flat = cross_standardize(X_adata_t1_latent_flat, X_adata_latent_flat[:,:-1])
    X_adata_t1_latent_geodesic, X_adata_latent_geodesic = cross_standardize(X_adata_t1_latent_geodesic, X_adata_latent_geodesic[:,:-1])
                                                                                   
    d_dist_vae_l = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"), 
                                         X_adata_latent_vae.unsqueeze(1).to("cpu"))
    d_dist_flat_l = compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
                                         X_adata_latent_flat.unsqueeze(1).to("cpu"))
    d_dist_geod_l = compute_distribution_distances(X_adata_t1_latent_geodesic.unsqueeze(1).to("cpu"),
                                         X_adata_latent_geodesic.unsqueeze(1).to("cpu"))
    d_dist_prev = compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"),
                                         X_adata_t0_latent_vae.unsqueeze(1).to("cpu"))
    
    d_dist_vae_l = dict(zip(d_dist_vae_l[0], d_dist_vae_l[1]))
    d_dist_flat_l = dict(zip(d_dist_flat_l[0], d_dist_flat_l[1]))
    d_dist_geod_l = dict(zip(d_dist_geod_l[0], d_dist_geod_l[1]))
    d_dist_prev_l = dict(zip(d_dist_prev[0], d_dist_prev[1]))
    
    leaveout_ckpt_vae_latent = update_dict(leaveout_ckpt_vae_latent, d_dist_vae_l)
    leaveout_ckpt_flat_latent = update_dict(leaveout_ckpt_flat_latent, d_dist_flat_l)
    leaveout_ckpt_geodesic_latent = update_dict(leaveout_ckpt_geodesic_latent, d_dist_geod_l)
    leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous, d_dist_prev_l)

    print("predict decoded trajectory")
    X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
    X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
    X_adata_predicted_geodesic = anndata.AnnData(X=X_adata_predicted_geodesic.numpy())
    X_adata_prev = adata_hein_original[adata_hein_original.obs["experimental_time"]==idx2time[tp-1]]
    
    sc.pp.log1p(X_adata_predicted_vae)
    sc.pp.log1p(X_adata_predicted_flat)
    sc.pp.log1p(X_adata_prev)
    sc.tl.pca(X_adata_predicted_vae, n_comps=50)
    sc.tl.pca(X_adata_predicted_flat, n_comps=50)
    sc.tl.pca(X_adata_predicted_geodesic, n_comps=50)
    sc.tl.pca(X_adata_prev, n_comps=50)
    
    d_dist_vae_d = compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"].copy()), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)
    d_dist_flat_d = compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"].copy()), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)
    d_dist_geodesic_d = compute_prdc(torch.from_numpy(X_adata_predicted_geodesic.obsm["X_pca"].copy()), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)
    d_dist_prev_d = compute_prdc(torch.from_numpy(X_adata_prev.obsm["X_pca"]), 
                                             X_adata_real_pca.to("cpu"), nearest_k=10)

    leaveout_ckpt_vae_data = update_dict(leaveout_ckpt_vae_data, d_dist_vae_d)
    leaveout_ckpt_flat_data = update_dict(leaveout_ckpt_flat_data, d_dist_flat_d)
    leaveout_ckpt_geodesic_data = update_dict(leaveout_ckpt_geodesic_data, d_dist_geodesic_d)
    leaveout_ckpt_previous_latent = update_dict(leaveout_ckpt_previous, d_dist_prev_d)

Time point 1
predict decoded trajectory


  view_to_actual(adata)


Time point 2
predict decoded trajectory


  view_to_actual(adata)


Time point 3
predict decoded trajectory


  view_to_actual(adata)


Time point 4
predict decoded trajectory


  view_to_actual(adata)


Time point 5
predict decoded trajectory


  view_to_actual(adata)


Time point 6
predict decoded trajectory


  view_to_actual(adata)


**Latent**

In [50]:
pd.DataFrame(leaveout_ckpt_vae_latent).mean(0)

1-Wasserstein    2.024811
2-Wasserstein    2.129322
Linear_MMD       0.113428
Poly_MMD         0.302331
RBF_MMD          0.247716
Mean_MSE         0.128685
Mean_L2          0.329420
Mean_L1          0.266932
dtype: float64

In [51]:
pd.DataFrame(leaveout_ckpt_flat_latent).mean(0)

1-Wasserstein    1.796140
2-Wasserstein    1.973514
Linear_MMD       0.163645
Poly_MMD         0.359617
RBF_MMD          0.296946
Mean_MSE         0.166371
Mean_L2          0.374711
Mean_L1          0.305529
dtype: float64

In [52]:
pd.DataFrame(leaveout_ckpt_geodesic_latent).mean(0)

1-Wasserstein    1.684322
2-Wasserstein    1.792405
Linear_MMD       0.044898
Poly_MMD         0.179994
RBF_MMD          0.141314
Mean_MSE         0.054560
Mean_L2          0.203742
Mean_L1          0.159653
dtype: float64

In [53]:
leaveout_ckpt_previous_latent

{'1-Wasserstein': [3.590788820292704,
  3.1652559458982585,
  2.013340199981804,
  1.8891657742112176,
  2.386414933473867,
  3.0120303707244576],
 '2-Wasserstein': [3.693727755567239,
  3.2471132703057606,
  2.0989638739940255,
  2.0031371024756304,
  2.513470513805088,
  3.0667211609764604],
 'Linear_MMD': [0.9655517339706421,
  0.4951860308647156,
  0.1430787742137909,
  0.08102326095104218,
  0.0603393018245697,
  0.24940434098243713],
 'Poly_MMD': [0.9826249202878187,
  0.7036945579331387,
  0.378257550108112,
  0.2846458518071925,
  0.24564059482212972,
  0.4994039857494503],
 'RBF_MMD': [0.9037631750106812,
  0.5949233174324036,
  0.32194235920906067,
  0.25268489122390747,
  0.19701287150382996,
  0.3519399166107178],
 'Mean_MSE': [0.9975139498710632,
  0.5083293318748474,
  0.15305648744106293,
  0.09046776592731476,
  0.11422283947467804,
  0.2876507639884949],
 'Mean_L2': [0.9987562014180754,
  0.7129721816977486,
  0.3912243441314241,
  0.3007785995168452,
  0.3379686959981

In [54]:
pd.DataFrame(leaveout_ckpt_previous_latent).mean(0)

1-Wasserstein    2.676166
2-Wasserstein    2.770522
Linear_MMD       0.332431
Poly_MMD         0.515711
RBF_MMD          0.437044
Mean_MSE         0.358540
Mean_L2          0.546338
Mean_L1          0.459409
precision        0.373793
recall           0.364364
density          0.250444
coverage         0.177981
dtype: float64

**Data space**

In [55]:
pd.DataFrame(leaveout_ckpt_vae_data).mean(0)

precision    0.750372
recall       0.150647
density      1.416885
coverage     0.659755
dtype: float64

In [56]:
pd.DataFrame(leaveout_ckpt_flat_data).mean(0)

precision    0.766666
recall       0.058059
density      3.100853
coverage     0.817091
dtype: float64

In [57]:
pd.DataFrame(leaveout_ckpt_geodesic_data).mean(0)

precision    0.009464
recall       0.883743
density      0.001170
coverage     0.002569
dtype: float64

In [59]:
pd.DataFrame(leaveout_ckpt_previous_latent).mean(0)

1-Wasserstein    2.676166
2-Wasserstein    2.770522
Linear_MMD       0.332431
Poly_MMD         0.515711
RBF_MMD          0.437044
Mean_MSE         0.358540
Mean_L2          0.546338
Mean_L1          0.459409
precision        0.373793
recall           0.364364
density          0.250444
coverage         0.177981
dtype: float64