## (1) Load Processed Data in the Paper

Currently, we support norman / adamson / dixit.

In [1]:
import sys
sys.path.append('../../')

from omnicell.models.gears.pertdata import PertData


## (2) Create your own Perturb-Seq data
Prepare a scanpy adata object with 
1. `adata.obs` dataframe has `condition` and `cell_type` columns, where `condition` is the perturbation name for each cell. Control cells have condition format of `ctrl`, single perturbation has condition format of `A+ctrl` or `ctrl+A`, combination perturbation has condition format of `A+B`.
2. `adata.var` dataframe has `gene_name` column, where each gene name is the gene symbol.
3. `adata.X` stores the post-perturbed gene expression. 

Here an example using dixit 2016 dataset.

In [2]:
import scanpy as sc
adata = sc.read('/orcd/data/omarabu/001/Omnicell_datasets/satija_IFNB_raw/Seurat_IFNB.h5ad')
adata

AnnData object with n_obs × n_vars = 328542 × 34025
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'bc1_well', 'bc2_well', 'bc3_well', 'percent.mito', 'cell_type', 'pathway', 'RNA_snn_res.0.9', 'seurat_clusters', 'sample_ID', 'Batch_info', 'guide', 'gene', 'mixscale_score'
    var: 'gene'

In [3]:
adata.var["gene_name"] = adata.var["gene"]
adata.obs["condition"] = adata.obs["gene"]


In [4]:
#We relabel NT as ctrl and all other entries as some_entry+ctrl

perts = [p for p in adata.obs["condition"].unique() if p != "NT"]
adata.obs["condition"] = adata.obs["condition"].replace({"NT":"ctrl"})
adata.obs["condition"] = adata.obs["condition"].replace({p:p+"+ctrl" for p in perts})



In [5]:
adata.obs

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,sample,bc1_well,bc2_well,bc3_well,percent.mito,cell_type,pathway,RNA_snn_res.0.9,seurat_clusters,sample_ID,Batch_info,guide,gene,mixscale_score,condition
07_48_88_1_1_1_1_1_1_1_1_1,7,9816,4122,A549_IFNB,A7,D12,H4,1.161369,A549,IFNB,15.0,15.0,sample_1,Rep1,TRAFD1g3,TRAFD1,-0.290358,TRAFD1+ctrl
06_04_63_1_1_1_1_1_1_1_1_1,6,9359,4112,A549_IFNB,A6,A4,F3,3.835880,A549,IFNB,15.0,15.0,sample_1,Rep1,HES4g3,HES4,0.121449,HES4+ctrl
06_28_67_1_1_1_1_1_1_1_1_1,6,8999,3854,A549_IFNB,A6,C4,F7,9.189910,A549,IFNB,15.0,15.0,sample_1,Rep1,NTg8,NT,0.000000,ctrl
06_27_93_1_1_1_1_1_1_1_1_1,6,8384,3600,A549_IFNB,A6,C3,H9,3.268130,A549,IFNB,15.0,15.0,sample_1,Rep1,STAT5Ag1,STAT5A,0.377627,STAT5A+ctrl
06_81_38_1_1_1_1_1_1_1_1_1,6,7925,3580,A549_IFNB,A6,G9,D2,3.798107,A549,IFNB,15.0,15.0,sample_1,Rep1,STAT4g2,STAT4,1.000000,STAT4+ctrl
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
09_10_48_2_2,9,2362,1491,MCF7_IFNB,A9,A10,D12,1.820491,MCF7,IFNB,,,sample_16,Rep2,NFE2L3g3,NFE2L3,-0.009685,NFE2L3+ctrl
10_73_27_2_2,10,2143,1483,MCF7_IFNB,A10,G1,C3,0.326645,MCF7,IFNB,,,sample_16,Rep2,IRF3g2,IRF3,-0.402627,IRF3+ctrl
10_24_48_2_2,10,1921,1322,MCF7_IFNB,A10,B12,D12,6.507028,MCF7,IFNB,,,sample_16,Rep2,STAT6g2,STAT6,-0.356580,STAT6+ctrl
10_12_27_2_2,10,1676,1169,MCF7_IFNB,A10,A12,C3,2.923628,MCF7,IFNB,,,sample_16,Rep2,HERC5g3,HERC5,-0.045158,HERC5+ctrl


### Suggested normalization

For raw count data we recommend the following normalization and subsetting to the top 5000 most variable genes

In [6]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata,n_top_genes=5000, subset=True)

### Create dataloader

GEARS will take it from here. The new data processing takes around 15 minutes for 5K genes and 100K cells. 

In [7]:
import sys
sys.path.append('../')

from omnicell.models.gears.pertdata import PertData

pert_data = PertData('./data') # specific saved folder
pert_data.new_data_process(dataset_name = 'satija', adata = adata) # specific dataset name and adata object
print(f"Data processed and saved in {pert_data.data_path}")
pert_data.load(data_path = './data/satija') # load the processed data, the path is saved folder + dataset_name
print(f"Data loaded from {pert_data.data_path}")
pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
print(f"Data split with seed 1")
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader
print(f"Data loader prepared")

Found local copy...
Found local copy...
Creating pyg object for each cell in the data...
Creating dataset file...
  0%|          | 0/62 [00:00<?, ?it/s]TEEST TIMER: 0.00 seconds
Adata shape: (5339, 5000)
Duration for double loop in TRAFD1+ctrl: 1074.72 seconds
  0%|          | 0/62 [17:54<?, ?it/s]


ValueError: not enough values to unpack (expected 3, got 2)