In [1]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns

import celloracle as co

In [2]:
adata = sc.read_h5ad('/dfs/project/perturb-gnn/datasets/Norman2019/Norman2019_hvg+perts.h5ad')
#ctrl_adata = adata[adata.obs['condition']=='ctrl']
#ctrl_adata.var = ctrl_adata.var.set_index('gene_name')

In [3]:
TF_names = pd.read_csv('TF_names_v_1.01.txt', delimiter='\t', header=None)
TF_names = TF_names.rename(columns={0:'Gene'})

all_conds = [c.split('+') for c in adata.obs['condition'].values ]
all_conds = [item for sublist in all_conds for item in sublist]
all_conds = set(all_conds)

# treat all perturbations as TFs
# aug_TF_names = list(TF_names['Gene'].values) + list(all_conds)

In [4]:
#sc.pp.subsample(adata, n_obs=500)
#sc.pp.pca(adata)
adata.var = adata.var.set_index('gene_name')
adata.obs['label']='0'

In [None]:
oracle = co.Oracle()
oracle.import_anndata_as_raw_count(adata=adata,
                                   cluster_column_name='condition',
                                   embedding_name='X_pca')

oracle.perform_PCA()

n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]
n_cell = oracle.adata.shape[0]
k = int(0.025*n_cell)

oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,
                      b_maxl=k*4, n_jobs=4)

base_GRN = co.data.load_human_promoter_base_GRN()

# You can load TF info dataframe with the following code.
oracle.import_TF_data(TF_info_matrix=base_GRN)

5045 genes were found in the adata. Note that Celloracle is intended to use around 1000-3000 genes, so the behavior with this number of genes may differ from what is expected.


In [6]:
oracle.fit_GRN_for_simulation(GRN_unit='whole')

In [5]:
# Save cell oracle object
#oracle.to_hdf5("Norman19.celloracle.oracle")

oracle = co.load_hdf5("Norman19.celloracle.oracle")

In [6]:
oracle

Oracle object

Meta data
    celloracle version used for instantiation: 0.10.12
    n_cells: 91205
    n_genes: 5045
    cluster_name: condition
    dimensional_reduction_name: X_pca
    n_target_genes_in_TFdict: 27150 genes
    n_regulatory_in_TFdict: 1094 genes
    n_regulatory_in_both_TFdict_and_scRNA-seq: 181 genes
    n_target_genes_both_TFdict_and_scRNA-seq: 3436 genes
    k_for_knn_imputation: 2280
Status
    Gene expression matrix: Ready
    BaseGRN: Ready
    PCA calculation: Done
    Knn imputation: Done
    GRN calculation for simulation: Done

In [7]:
def check_pert(x, pertable_genes):
    x1, x2 = x.split('+')
    if x1 not in pertable_genes and x1 != 'ctrl':
        return False
    if x2 not in pertable_genes and x2 != 'ctrl':
        return False
    
    else:
        return True
    
def get_pert_value(g):
    
    if g+'+ctrl' in adata.obs['condition']:
        pert_value = adata[adata.obs['condition'] == g+'+ctrl'][:,g].X.mean()
            
    else:
        pert_value = adata[adata.obs['condition'] == 'ctrl+'+g][:,g].X.mean()
            
    return pert_value

def get_pert_effect(pert):
    
    g1,g2 = pert.split('+')
    pert_conditions = {}
    
    if g1 != 'ctrl':
        pert_value_g1 = get_pert_value(g1)
        if pert_value_g1 <0:
            pert_value_g1 = 1
        pert_conditions.update({g1:pert_value_g1})
    
    if g2 != 'ctrl':
        pert_value_g2 = get_pert_value(g2)
        if pert_value_g2 <0:
            pert_value_g2 = 1
        pert_conditions.update({g2:pert_value_g2})
    
    ctrl_idxs = np.where(oracle.adata.obs['condition']=='ctrl')[0]
    oracle.simulate_shift(perturb_condition=pert_conditions,
                          ignore_warning=True,
                          n_propagation=3)
    
    perturbed_expression = oracle.adata.layers['simulated_count'][ctrl_idxs,:]
    perturbed_expression = perturbed_expression.mean(0)
    
    #_ = [oracle.adata.layers.pop(k) for k in ['simulation_input', 'simulated_count']]
         
    return perturbed_expression

In [None]:
for split_num in range(1,6):
    split_file = '/dfs/project/perturb-gnn/datasets/data/norman/splits/norman_simulation_'+str(split_num)+'_0.75.pkl'
    split_perts = pd.read_pickle(split_file)
    test_perts = split_perts['test']

    ctrl_adata = adata[adata.obs['condition']=='ctrl']
    ctrl_mean = ctrl_adata.X.toarray().mean(0)

    unique_perts = set(np.hstack([x.split('+') for x in adata.obs['condition'].values]))
    pertable_genes = [x for x in unique_perts if x in TF_names.iloc[:,0].values]
    pertable_test_perts = [p for p in test_perts if check_pert(p, pertable_genes)]

    perturbed_expression = {}

    for pert in pertable_test_perts:
        ## Retry with repeated reloading
        oracle._clear_simulation_results()
        if pert not in perturbed_expression:
            print(pert)
            try:
                perturbed_expression[pert] = get_pert_effect(pert)
            except:
                print('Failed: '+pert)
    
    #np.save('CellOracle_preds_pert_exp_split_'+str(split_num), perturbed_expression)
    np.save('CellOracle_preds_pert_exp_split_retry_'+str(split_num), perturbed_expression)

MEIS1+ctrl
KLF1+FOXA1
TBX3+TBX2
CEBPE+KLF1
ZNF318+FOXL2
Failed: ZNF318+FOXL2
JUN+CEBPA


  return self.astype(np.float_)._mul_scalar(1./other)


ctrl+MEIS1
ETS2+CEBPE
POU3F2+FOXL2
AHR+KLF1
CEBPB+CEBPA
FOXL2+MEIS1
FOXL2+ctrl
FOSB+CEBPE
FOSB+CEBPB
FOXA3+HOXB9
OSR2+ctrl
ctrl+SPI1
CEBPB+ctrl
CEBPB+OSR2
FEV+ISL2
JUN+ctrl
FOXA1+HOXB9
ZBTB10+ctrl
Failed: ZBTB10+ctrl
CEBPE+SPI1
FOXA1+FOXL2
FOXF1+FOXL2
LYL1+CEBPB
ctrl+CEBPB
PRDM1+ctrl
FOSB+OSR2
FOXL2+HOXB9
ctrl+OSR2
JUN+CEBPB
ZBTB10+SNAI1
Failed: ZBTB10+SNAI1
ctrl+FOXL2
CEBPE+CEBPB
FOXA3+FOXL2
SPI1+ctrl
EGR1+ctrl
ZBTB10+DLX2
Failed: ZBTB10+DLX2
SNAI1+DLX2
ctrl+FOXA1
FOXA3+FOXA1


  return self.astype(np.float_)._mul_scalar(1./other)


ctrl+ETS2
KLF1+FOXA1
HES7+ctrl
ZNF318+FOXL2
Failed: ZNF318+FOXL2
FOXO4+ctrl
JUN+CEBPA
