In [4]:
import subprocess
import os
import sys
import matplotlib.backends.backend_pdf
import scanpy as sc
import matplotlib.pyplot as pl
import anndata as ad
import pandas as pd
import numpy as np
import seaborn as sns
import scvelo as scv
scv.settings.verbosity=1

from pathlib import Path

# Jupyter stuff
from tqdm.auto import tqdm
from IPython.display import clear_output
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

%matplotlib inline

# Custom functions
sys.path.insert(1, '../..')
%load_ext autoreload
%autoreload 2
from utils import *

# scperturb package
sys.path.insert(1, '../../package/src/')
from scperturb import *

from pathlib import Path
figure_path = Path('../../figures/')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


>Computing time: E-tests require permutation testing of randomly assigned nearest neighbors, which may be computationally expensive. In the future, datasets are going to be large. How does the computation time scale as a function of, #perturbations, #cells per perturbation, #nearest neighbors. The GWPS data set might be a good test case scenario.

- E-test scales linear with the number of perturbations
- Test this, valid point. I think worst case it will scale quadratically.
- not applicable

In [5]:
DATADIR = Path('/data/gpfs-1/users/peidlis_c/work/data/perturbation_resource_paper')
TEMPDIR = Path("/fast/scratch/users/peidlis_c/perturbation_resource_paper/")

def leftsided_chebyshev_nodes(N):
    # Takes negative chebyshev nodes of first kind, forces 0 to be included, then adds 1 to all.
    return np.polynomial.chebyshev.chebpts1(N*2+1)[:(N+1)] + 1

In [None]:
# adata = sc.read(DATADIR / 'ReplogleWeissman2022_K562_gwps.h5ad')

In [None]:
adata

Computation time approximative formula: 
$$
T \approx K \cdot I \cdot (M \cdot N)^x
$$

https://stackoverflow.com/questions/20507646/how-is-the-complexity-of-pca-ominp3-n3:

d=dim of gex data

Covariance matrix computation is $O(d^2n)$; its eigen-value decomposition is $O(d^3)$. So, the complexity of PCA is $O(d^2n+d^3)$.

Can be reduced, e.g. by using chunked PCA or only fitting on a subset of cells to begin with.

Let p be dimension of pca space.
Computational complexity of pairwise distance computation is $$(n+m)(n+m-1)/2 * 3 * p$$

- shuffling: $O(n)$
- summations: $N*(N-1) * p + M*(M-1) * p + N*M*p$
- averages divisions: $3$
- final sum: $3$

- n: Number of cells per perturbation (assuming same size)
- m: Number of control cells
- k: Number of perturbations
- s: Number of permutations in the E-test
- d: dim of data (e.g. 2000 HVGs)
- p: dim of PCA (e.g. 50)

$$\text{Complexity}_{\text{Etest}}=\text{Complexity}_{\text{PCA}}+\text{Complexity}_{\text{Pairwise distance}}+\text{Complexity}_{\text{permutation test}}$$
$$\text{Complexity}_{\text{PCA}}=O(d^2(k*n+m)+d^3)$$
$$\text{Complexity}_{\text{Pairwise distance}}=O(k*(n+m)^2*p)$$
$$\text{Complexity}_{\text{permutation test}}=O(s*k*p*(n^2+m^2+nm))$$

In [10]:
odata = sc.read(DATADIR / 'PapalexiSatija2021_eccite_RNA.h5ad')

In [12]:
import time
times = {}
for n in [10,20,50,100,200]:
    t = {}
    # basic qc and pp
    adata = odata.copy()
    sc.pp.filter_cells(adata, min_counts=1000)
    adata = equal_subsampling(adata, 'perturbation', N_min=n)
    adata.layers['counts'] = adata.X.copy()
    sc.pp.normalize_per_cell(adata)
    sc.pp.filter_genes(adata, min_cells=50)
    sc.pp.log1p(adata)

    # select HVGs
    n_var_max = 2000  # max total features to select
    sc.pp.highly_variable_genes(adata, n_top_genes=2000,
                                subset=False, flavor='seurat_v3', layer='counts')
    t0 = time.time()
    sc.pp.pca(adata, use_highly_variable=True)
    t['pca'] = time.time() - t0
    
    t0 = time.time()
    edist_to_control(adata)
    t['edist'] = time.time() - t0
    
    t0 = time.time()
    etest(adata, runs=1000)
    t['etest'] = time.time() - t0
    
    times[n] = t

100%|██████████| 97/97 [00:00<00:00, 213.63it/s]
100%|██████████| 1000/1000 [00:46<00:00, 21.39it/s]
100%|██████████| 92/92 [00:00<00:00, 189.60it/s]
100%|██████████| 1000/1000 [00:45<00:00, 21.97it/s]
100%|██████████| 84/84 [00:00<00:00, 180.92it/s]
100%|██████████| 1000/1000 [00:44<00:00, 22.43it/s]
100%|██████████| 69/69 [00:00<00:00, 166.11it/s]
100%|██████████| 1000/1000 [00:43<00:00, 22.86it/s]
100%|██████████| 47/47 [00:00<00:00, 135.14it/s]
100%|██████████| 1000/1000 [00:44<00:00, 22.44it/s]


In [13]:
pd.DataFrame(times)

Unnamed: 0,10,20,50,100,200
pca,0.889914,0.577637,0.968837,1.601215,2.109844
edist,0.931559,0.893482,0.851183,0.749554,0.604733
etest,49.750102,48.326092,47.633734,46.801937,47.337088


# Speadup with GPUs

In [3]:
# https://jejjohnson.github.io/research_journal/tutorials/jax/lab_tutorials/pairwise/
#https://github.com/google/jax/issues/1918