In [1]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import numpy as np
import torch
import scanpy as sc
import scvelo as scv
import cellrank as cr

import anndata
import pandas as pd
import seaborn as sns

from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule
from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import (standardize_adata,
                             add_keys_to_dict,
                             real_reconstructed_cells_adata,
                             add_velocity_to_adata, 
                             compute_velocity_projection, 
                             compute_trajectory, 
                             decode_trajectory)

import scib
from scib_metrics.benchmark import Benchmarker, BatchCorrection

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  self.seed = seed
  self.dl_pin_memory_gpu_training = (


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
adata_latent_vae = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_lib.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_flat_lib.h5ad")

In [4]:
def get_trajectory(adata, cluster, gene_list, n_tp=4):
    condition = np.logical_and(adata.obs.experimental_time==1.0, adata.obs.leiden==cluster)
    cells_t1 = np.array(adata[condition].obs.index).astype(np.int32) 

    # Collect the indices of the observations mapping to ncc
    indices_lineage = [cells_t1]
    
    tmp_indices = cells_t1.copy()
    for _ in range(n_tp):
        tmp_indices -= n_x0
        indices_lineage.append(tmp_indices.copy())

    # Collect observations mapping to lineage 
    X_lineage = []

    for idxs in indices_lineage[::-1]:
        X_lineage.append(adata.X[idxs])
    
    X_lineage = np.stack(X_lineage, axis=1)

    # Collect indices of genes of interest
    genes_indices = list(adata.var[adata.var.gene_name.isin(gene_list)].index)
    return X_lineage,genes_indices

def plot_gene_trajectory_dpt(X_method_list, X_reference, pseudotime, idx2time, method_names, idx, gene_name):
    # Add reference name
    method_names.append("pseudotime")
    # Get the associated time
    real_times = [idx2time[i] for i in idx2time]
    time = [real_times*X.shape[0] for X in X_method_list]
    time.append(pseudotime)
    # Subset to values of interest
    X_list = [X[:,:,idx].ravel() for X in X_method_list]
    X_list.append(X_reference[:, idx])
    # Method 
    method = [[method_names[i]]*X_list[i].shape[0] for i in range(len(X_list))]

    # Concatenate
    X_list = np.concatenate(X_list)
    time = np.concatenate(time)
    method = np.concatenate(method)
    
    # Dictironary to plot 
    plotting_dict = {"gene_expression": X_list, 
                     "time": time,
                    "method": method}
    sns.lineplot(plotting_dict, x="time", y="gene_expression", hue="method")
    plt.title(gene_name)

def plot_gene_trajectory(X_method_list, X_original, data_times, idx2time, method_names, idx, gene_name):
    # Append name real data
    method_names.append("real_data")
    # Get the associated time
    real_times = [idx2time[i] for i in idx2time]
    time = [real_times*X.shape[0] for X in X_method_list]
    time.append(data_times)
    # Subset to values of interest
    X_list = [X[:,:,idx].ravel() for X in X_method_list]
    X_list.append(X_original[:, idx])
    # Method 
    method = [[method_names[i]]*X_list[i].shape[0] for i in range(len(X_list))]

    # Concatenate
    X_list = np.concatenate(X_list)
    time = np.concatenate(time)
    method = np.concatenate(method)
    
    # Dictironary to plot 
    plotting_dict = {"gene_expression": X_list, 
                     "time": time,
                    "method": method}
    sns.lineplot(plotting_dict, x="time", y="gene_expression", hue="method")
    plt.title(gene_name)

Initialize datamodule

In [5]:
datamodule_kwargs_vae = {'path': '/nfs/homedirs/pala/scCFM/project_dir/data/eb/flat/eb_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': '/nfs/homedirs/pala/scCFM/project_dir/data/eb/flat/eb_flat_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)

# Mapping real times to index
idx2time = datamodule_vae.idx2time

Velocity field network

In [6]:
net_hparams = {"dim": datamodule_vae.dim,
                "w": 64,
                "time_varying": True}

net_vae = MLP(**net_hparams).to(device)
net_flat = MLP(**net_hparams).to(device)

CFM model 

In [7]:
cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)

In [8]:
cfm_vae.load_state_dict(torch.load(CKPT_FOLDER / "trajectory" / "eb" / "best_cfm_model_eb.ckpt")["state_dict"])
cfm_flat.load_state_dict(torch.load(CKPT_FOLDER / "trajectory" / "eb" / "best_cfm_model_flat_eb.ckpt")["state_dict"])

<All keys matched successfully>

**Decoded trajectories**

In [9]:
datamodule={'path': '/nfs/homedirs/pala/scCFM/project_dir/data/eb/processed/eb_phate.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'leiden'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

Initialize variational autoencoders

In [10]:
model_vae={
       'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric={'compute_metrics_every': 1, 
           'use_c': True, 
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

In [11]:
# Initialize vae and geometric vae
vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae).to(cfm_vae.device)
geometric_vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae).to(cfm_vae.device)

In [12]:
# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geometric_lib.ckpt")["state_dict"])

vae.eval()
geometric_vae.eval()

GeometricNBVAE(
  (encoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=1241, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (decoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=10, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (library_size_decoder): Linear(in_features=10, out_features=1, bias=True)
  (decoder_mu_lib): Linear(in_features=256, out_features=1241, bias=True)
  (mu_logvar): Linear(in_features=256, out_features=20, bias=True)
)

Compute the trajectories 

In [13]:
#Pick time 0 observations
X_adata_t0_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==0].X).to(device)
X_adata_t0_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==0].X).to(device)

#Library values of observations at time 0
l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==0, "log_library_size"].to_numpy()
l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==0, "log_library_size"].to_numpy()
l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
l_t0_flat = torch.from_numpy(l_t0_flat).to(device)

In [14]:
# Number of observations at x0
n_x0 = X_adata_t0_vae.shape[0]

In [15]:
# Collect data trajectories
mu_traj_vae, x_traj_vae, times_traj_vae = decode_trajectory(X_adata_t0_vae,
                                                            l_t0_vae,
                                                            cfm_vae,
                                                            vae,
                                                            idx2time, 
                                                            device, 
                                                            False,
                                                            keep_time_d=False, 
                                                           append_last=False)

mu_traj_flat, x_traj_flat, times_traj_flat = decode_trajectory(X_adata_t0_flat,
                                                               l_t0_flat,
                                                               cfm_flat,
                                                               geometric_vae,
                                                               idx2time, 
                                                               device, 
                                                               False,
                                                               keep_time_d=False)

# Real time cells (for plotting)
times_traj_vae = pd.DataFrame(times_traj_vae)
times_traj_flat = pd.DataFrame(times_traj_flat)

times_traj_vae.columns = ["experimental_time"]
times_traj_flat.columns = ["experimental_time"]

In [16]:
# Create anndatas
adata_x_traj_vae = anndata.AnnData(X=x_traj_vae.detach().numpy(), 
                                     obs=times_traj_vae)
adata_x_traj_flat = anndata.AnnData(X=x_traj_flat.detach().numpy(),
                                      obs=times_traj_flat)



ValueError: Observations annot. `obs` must have number of rows of `X` (22653), but has 17655653 rows.

We now have simulated datasets. We read the true datasets and visualize it 

In [None]:
adata_eb_original = sc.read_h5ad('/nfs/homedirs/pala/scCFM/project_dir/data/eb/processed/eb_phate.h5ad')
adata_eb_original.X = adata_eb_original.layers["X_norm"].copy()
adata_eb_original.var = adata_eb_original.var

Annotate genes

In [None]:
adata_x_traj_vae.var = adata_eb_original.var.copy()
adata_x_traj_flat.var = adata_eb_original.var.copy()

In [None]:
sc.pl.embedding(adata_eb_original, basis="X_phate", color=["leiden", "experimental_time"])

In [None]:
sc.pl.umap(adata_eb_original, color=["leiden", "experimental_time"])

In [None]:
#log1p and neighbors calculation
sc.pp.log1p(adata_x_traj_vae)
sc.pp.log1p(adata_x_traj_flat)

sc.tl.pca(adata_x_traj_vae)
sc.tl.pca(adata_x_traj_flat)

sc.pp.neighbors(adata_x_traj_vae)
sc.pp.neighbors(adata_x_traj_flat)

sc.tl.umap(adata_x_traj_vae)
sc.tl.umap(adata_x_traj_flat)

In [None]:
sc.tl.ingest(adata_x_traj_vae, adata_eb_original, obs="leiden", embedding_method="umap")
sc.tl.ingest(adata_x_traj_flat, adata_eb_original, obs="leiden", embedding_method="umap")

In [None]:
sc.pl.umap(adata_x_traj_flat, color="leiden")

In [None]:
sc.pl.umap(adata_x_traj_vae, color=["leiden", "experimental_time"])

In [None]:
adata_x_traj_flat.var = adata_x_traj_flat.var.reset_index()
adata_x_traj_vae.var = adata_x_traj_vae.var.reset_index()
adata_eb_original.var = adata_eb_original.var.reset_index()

### Neural crest cells

In [None]:
X_neural_crest_cells_flat, genes_neural_crest_cells = get_trajectory(adata_x_traj_flat, '3', ["NGFR", "GYPC", "CXCR4", "PDGFRB"], n_tp=4)
X_neural_crest_cells_vae, genes_neural_crest_cells = get_trajectory(adata_x_traj_vae, '3', ["NGFR", "GYPC", "CXCR4", "PDGFRB"], n_tp=4)

In [None]:
adata_eb_original_neural_crest = adata_eb_original[adata_eb_original.obs.leiden.isin(['7', '4', '8', '3'])]
sc.pl.embedding(adata_eb_original_neural_crest, basis="X_phate", color=["leiden", "dpt_pseudotime"])
adata_eb_original_neural_crest.uns['iroot'] = np.flatnonzero(adata_eb_original_neural_crest.obs['leiden']  == '7')[0]
sc.tl.dpt(adata_eb_original_neural_crest)
adata_eb_original_neural_crest = adata_eb_original_neural_crest[adata_eb_original_neural_crest.obs.sort_values(by="dpt_pseudotime").index]

Plot pseudotime 

In [None]:
plot_gene_trajectory([X_neural_crest_cells_flat, X_neural_crest_cells_vae], 
                     np.array(adata_eb_original_neural_crest.X.A), 
                     np.array(adata_eb_original_neural_crest.obs.experimental_time), 
                     idx2time, 
                     ["flat", "vae", "real"],
                     257, 
                     "NGFR")

In [None]:
plot_gene_trajectory([X_neural_crest_cells_flat, X_neural_crest_cells_vae], 
                     np.array(adata_eb_original_neural_crest.X.A), 
                     np.array(adata_eb_original_neural_crest.obs.experimental_time), 
                     idx2time, 
                     ["flat", "vae", "real"],
                     455, 
                     "GYPC")

In [None]:
plot_gene_trajectory([X_neural_crest_cells_flat, X_neural_crest_cells_vae], 
                     np.array(adata_eb_original_neural_crest.X.A), 
                     np.array(adata_eb_original_neural_crest.obs.experimental_time), 
                     idx2time, 
                     ["flat", "vae", "real"],
                     755, 
                     "CXCR4")

In [None]:
plot_gene_trajectory([X_neural_crest_cells_flat, X_neural_crest_cells_vae], 
                     np.array(adata_eb_original_neural_crest.X.A), 
                     np.array(adata_eb_original_neural_crest.obs.experimental_time), 
                     idx2time, 
                     ["flat", "vae", "real"],
                     824, 
                     "PDGFRB")

**Neuron subtypes**