Benchmark Palantir runtime
---

In this notebook, we benchmark the runtime of Palantir's `run_palantir` function which extracts the fate probabilities.

# Preliminaries

## Dependency notebooks

1. [MK_2020-10-16_gpcca.ipynb](MK_2020-10-16_gpcca.ipynb) - to extract the terminal states.

## Import packages

In [1]:
# import standard packages
from pathlib import Path
from collections import defaultdict
from math import ceil
import sys
import pickle

import pandas as pd

# import single-cell packages
import scanpy as sc
import scvelo as scv
import palantir

# import utilities
import utils.utilities as ul

findfont: Font family ['Raleway'] not found. Falling back to DejaVu Sans.
findfont: Font family ['Lato'] not found. Falling back to DejaVu Sans.


## Print package versions for reproducibility

In [2]:
sc.logging.print_header()
palantir.__version__

scanpy==1.6.0 anndata==0.7.4 umap==0.4.6 numpy==1.19.2 scipy==1.5.2 pandas==1.1.3 scikit-learn==0.23.2 statsmodels==0.12.0 python-igraph==0.8.2 louvain==0.7.0 leidenalg==0.8.2


'1.0.0'

## Set up paths

In [3]:
sys.path.insert(0, "../../..")  # this depends on the notebook depth and must be adapted per notebook

from paths import DATA_DIR

## Load the data

Load the raw data, not the preprocessed one.

In [4]:
adata = sc.read(DATA_DIR / "morris_data" / "adata.h5ad")
del adata.layers  # we don't need any of these
adata

AnnData object with n_obs × n_vars = 104679 × 22630
    obs: 'batch'

### Load supplementary information data

This is to get the `Reprogramming Day` annotation, which we use to get the root cell.

In [5]:
annot = pd.read_csv(DATA_DIR / "morris_data" / "annotations" / "supp_table_4.csv", index_col=0, header=2)
print(annot.shape)
annot.head()

(85010, 12)


Unnamed: 0_level_0,Genes,UMIs,% Mitochondrial RNA,Cell Cycle Phase,Timecourse,Reprogramming Day,CellTagMEF,CellTagD3,CellTagD13,CellTagMEF.1,CellTagD3.1,CellTagD13.1
Cell ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
_HF-1_AGAATAGGTAGCGCTC-1,4551,6497,7.3,G1,1,0,16.0,,,16.0,,
_HF-1_AGGTCCGTCAACACGT-1,4625,6726,7.39,G1,1,0,20.0,,,20.0,,
_HF-1_CACATAGGTTCCCTTG-1,5488,6869,6.03,G1,1,0,31.0,,,31.0,,
_HF-1_CCAGCGACAGAGCCAA-1,5549,6995,7.07,G1,1,0,39.0,,,39.0,,
_HF-1_CCGGTAGCACGTGAGA-1,4901,6541,7.03,G1,1,0,,,,,,


### Load the subsets and splits

In [6]:
dfs = ul.get_split(DATA_DIR / "morris_data" / "splits")
list(dfs.keys())

[10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000]

## Define utility functions

In [7]:
def clean_orig_names(series):
    return np.array(list(map(np.array, series.str.split(":"))))[:, 1]

In [8]:
def add_annotations(adata, annot):
    adata.obs['genes_cleaned'] = clean_orig_names(adata.obs.index)
    assert len(set(adata.obs['genes_cleaned'])) == len(adata)
    
    annot['genes_cleaned'] = np.array(list(map(np.array, annot.index.str.split("_"))))[:, 2]
    annot['genes_cleaned'] = annot['genes_cleaned'].str.replace("-", "x-")
    
    tmp = adata.obs.merge(annot, how='left', on='genes_cleaned')
    tmp.drop_duplicates('genes_cleaned', inplace=True)
    tmp.set_index('genes_cleaned', drop=True, inplace=True)

    adata.obs = tmp
    adata.obs['Reprogramming Day'] = adata.obs['Reprogramming Day'].astype('category')
    
    
def select_root_cell(adata):
    obs = adata.obs['Reprogramming Day']
    min_val = np.nanmin(obs.cat.categories)
    
    return obs[obs == min_val].index[0]

In [9]:
def load_cellrank_final_states(adata, data, size, col):
    try:
        # old name: main_states
        index = clean_orig_names(data[size][col]['terminal_states'].index)
        valid_ixs = np.isin(index, adata.obs.index)

        x = data[size][col]['lin_probs'][valid_ixs, :]
        x = pd.DataFrame(x, index=index[valid_ixs])

        if len(index) < 3:
            return None

        ixs = []
        for lin in range(x.shape[1]):  # be extra careful
            y = x[~np.isin(x.index, ixs)]

            assert len(y) + len(ixs) == x.shape[0], "Sanity check failed"

            # we select the most likely cell from each terminal state
            ix = np.argmax(y.values[:, lin])
            ixs.append(y.index[ix])

        return ixs
    except Exception as e:
        print(f"Unexpected error: `{e}`.")
        return None

In [10]:
def palantir_preprocess(adata):
    sc.pp.filter_genes(adata, min_cells=10)
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, flavor='cell_ranger', n_top_genes=1500)
    
    print("Running PCA")
    n_comps = 300
    sc.pp.pca(adata, use_highly_variable=True, n_comps=n_comps)

    print("Diff maps")
    dm_res = palantir.utils.run_diffusion_maps(pd.DataFrame(adata.obsm['X_pca'][:, :n_comps],
                                                            index=adata.obs_names))
    
    print("MS space")
    ms_data = palantir.utils.determine_multiscale_space(dm_res)
    
    return ms_data

In [11]:
def benchmark_palantir(adata, dfs, annot, final_states_path):
    res = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    run_palantir = ul.timeit(palantir.core.run_palantir)
    
    with open(final_states_path, 'rb') as fin:
        # load the terminal states from cellrank
        fs_data = pickle.load(fin)
    
    for size, split in dfs.items():
        for col in split.columns:
            try:
                print(f"Subsetting data to `{size}`, split `{col}`.")
                ixs = split[col].values
                bdata = adata[ixs].copy()
                add_annotations(bdata, annot)
                
                assert bdata.n_obs == size
                
                root_cell = select_root_cell(bdata)
                final_states = load_cellrank_final_states(bdata, fs_data, size, col)

                if final_states is None:
                    print("No final states found, skipping")
                    continue
                if root_cell in final_states:
                    print("Root cell is in final states, skipping")
                    continue
                    
                print("Preprocessing")
                ms_data = palantir_preprocess(bdata)
                                  
                # this is the data we're using the figures
                print(f"Running with CellRank terminal states `root_cell={root_cell}` and "
                      f"`final_states={final_states}`")
                _, time_ts = run_palantir(ms_data,
                                          root_cell,
                                          terminal_states=final_states,
                                          knn=30,
                                          num_waypoints=int(ceil(size * 0.15)),
                                          n_jobs=32,
                                          scale_components=False,
                                          use_early_cell_as_start=True)
                
                res[size][col]['core_ts'] = time_ts
                ul.save_results(res, DATA_DIR / "benchmarking" / "runtime_analysis" / "palantir.pickle")
            except Exception as e:
                print(f"Unable to run `Palantir` with size `{size}` on split `{col}`. Reason: `{e}`.")
                continue
                
    return res

# Run the benchmarks

In [None]:
res_pala = benchmark_palantir(adata, dfs, annot,
                              final_states_path=DATA_DIR / "benchmarking" / "runtime_analysis" / "gpcca.pickle")

Subsetting data to `10000`, split `0`.
Preprocessing
Running PCA
Diff maps
Determing nearest neighbor graph...
MS space
Running with CellRank terminal states `root_cell=CCTATTAAGCATCATCx-1` and `final_states=['TACCTATAGACCACGAx-6', 'CTCGTCATCACTTCATx-14', 'TCAGGTAGTTTGACACx-5']`
Sampling and flocking waypoints...
Time for determining waypoints: 0.06873098214467367 minutes
Determining pseudotime...
Shortest path distances using 30-nearest neighbor graph...
Time for shortest paths: 0.49483526945114137 minutes
Iteratively refining the pseudotime...
Correlation at iteration 1: 0.9999
Entropy and branch probabilities...
Markov chain construction...
Computing fundamental matrix and absorption probabilities...
Project results to all cells...
Subsetting data to `10000`, split `1`.
Unexpected error: `'1'`.
Preprocessing
Running PCA


## Save the results

In [None]:
ul.save_results(res_pala, DATA_DIR / "benchmarking" / "runtime_analysis" / "palantir.pickle")