# Running CellOracle Model for benchmark GRN inference

In [None]:
### Running CellOracle Model for benchmark GRN inference
import celloracle as co

In [5]:
import splicejac as sp

In [14]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scvelo as scv
from velovi import preprocess_data
import math
import random
from scvelo import logging as logg
import torch.nn.functional as F
import os

In [7]:
from paths import FIG_DIR, DATA_DIR

In [8]:
SAVE_FIGURES = True
if SAVE_FIGURES:
    os.makedirs(FIG_DIR / 'simulation' / 'dyngen_results',exist_ok = True)

SAVE_DATASETS = True
if SAVE_DATASETS:
    os.makedirs(DATA_DIR / 'simulation' / 'dyngen_results',exist_ok = True)
    os.makedirs(DATA_DIR / 'simulation' / 'dyngen_results' / 'copy_file',exist_ok = True)

In [9]:
from scipy.spatial.distance import cdist
import pandas as pd
import sklearn
import scipy.stats as stats
from typing_extensions import Literal
import anndata
import scipy
import matplotlib.pyplot as plt
## define function
import torch
def csgn_groundtruth(adata):
    csgn_array = adata.obsm["regulatory_network_sc"].toarray()
    csgn_tensor = torch.zeros([len(adata.uns["regulators"]),len(adata.uns["targets"]),csgn_array.shape[0]])
    
    for k in range(csgn_array.shape[0]):
        ## generate a 3D tensor to indicate the ground truth network for each cell
        grnboost_m = np.zeros((len(adata.uns["regulators"]),len(adata.uns["targets"])))
        grnboost_m = pd.DataFrame(grnboost_m,index = adata.uns["regulators"], columns = adata.uns["targets"])
        for i in range(adata.uns["regulatory_network"].shape[0]):
            #ind = (adata.uns["regulatory_network"]["regulator"] == j) & (adata.uns["regulatory_network"]["target"] == i)
            regulator = adata.uns["regulatory_network"].iloc[i]["regulator"]
            target = adata.uns["regulatory_network"].iloc[i]["target"]
            grnboost_m.loc[regulator,target] = csgn_array[k,i]
        tensor = torch.tensor(np.array(grnboost_m))
        csgn_tensor[:,:,k] = tensor
        
    return csgn_tensor

def csgn_benchmark2(GRN,W,csgn):
    csgn[csgn!=0] = 1
    if len(GRN.shape)>2:
        print("Input is cell type specific GRN...")
        score = []
        for i in range(csgn.shape[2]):
            net = csgn[:,:,i]
            #auprc = sklearn.metrics.average_precision_score(W.T.ravel(), np.abs(GRN[:,:,i].numpy().ravel()))
            pre = GRN[:,:,i][np.array(W.T)==1]
            gt = net.T[np.array(W.T)==1]
            gt[gt!=0] = 1
            
            number = min(10000,len(gt))
            pre,index = torch.topk(pre,number)
            fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = gt[index], y_score = pre, pos_label = 1) #positive class is 1; negative class is 0
            auroc = sklearn.metrics.auc(fpr, tpr)
            score.append(auroc)
    else:
        print("Input is global GRN...")
        score = []
        for i in range(csgn.shape[2]):
            net = csgn[:,:,i]
            pre = GRN[np.array(W.T)==1]
            gt = net.T[np.array(W.T)==1]
            gt[gt!=0] = 1
            
            number = min(10000,len(gt))
            pre,index = torch.topk(pre,number)
            fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = gt[index], y_score = pre, pos_label = 1) #positive class is 1; negative class is 0
            auroc = sklearn.metrics.auc(fpr, tpr)
            score.append(auroc)
    return score

In [47]:
def sanity_check(adata):
    csgn = adata.uns["csgn"]
    gene_name = adata.var.index.tolist()
    full_name = adata.uns["regulators"]
    index = [i in gene_name for i in full_name]
    full_name = full_name[index]
    adata = adata[:,full_name].copy()
    W = adata.uns["skeleton"]
    W = W[index,:]
    W = W[:,index]
    adata.uns["skeleton"] = W 
    W = adata.uns["network"]
    W = W[index,:]
    W = W[:,index]
    csgn = csgn[index,:,:]
    csgn = csgn[:,index,:]
    adata.uns["network"] = W
    adata.uns["regulators"] = gene_name
    adata.uns["targets"] = gene_name
    adata.uns["csgn"] = csgn
    
    return adata

## import dataset

In [10]:
time_corr_all = list()
gene_time_corr_all = list()
gene_velo_corr_all = list()
AUC_GRN_result = list()

folder_path = os.getcwd() + '/RegVelo_datasets/dyngen_simulation/'
# Get a list of all files in the folder
files = os.listdir(folder_path)
files = [file for file in files if file.endswith(".h5ad")]

In [12]:
len(files)

50

In [18]:
address = os.getcwd() + '/RegVelo_datasets/dyngen_simulation/' + adata_name
adata = sc.read_h5ad(address)

In [25]:
for adata_name in files:
    address = os.getcwd() + '/RegVelo_datasets/dyngen_simulation/' + adata_name
    
    adata = sc.read_h5ad(address)
    adata_raw = adata.copy()
    csgn = csgn_groundtruth(adata)
    adata.uns["csgn"] = csgn

    adata.X = adata.X.copy()
    adata.layers["spliced"] = adata.layers["counts_spliced"].copy()
    adata.layers["unspliced"] = adata.layers["counts_unspliced"].copy()


    scv.pp.filter_and_normalize(adata, min_shared_counts=10)
    sc.tl.pca(adata)
    sc.pp.neighbors(adata, n_neighbors=30, n_pcs=30)
    scv.pp.moments(adata)

    adata.X = np.log1p(adata.X.copy())


    sc.tl.leiden(adata)
    adata_raw.obs["cluster"] = adata.obs["leiden"].copy()
    adata_raw.obsm["X_pca"] = adata.obsm["X_pca"].copy()
    adata_raw.layers["spliced"] = adata_raw.layers["counts_spliced"].copy()
    adata_raw.layers["unspliced"] = adata_raw.layers["counts_unspliced"].copy()


    adata = preprocess_data(adata)
    adata = sanity_check(adata)
    adata.uns["Ms"] = adata.layers["Ms"]
    adata.uns["Mu"] = adata.layers["Mu"]

    W = adata.uns["skeleton"].copy()
    W = torch.tensor(np.array(W)).int()
    W = torch.ones(W.shape)
    intersection = list(set(adata.uns["regulators"]).intersection(adata.uns["targets"]))

    for i in intersection:
        index1 = [j == i for j in adata.uns["regulators"]]
        index2 = [j == i for j in adata.uns["targets"]]
        W[index1,index2] = 0

    
    target_list = []
    for i in range(adata.uns["skeleton"].shape[0]):
        a = list(adata.uns["targets"][W[i,:] == 1])
        a = ','.join(a)
        target_list.append(a)
    d = {"TF":adata.uns["regulators"],"Target_genes":target_list}
    df = pd.DataFrame(data=d)
    TF_to_TG_dictionary = {}

    for TF, TGs in zip(df.TF, df.Target_genes):
        # convert target gene to list
        TG_list = TGs.replace(" ", "").split(",")
        # store target gene list in a dictionary
        TF_to_TG_dictionary[TF] = TG_list

    # We invert the dictionary above using a utility function in celloracle.
    TG_to_TF_dictionary = co.utility.inverse_dictionary(TF_to_TG_dictionary)

    net = co.Net(gene_expression_matrix=adata.to_df(), # Input gene expression matrix as data frame
                 TFinfo_dic=TG_to_TF_dictionary, # Input base GRN
                 verbose=True
                 )
    
    net.fit_All_genes(bagging_number=20,
                 alpha=10, verbose=True)
    net.updateLinkList(verbose=True)
    inference_result = net.linkList.copy()
    celloracle_m = np.zeros((len(adata.uns["targets"]),len(adata.uns["regulators"])))
    celloracle_m = pd.DataFrame(celloracle_m,index = adata.uns["targets"], columns = adata.uns["regulators"])
    for i in adata.uns["targets"]:
        for j in adata.uns["regulators"]:
            ind = (inference_result["source"] == j) & (inference_result["target"] == i)
            if sum(ind) > 0:
                pdd = inference_result[ind]
                celloracle_m.loc[i,j] = pdd["-logp"].values

    # In[72]:
    score6 = csgn_benchmark2(torch.tensor(np.array(celloracle_m)),W,adata.uns["csgn"])

    ### Visualize the Violin Plots
    AUC_GRN_result.append(np.mean(score6))
    print(np.mean(score6))
    # In[79]:

    df = pd.DataFrame(AUC_GRN_result)
    df.columns = ["celloracle"]

    if SAVE_DATASETS:
        df.to_csv(DATA_DIR / 'simulation' / 'dyngen_results' /  'AUROC_res_all_celloracle.csv')

    print("Done " + adata_name + "!")

Filtered out 1 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 79)
initiation completed.


  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Input is global GRN...
0.6173405545834783
Done dataset_sim24.h5ad!
Filtered out 5 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 63)
initiation completed.


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Input is global GRN...
0.5295938221814589
Done dataset_sim20.h5ad!
Filtered out 1 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 83)
initiation completed.


  0%|          | 0/83 [00:00<?, ?it/s]

  0%|          | 0/83 [00:00<?, ?it/s]

Input is global GRN...
0.603686281854155
Done dataset_sim39.h5ad!
Filtered out 2 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 63)
initiation completed.


  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Input is global GRN...
0.542335643167199
Done dataset_sim23.h5ad!
Filtered out 1 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 72)
initiation completed.


  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Input is global GRN...
0.53910608158487
Done dataset_sim10.h5ad!
Filtered out 1 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 84)
initiation completed.


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/84 [00:00<?, ?it/s]

Input is global GRN...
0.6706305515129201
Done dataset_sim18.h5ad!
Filtered out 1 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 75)
initiation completed.


  0%|          | 0/75 [00:00<?, ?it/s]

  0%|          | 0/75 [00:00<?, ?it/s]

Input is global GRN...
0.612868422627679
Done dataset_sim8.h5ad!
Filtered out 1 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


  0%|          | 0/90 [00:00<?, ?it/s]

initiating Net object ...
gem_shape: (300, 84)
initiation completed.


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/84 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [48]:
time_corr_all = list()
gene_time_corr_all = list()
gene_velo_corr_all = list()
AUC_GRN_result = list()

for adata_name in files:
    address = os.getcwd() + '/RegVelo_datasets/dyngen_simulation/' + adata_name
    
    adata = sc.read_h5ad(address)
    adata_raw = adata.copy()
    csgn = csgn_groundtruth(adata)
    adata.uns["csgn"] = csgn

    adata.X = adata.X.copy()
    adata.layers["spliced"] = adata.layers["counts_spliced"].copy()
    adata.layers["unspliced"] = adata.layers["counts_unspliced"].copy()


    scv.pp.filter_and_normalize(adata, min_shared_counts=10)
    sc.tl.pca(adata)
    sc.pp.neighbors(adata, n_neighbors=30, n_pcs=30)
    scv.pp.moments(adata)

    adata.X = np.log1p(adata.X.copy())


    #sc.tl.leiden(adata)
    adata_raw.obs["cluster"] = adata.obs["leiden"].copy()
    adata_raw.obsm["X_pca"] = adata.obsm["X_pca"].copy()
    adata_raw.layers["spliced"] = adata_raw.layers["counts_spliced"].copy()
    adata_raw.layers["unspliced"] = adata_raw.layers["counts_unspliced"].copy()


    adata = preprocess_data(adata)
    adata = sanity_check(adata)
    
    adata.uns["Ms"] = adata.layers["Ms"]
    adata.uns["Mu"] = adata.layers["Mu"]

    W = adata.uns["skeleton"].copy()
    W = torch.tensor(np.array(W)).int()
    W = torch.ones(W.shape)
    intersection = list(set(adata.uns["regulators"]).intersection(adata.uns["targets"]))

    for i in intersection:
        index1 = [j == i for j in adata.uns["regulators"]]
        index2 = [j == i for j in adata.uns["targets"]]
        W[index1,index2] = 0

    ## We ignore the cell label information and assume all cells is the same label
    adata.obs["clusters"] = "1"
    sp.tl.estimate_jacobian(adata,n_top_genes = adata.shape[1],min_shared_counts=0)
    weight_quantile=.5
    genes = list(adata.var_names)
    n = len(genes)
    A = adata.uns['average_jac']["1"][0][0:n, n:].copy().T
    A = A.T

    # In[72]:
    score6 = csgn_benchmark2(torch.tensor(np.array(A)),W,adata.uns["csgn"])

    ### Visualize the Violin Plots
    AUC_GRN_result.append(np.mean(score6))
    print(np.mean(score6))
    # In[79]:

    df = pd.DataFrame(AUC_GRN_result)
    df.columns = ["spliceAJ"]
    
    if SAVE_DATASETS:
        df.to_csv(DATA_DIR / 'simulation' / 'dyngen_results' /  'AUROC_res_all_spliceJAC.csv')

    print("Done " + adata_name + "!")

Filtered out 1 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)
Extracted 79 highly variable genes.
Logarithmized X.
Running quick regression...
Running subset regression on the 1 cluster...
Input is global GRN...
0.7344667306942315
Done dataset_sim24.h5ad!
Filtered out 5 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
Logarithmized X.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)
Extracted 63 highly 

Exception ignored in: <function tqdm.__del__ at 0x7f0affa0baf0>
Traceback (most recent call last):
  File "/home/icb/weixu.wang/miniconda3/envs/celloracle_env/lib/python3.8/site-packages/tqdm/std.py", line 1147, in __del__
    def __del__(self):
KeyboardInterrupt: 


computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)
Extracted 84 highly variable genes.
Logarithmized X.
Running quick regression...
Running subset regression on the 1 cluster...
Input is global GRN...
0.5310126513520153
Done dataset_sim18.h5ad!


KeyboardInterrupt: 