In [None]:
from paths import DATA_DIR, CKPT_FOLDER, PROJECT_FOLDER

import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import scanpy as sc
import sklearn
import scvelo as scv

import anndata
import pandas as pd

from torchdyn.core import NeuralODE

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.components.mlp import MLP
from scCFM.models.cfm.cfm_module import CFMLitModule
from scCFM.models.base.vae import VAE
from scCFM.models.base.geometric_vae import GeometricNBVAE
from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.cfm.components.eval.distribution_distances import compute_distribution_distances

from notebooks.utils import decode_trajectory_single_step

import scib_metrics
from scib_metrics.benchmark import Benchmarker, BatchCorrection, BioConservation
from scib_metrics import silhouette_batch, ilisi_knn, clisi_knn, kbet, graph_connectivity

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
def standardize(tensor):
    tensor = (tensor - tensor.mean(0)) / tensor.std(0)+1e-6
    return tensor

In [7]:
def compute_pairwise_distance(data_x, data_y=None):
    """
    Args:
        data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
        data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
    Returns:
        numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
    """
    if data_y is None:
        data_y = data_x
    dists = sklearn.metrics.pairwise_distances(
        data_x, data_y, metric='euclidean', n_jobs=8)
    return dists


def get_kth_value(unsorted, k, axis=-1):
    """
    Args:
        unsorted: numpy.ndarray of any dimensionality.
        k: int
    Returns:
        kth values along the designated axis.
    """
    # Take only K nearest neighbors and the radius is the maximum of the knn distances 
    indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
    k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
    kth_values = k_smallests.max(axis=axis)
    return kth_values


def compute_nearest_neighbour_distances(input_features, nearest_k):
    """
    Args:
        input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int
    Returns:
        Distances to kth nearest neighbours.
    """
    distances = compute_pairwise_distance(input_features)
    radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
    return radii


def compute_prdc(real_features, fake_features, nearest_k):
    """
    Computes precision, recall, density, and coverage given two manifolds.
    Args:
        real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int.
    Returns:
        dict of precision, recall, density, and coverage.
    """
    real_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        real_features, nearest_k)
    fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        fake_features, nearest_k)
    distance_real_fake = compute_pairwise_distance(
        real_features, fake_features)

    precision = (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).any(axis=0).mean()

    recall = (
            distance_real_fake <
            np.expand_dims(fake_nearest_neighbour_distances, axis=0)
    ).any(axis=1).mean()

    density = (1. / float(nearest_k)) * (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).sum(axis=0).mean()

    coverage = (
            distance_real_fake.min(axis=1) <
            real_nearest_neighbour_distances
    ).mean()

    return dict(precision=precision, 
                recall=recall,
                density=density, 
                coverage=coverage)

## Initialize VAEs

In [10]:
datamodule={'path': PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad', 
            'x_layer': 'X_norm', 
            'cond_keys': ['experimental_time', 'leiden'],
            'use_pca': False, 
            'n_dimensions': None, 
            'train_val_test_split': [1], 
            'batch_size': 512, 
            'num_workers': 2}

# Initialize datamodule
datamodule = scDataModule(**datamodule)

model_vae={
       'in_dim': datamodule.in_dim,
       'n_epochs_anneal_kl': 1000, 
       'kl_weight': None, 
       'likelihood': 'nb', 
       'dropout': False, 
       'learning_rate': 0.001, 
       'dropout_p': False, 
       'model_library_size': True, 
       'batch_norm': True, 
       'kl_warmup_fraction': 0.1, 
       'hidden_dims': [256, 10]}
        
geometric={'compute_metrics_every': 1, 
           'use_c': True, 
           'l2': True, 
           'eta_interp': 0, 
           'interpolate_z': False, 
           'start_jac_after': 0, 
           'fl_weight': 0.1,
           'detach_theta': True}

# Initialize vae and geometric vae
vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae).to(device)
geometric_vae = GeometricNBVAE(**geometric, vae_kwargs=model_vae).to(device)

# Load state dicts and put in eval mode 
vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_vae_lib.ckpt")["state_dict"])
geometric_vae.load_state_dict(torch.load(PROJECT_FOLDER / "checkpoints/ae/eb/best_model_geometric_lib.ckpt")["state_dict"])
vae.eval()
geometric_vae.eval()

GeometricNBVAE(
  (encoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=1241, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (decoder_layers): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=10, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
    )
  )
  (library_size_decoder): Linear(in_features=10, out_features=1, bias=True)
  (decoder_mu_lib): Linear(in_features=256, out_features=1241, bias=True)
  (mu_logvar): Linear(in_features=256, out_features=20, bias=True)
)

## Setup CFMs

In [11]:
leavout_timepoints_folder = CKPT_FOLDER / "trajectory" / "eb"

Initialize datamodule

In [12]:
datamodule_kwargs_vae = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                          'model_library_size': True}

datamodule_kwargs_flat = {'path': PROJECT_FOLDER / 'data/eb/flat/eb_flat_lib.h5ad',
                          'x_layer': 'X_latents',
                          'time_key': 'experimental_time', 
                          'use_pca': False, 
                          'n_dimensions': None, 
                          'train_val_test_split': [0.9, 0.1], 
                          'num_workers': 2, 
                          'batch_size': 512, 
                           'model_library_size': True}

# Initialize the datamodules 
datamodule_vae = TrajectoryDataModule(**datamodule_kwargs_vae)
datamodule_flat = TrajectoryDataModule(**datamodule_kwargs_flat)

# Mapping real times to index
idx2time = datamodule_vae.idx2time

## Read data

First, read the latent space anndata and plot the results

In [13]:
# Read latent anndata
adata_latent_vae = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_lib.h5ad")
adata_latent_flat = sc.read_h5ad(DATA_DIR / "eb" / "flat" / "eb_flat_lib.h5ad")

# Read real anndata
adata_eb_original = sc.read_h5ad(PROJECT_FOLDER / 'data/eb/processed/eb_phate.h5ad')
sc.tl.pca(adata_eb_original, n_comps=200)
adata_eb_original.X = adata_eb_original.layers["X_norm"].copy()

Number of experiments 

In [14]:
n_timepoints = len(np.unique(adata_latent_vae.obs.experimental_time))
idx2time = dict(zip(range(n_timepoints), np.unique(adata_latent_vae.obs.experimental_time)))
idx2time

{0: 0.0, 1: 0.25, 2: 0.5, 3: 0.75, 4: 1.0}

Initialize model

In [15]:
net_hparams = {"dim": adata_latent_flat.X.shape[1]+1,
                "w": 64,
                "time_varying": True}

cfm_kwargs = {'ot_sampler': 'exact', 
                   'sigma': 0.1, 
                   'use_real_time': False, 
                   'lr': 0.001, 
                   'antithetic_time_sampling': True}

## Evaluation

Load checkpoints

In [16]:
leaveput_ckpt_vae = {}
leaveout_ckpt_flat = {}

In [None]:
for tp in range(1, n_timepoints-1):
    print(f"Time point {tp}")
    #Pick time 0 observations
    X_adata_t0_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t0_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)]].X).to(device)
    X_adata_t1_latent_vae = torch.from_numpy(adata_latent_vae[adata_latent_vae.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_t1_latent_flat = torch.from_numpy(adata_latent_flat[adata_latent_flat.obs["experimental_time"]==idx2time[tp]].X).to(device)
    X_adata_real_pca = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].obsm["X_pca"]).to(device)
    X_adata_real = torch.from_numpy(adata_eb_original[adata_eb_original.obs["experimental_time"]==idx2time[tp]].layers["X_log"].A).to(device)

    #Pick library sizes
    l_t0_vae = adata_latent_vae.obs.loc[adata_latent_vae.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_flat = adata_latent_flat.obs.loc[adata_latent_flat.obs["experimental_time"]==idx2time[(tp-1)], "log_library_size"].to_numpy()
    l_t0_vae = torch.from_numpy(l_t0_vae).to(device)
    l_t0_flat = torch.from_numpy(l_t0_flat).to(device)

    # Initialize nets
    net_vae = MLP(**net_hparams).to(device)
    net_flat = MLP(**net_hparams).to(device)
    cfm_vae = CFMLitModule(net=net_vae, datamodule=datamodule_vae, **cfm_kwargs).to(device)
    cfm_flat = CFMLitModule(net=net_flat, datamodule=datamodule_flat, **cfm_kwargs).to(device)

    # Read the checkpoints
    cfm_vae.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_vae_leaveout_{tp}.ckpt")["state_dict"])
    cfm_flat.load_state_dict(torch.load(leavout_timepoints_folder / f"eb_flat_leaveout_{tp}.ckpt")["state_dict"])

    _, X_adata_predicted_vae, X_adata_latent_vae = decode_trajectory_single_step(X_adata_t0_latent_vae, l_t0_vae, tp-1, cfm_vae, vae)
    _, X_adata_predicted_flat, X_adata_latent_flat = decode_trajectory_single_step(X_adata_t0_latent_flat, l_t0_flat, tp-1, cfm_flat, geometric_vae)

    print("predict latent trajectory")
    print(compute_distribution_distances(standardize(X_adata_t1_latent_vae.unsqueeze(1).to("cpu")), 
                                         standardize(X_adata_latent_vae[:,:-1].unsqueeze(1).to("cpu"))))
    print(compute_distribution_distances(standardize(X_adata_t1_latent_flat.unsqueeze(1).to("cpu")),
                                         standardize(X_adata_latent_flat[:,:-1].unsqueeze(1).to("cpu"))))

    print("predict decoded trajectory")
    X_adata_predicted_vae = anndata.AnnData(X=X_adata_predicted_vae.numpy())
    X_adata_predicted_flat = anndata.AnnData(X=X_adata_predicted_flat.numpy())
    sc.pp.log1p(X_adata_predicted_vae)
    sc.pp.log1p(X_adata_predicted_flat)
    sc.tl.pca(X_adata_predicted_vae, n_comps=200)
    sc.tl.pca(X_adata_predicted_flat, n_comps=200)
    
    # print(compute_distribution_distances(torch.from_numpy(X_adata_predicted_vae.X).unsqueeze(1), 
    #                                      X_adata_real.unsqueeze(1).to("cpu")))
    # print(compute_distribution_distances(torch.from_numpy(X_adata_predicted_flat.X).unsqueeze(1), 
    #                                                       X_adata_real.unsqueeze(1).to("cpu")))
    
    print(compute_distribution_distances(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"]).unsqueeze(1), 
                                             X_adata_real_pca.unsqueeze(1).to("cpu")))
    print(compute_distribution_distances(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"]).unsqueeze(1), 
                                             X_adata_real_pca.unsqueeze(1).to("cpu")))

    print(compute_prdc(torch.from_numpy(X_adata_predicted_vae.obsm["X_pca"]), 
                                             X_adata_real_pca.to("cpu"), nearest_k=30))
    print(compute_prdc(torch.from_numpy(X_adata_predicted_flat.obsm["X_pca"]), 
                                             X_adata_real_pca.to("cpu"), nearest_k=30))
    print()