In [1]:
#libraries
import scanpy as sc
import scvi 
import anndata
import matplotlib.pyplot as plt
import numpy as np 
import torch
import pandas as pd
import seaborn as sb
import cellrank as cr
from cellrank.kernels import CytoTRACEKernel
import scvelo as scv
import scanpy.external as sce

In [None]:
#load object with raw counts and check all lanes present
adata = anndata.read_h5ad('../../path_to_raw_h5ad')
np.unique(adata.obs.batch)

In [None]:
#subset to mesenchymal compartment
adata = adata[[c in mesenchymal_lineage for c in adata.obs.celltype]]

In [None]:
#preprocess and scvi integration with batch correction

sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_per_cell(adata, 10000) 
sc.pp.log1p(adata)
sc.pp.pca(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000, subset=True)

def preprocess(adata):
    scvi.settings.seed = 0
    scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key='batch')
    vae = scvi.model.SCVI(adata, n_layers=2, n_latent=10, dropout_rate=0.2) 
    vae.train()
    adata.obsm["X_scVI"] = vae.get_latent_representation()
    sc.pp.neighbors(adata, use_rep="X_scVI",method='gauss')
    sc.tl.umap(adata)
    #sc.tl.leiden(adata)
    return adata
adata = preprocess(adata)

In [None]:
#run palantir - define function
def run_palantir(adata, latent_embedding, umap_embedding, root_markers, n_eigs=None): 

    print('============= Palantir Run ===========')
    sc.tl.score_genes(adata, gene_list=root_markers, score_name='start_score')
    sc.pl.embedding(adata, basis=umap_embedding, color= ['start_score'],s=15, legend_loc='on data')
    print('Root cell determination')
    start_cell_id = np.argmax(adata.obs.start_score)
    start_cell =  adata.obs_names[start_cell_id] 
    print('ROOT CELL: ', start_cell_id)
    print(adata[adata.obs_names == start_cell].obs.finestanno)
    adata[adata.obs_names == start_cell].obs.fineanno
    print(adata.shape)
    np.random.seed(5)
    # Run diffusion maps
    palantir.utils.run_diffusion_maps(adata, pca_key = latent_embedding)
    if(n_eigs==None):
        ms_data = palantir.utils.determine_multiscale_space(adata)
    else:
        print('Taking n_eigs = ', n_eigs)
        ms_data = palantir.utils.determine_multiscale_space(adata, n_eigs=n_eigs+1)
    print('DC representation shape:', ms_data.shape)
    np.random.seed(5)
    print('Running Palantir')
    pr_res = palantir.core.run_palantir(ms_data, early_cell=start_cell , 
                                        use_early_cell_as_start=True, num_waypoints=500)
   
    return pr_res

In [None]:
#run
pr_res = run_palantir(adata, 'X_scVI_batch_corrected', 'X_umap_batch_corrected', ['TBX5'])

In [None]:
#plot
adata.obs['palantir_pseudotime'] = palantir_pseudotime
sc.pl.embedding(adata, basis='umap', color='palantir_pseudotime')
palantir.plot.plot_palantir_results(adata, pr_res, s=3)
plt.show()

In [None]:
#cellrank
pk_palantir = cr.kernels.PseudotimeKernel(adata, time_key="palantir_pseudotime")
pk_palantir.compute_transition_matrix()
pk_palantir.plot_projection(basis="umap", color="finestanno", legend_loc="right", recompute=True)