In [93]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import scanpy as sc
import sklearn
import scvelo as scv

import anndata
import pandas as pd

from IPython.display import display
from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule

from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.models.base.geodesic_ae import GeodesicAE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step, standardize, compute_prdc

Initialize the device

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

Initialize datamodule

In [3]:
datamodule={'path': PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'leiden'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize autoencoders

In [4]:
vae_kwargs={'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric_kwargs={'compute_metrics_every': 1, 
           'use_c': True, 
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

geodesic_kwargs={"in_dim": datamodule.in_dim,
          "hidden_dims": [256, 10],
          "batch_norm": True,
          "dropout": False, 
          "dropout_p": False,
          "likelihood": "nb",
          "learning_rate": 0.001}

# Initialize vae and geometric vae
vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geometric_vae = GeometricNBVAE(**geometric_kwargs, vae_kwargs=vae_kwargs).to(device)
geodesic_ae = GeodesicAE(**geodesic_kwargs).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geometric_lib.ckpt")["state_dict"])
geodesic_ae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geodesic_ae.ckpt")["state_dict"])

vae.eval()
geometric_vae.eval()
geodesic_ae.eval()

GeodesicAE(
  (encoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=1241, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (decoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=10, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=1241, bias=True)
        (1): BatchNorm1d(1241, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (decoder_mu_lib): Linear(in_features=256, out_features=1241, bias=True)
  (latent_layer): Linear(in_features=256, out_features=10, bias=True)
  (criterion): MSELoss()
)

## Setup CFMs

In [5]:
leavout_timepoints_folder = CKPT_FOLDER / "trajectory" / "eb"

Initialize datamodule for trajectory

In [19]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_flat_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

datamodule_kwargs_geodesic= {'path': PROJECT_FOLDER / 'data/eb/flat/eb_geodesic.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)
datamodule_geodesic = TrajectoryDataModule(**datamodule_kwargs_geodesic)

# Mapping real times to index
idx2time = datamodule_vae.idx2time

## Read data

First, read the latent space anndata and plot the results

In [58]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_lib.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_flat_lib.h5ad")
adata_latent_geodesic = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_geodesic.h5ad")

# Read real anndata
adata_eb_original = sc.read_h5ad(PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad')
sc.tl.pca(adata_eb_original, n_comps=50)
adata_eb_original.X = adata_eb_original.layers["X_norm"].copy()

Number of experiments 

In [59]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0, 1: 0.25, 2: 0.5, 3: 0.75, 4: 1.0}

Initialize model

In [60]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [61]:
leaveput_ckpt_vae = {}
leaveout_ckpt_flat = {}
leaveout_ckpt_geodesic = {}

In [106]:
for tp in range(1, n_timepoints-1):
    print(f"Time point {tp}")
    #Pick time 0 observations
    X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t0_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)

    # Pick observations next timepoint 
    X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_t1_latent_geodesic = torch.from_numpy(adata_latent_geodesic[adata_latent_geodesic.obs["experimental_time"]==idx2time[tp]].X).to(device)    

    # Collect PCs    
    X_adata_real_pca = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].obsm["X_pca"]).to(device)
    X_adata_real = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].layers["X_log"].A).to(device)

    #Pick library sizes
    l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_geodesic = adata_latent_geodesic.obs.loc[adata_latent_geodesic.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()

    #Pick library sizes
    l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
    l_t0_flat = torch.from_numpy(l_t0_flat).to(device)
    l_t0_geodesic = torch.from_numpy(l_t0_geodesic).to(device)

    # Initialize nets
    net_vae = MLP(**net_hparams).to(device)
    net_flat = MLP(**net_hparams).to(device)
    net_geodesic = MLP(**net_hparams).to(device)
    cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
    cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)
    cfm_geodesic = CFMLitModule(net=net_geodesic, datamodule=datamodule_geodesic, **cfm_kwargs).to(device)

    # Read the checkpoints
    cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_vae_leaveout_{tp}.ckpt")["state_dict"])
    cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_flat_leaveout_{tp}.ckpt")["state_dict"])
    cfm_geodesic.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_geodesic_leaveout_{tp}.ckpt")["state_dict"])

    _, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, 
                                                                                 l_t0_vae, 
                                                                                 tp-1, 
                                                                                 cfm_vae, 
                                                                                 vae)
                                                                                
    _, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, 
                                                                                   l_t0_flat, 
                                                                                   tp-1, 
                                                                                   cfm_flat, 
                                                                                   geometric_vae)
                                                                                  
    _, X_adata_predicted_geodesic, X_adata_latent_geodesic = decode_trajectory_single_step(X_adata_t0_latent_geodesic, 
                                                                                           l_t0_geodesic, 
                                                                                           tp-1, 
                                                                                           cfm_geodesic, 
                                                                                           geodesic_ae, 
                                                                                           model_type="geodesic_ae")

    print("predict latent trajectory")
    X_adata_t1_latent_vae, X_adata_latent_vae = cross_standardize(X_adata_t1_latent_vae, X_adata_latent_vae[:,:-1])
    X_adata_t1_latent_flat, X_adata_latent_flat = cross_standardize(X_adata_t1_latent_flat, X_adata_latent_flat[:,:-1])
                                                                                   
    # print(compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"), 
    #                                      X_adata_latent_vae.unsqueeze(1).to("cpu")))
    # print(compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
    #                                      X_adata_latent_flat.unsqueeze(1).to("cpu")))
    print(compute_prdc(X_adata_t1_latent_vae.to("cpu"), X_adata_latent_vae.to("cpu"), nearest_k=5))
    print(compute_prdc(X_adata_t1_latent_flat.to("cpu"), X_adata_latent_flat.to("cpu"), nearest_k=5))
    print(compute_prdc(X_adata_t1_latent_vae.to("cpu"), torch.randn_like(X_adata_latent_vae.to("cpu")), nearest_k=5))
    print(compute_prdc(X_adata_t1_latent_flat.to("cpu"), torch.randn_like(X_adata_latent_flat.to("cpu")), nearest_k=5))
    
    # print(compute_distribution_distances(standardize(X_adata_t1_latent_vae).unsqueeze(1).to("cpu"), 
    #                                      standardize(X_adata_latent_vae[:,:-1]).unsqueeze(1).to("cpu")))
    # print(compute_distribution_distances(standardize(X_adata_t1_latent_flat).unsqueeze(1).to("cpu"),
    #                                      standardize(X_adata_latent_flat[:,:-1]).unsqueeze(1).to("cpu")))
    # print(compute_distribution_distances(standardize(X_adata_t1_latent_geodesic).unsqueeze(1).to("cpu"),
    #                                      standardize(X_adata_latent_geodesic[:,:-1]).unsqueeze(1).to("cpu")))


    # print("predict decoded trajectory")
    # X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
    # X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
    # X_adata_predicted_geodesic = anndata.AnnData(X=X_adata_predicted_geodesic.numpy())
    # sc.pp.log1p(X_adata_predicted_vae)
    # sc.pp.log1p(X_adata_predicted_flat)
    # sc.tl.pca(X_adata_predicted_vae, n_comps=50)
    # sc.tl.pca(X_adata_predicted_flat, n_comps=50)
    # sc.tl.pca(X_adata_predicted_geodesic, n_comps=50)

    # print(compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"]), 
    #                                          X_adata_real_pca.to("cpu"), nearest_k=30))
    # print(compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"]), 
    #                                          X_adata_real_pca.to("cpu"), nearest_k=30))
    # print(compute_prdc(torch.from_numpy(X_adata_predicted_geodesic.obsm["X_pca"]), 
    #                                          X_adata_real_pca.to("cpu"), nearest_k=30))
    # print()

Time point 1
predict latent trajectory
{'precision': 0.5393893705239352, 'recall': 0.6461573650503202, 'density': 0.30976253298153034, 'coverage': 0.2381061299176578}
{'precision': 0.846646571213263, 'recall': 0.7903890160183067, 'density': 0.645440844009043, 'coverage': 0.43157894736842106}
{'precision': 0.3742932529212213, 'recall': 0.9338975297346752, 'density': 0.14346023369770072, 'coverage': 0.18412625800548948}
{'precision': 0.0546345139412208, 'recall': 0.8988558352402746, 'density': 0.016126601356443105, 'coverage': 0.032494279176201374}
Time point 2
predict latent trajectory
{'precision': 0.7152333028362305, 'recall': 0.736827661909989, 'density': 0.4806953339432754, 'coverage': 0.5532381997804611}
{'precision': 0.9462242562929062, 'recall': 0.9368652209717266, 'density': 0.8555606407322656, 'coverage': 0.8210266264068076}
{'precision': 0.3389752973467521, 'recall': 0.9179473106476399, 'density': 0.12872827081427266, 'coverage': 0.27497255762897915}
{'precision': 0.0599542334

In [99]:
display(pd.DataFrame(compute_distribution_distances(X_adata_t1_latent_flat.unsqueeze(1).to("cpu"),
                                         X_adata_latent_flat.unsqueeze(1).to("cpu"))))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,1-Wasserstein,2-Wasserstein,Linear_MMD,Poly_MMD,RBF_MMD,Mean_MSE,Mean_L2,Mean_L1,Median_MSE,Median_L2,Median_L1
1,1.251247,1.326609,0.052445,0.229008,0.208673,0.069506,0.263639,0.241806,,,


In [100]:
display(pd.DataFrame(compute_distribution_distances(X_adata_t1_latent_vae.unsqueeze(1).to("cpu"),
                                         X_adata_latent_vae.unsqueeze(1).to("cpu"))))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,1-Wasserstein,2-Wasserstein,Linear_MMD,Poly_MMD,RBF_MMD,Mean_MSE,Mean_L2,Mean_L1,Median_MSE,Median_L2,Median_L1
1,1.811036,1.933194,0.082131,0.286585,0.217289,0.073672,0.271425,0.207192,,,


In [64]:
def cross_standardize(tensor1, tensor2):
    """
    Standardize tensor across the rows
    """
    mean_t1, std_t1 = tensor1.mean(0), tensor1.std(0)
    tensor1 = (tensor1 - mean_t1) / (std_t1 + 1e-6)
    tensor2 = (tensor2 - mean_t1) / (std_t1 + 1e-6)
    return tensor1, tensor2

In [108]:
f = "/home/icb/alessandro.palma/environment/scCFM/scCFM/train_hydra/multirun/2023-09-23/17-38-14/.submitit/13690774_0/13690774_0_submitted.pkl"

In [112]:
import pickle as pkl
with open(f, "rb") as file:
    f = pkl.load(file)

In [114]:
f.__dict__

{'function': <hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher at 0x7fee84bff640>,
 'args': (['model.leaveout_timepoint=2',
   'hydra=CFM_schiebinger_leaveout',
   'datamodule=CFM_schiebinger_leaveout',
   'logger=CFM_schiebinger_leaveout',
   'train=CFM_schiebinger_leaveout'],
  'hydra.sweep.dir',
  0,
  'job_id_for_0',
  {'instances': {hydra.core.config_store.ConfigStore: <hydra.core.config_store.ConfigStore at 0x7feecd936f20>,
    hydra.version.VersionBase: <hydra.version.VersionBase at 0x7feecd947430>,
    hydra._internal.sources_registry.SourcesRegistry: <hydra._internal.sources_registry.SourcesRegistry at 0x7feecd947550>,
    hydra.core.utils.JobRuntime: <hydra.core.utils.JobRuntime at 0x7feecd9475b0>,
    hydra.core.global_hydra.GlobalHydra: <hydra.core.global_hydra.GlobalHydra at 0x7feecd9476d0>},
   'omegaconf_resolvers': {'oc.create': <function omegaconf.omegaconf.OmegaConf.register_new_resolver.<locals>.resolver_wrapper(config: omegaconf.basecontainer.Ba