In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
from gears import PertData
import pickle
from perturbench.data.accessors.norman19 import Norman19
from perturbench.data.datasplitter import PerturbationDataSplitter

%load_ext autoreload
%autoreload 2

In [2]:
adata = Norman19().get_anndata()
adata

Loading processed data from: ../perturbench_data/norman19_processed.h5ad


AnnData object with n_obs × n_vars = 111445 × 5666
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_type', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'condition', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'
    var: 'ensemble_id', 'ncounts', 'ncells', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: 'hvg', 'log1p', 'rank_genes_groups_cov'
    layers: 'counts'

In [3]:
adata.var['gene_name'] = adata.var_names.values

In [4]:
condition = []
for cond in adata.obs.condition.to_list():
    genes = cond.split('+')
    if len(genes) == 1:
        if genes[0] == 'control':
            genes = ['ctrl']
        else:
            genes.append('ctrl')
    new_cond = '+'.join(genes)
    condition.append(new_cond)

adata.obs['condition'] = condition
adata.obs['condition'] = adata.obs['condition'].astype('category')
adata.obs['condition'].value_counts()

condition
ctrl             11855
KLF1+ctrl         1960
BAK1+ctrl         1457
CEBPE+ctrl        1233
CEBPE+RUNX1T1     1219
                 ...  
CBL+UBASH3A         64
CEBPB+CEBPA         64
C3orf72+FOXL2       59
JUN+CEBPB           59
JUN+CEBPA           54
Name: count, Length: 237, dtype: int64

In [5]:
perts_exclude = ['LYL1+IER5L', 'IER5L+ctrl', 'KIAA1804+ctrl', 'IER5L+ctrl'] 
adata = adata[~adata.obs['condition'].isin(perts_exclude)]
adata.obs['condition'] = adata.obs.condition.cat.remove_unused_categories()
adata.shape

  adata.obs['condition'] = adata.obs.condition.cat.remove_unused_categories()


(110139, 5666)

In [9]:
data_splitter = PerturbationDataSplitter(
    obs_dataframe=adata.obs,
    perturbation_key='perturbation',
    covariate_keys=['cell_type'],
    perturbation_control_value='control',
)
train_test_split = data_splitter.split_combinations(
    seed=42,
    max_heldout_fraction_per_covariate=0.7,
)

           train  val  test
('K562',)    143   46    47


  for covariates, df in self.obs_dataframe.groupby(self.covariate_keys):


In [10]:
adata.obs['combo_split'] = train_test_split.values

split_dict = {}
for split_val in adata.obs['combo_split'].unique():
    split_perts = list(adata[adata.obs['combo_split'] == split_val].obs.condition.unique())
    if split_val in ['val', 'test']:
        split_perts.remove('ctrl')
    split_dict[split_val] = split_perts

In [11]:
for k,v in split_dict.items():
    print(k, len(v))

train 143
val 45
test 46


In [None]:
gears_datapath = '../perturbench_data/gears/'
dataset_name = 'norman19'

split_dict_path = gears_datapath + f'{dataset_name}_gears_split.pkl'
with open(split_dict_path, 'wb') as f:
    pickle.dump(split_dict, f)

pert_data = PertData(gears_datapath) # specific saved folder
pert_data.new_data_process(dataset_name=dataset_name, adata=adata, skip_calc_de=False) # specific dataset name and adata object