In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
from tqdm.auto import tqdm

import anndata
import scanpy as sc

from cytofuture_data.gene_name_mapping import GeneNameMapper

In [None]:
# Load the gene name mapper
gene_name_mapper = GeneNameMapper(
    '../../standard_genes/gene_names/human_genes.csv',
    '../../standard_genes/gene_names/mouse_genes.csv',
    '../../standard_genes/gene_names/orthologue_map_human2mouse_best.csv',
    '../../standard_genes/gene_names/orthologue_map_mouse2human_best.csv'
)

In [None]:
# Load the omnipath signaling network
op_df = pd.read_csv('/home/xingjie/Data/data2/cytofuture/databases/omnipath/omnipath_webservice_interactions__latest.tsv',
                            sep='\t', low_memory=False)
op_df['source_genesymbol'] = op_df['source_genesymbol'].str.upper()
op_df['target_genesymbol'] = op_df['target_genesymbol'].str.upper()

print(op_df.shape)
op_df.columns

In [None]:
gene_corr_df = pd.read_parquet('../../gene_correlations/gene_corr_df_measured.parquet')
gene_corr_df

In [None]:
# Get the transcriptional edges
op_df = op_df[op_df['type'] == 'transcriptional'].copy()

# Get the non-redudant edges
op_df = op_df.sort_values('curation_effort', ascending=False)
op_df = op_df.drop_duplicates(subset=['source_genesymbol', 'target_genesymbol'])

op_df['source_id'] = gene_name_mapper.map_gene_names(op_df['source_genesymbol'], 
                                                             'human', 'human', 'name', 'id')
op_df['target_id'] = gene_name_mapper.map_gene_names(op_df['target_genesymbol'], 
                                                             'human', 'human', 'name', 'id')

op_df = op_df[op_df['source_id'] != 'na']
op_df = op_df[op_df['target_id'] != 'na'].copy()

# Only keep the edges that are in the gene correlation matrix
op_df = op_df[op_df['source_id'].isin(gene_corr_df.index) 
              & op_df['target_id'].isin(gene_corr_df.index)].copy()

# Get the weights and signs of the edges
op_df['weight'] = (1 + op_df['curation_effort']) / (10 + op_df['curation_effort'])

signs = []

for i, row in tqdm(op_df.iterrows(), total=len(op_df)):
    if row['consensus_stimulation'] == 1:
        signs.append(1)
    elif row['consensus_inhibition'] == 1:
        signs.append(-1)
    else:
        # If the edge is not annotated, we use the correlation to determine
        signs.append(gene_corr_df.loc[row['source_id'], row['target_id']]) 

op_df['sign'] = signs
op_df = op_df[['source_id', 'target_id', 'sign', 'weight']].copy()

In [None]:
standard_gene_df = pd.read_csv('standard_genes.csv').set_index('human_id')
op_df = op_df[op_df['target_id'].isin(standard_gene_df.index)].copy()

obs_dict = {
    'id' : [],
    'condition' : [],
    'perturbed_gene' : [],
    'perturbation_sign' : [],
}
X = []
X_measure_mask = []

dataset_id = 'omnipath'
tfs = np.unique(op_df['source_id'])

# Get the perturbation data
for tf in tqdm(tfs):
    sample_df = op_df[op_df['source_id'] == tf]
    sample_df = sample_df.set_index('target_id')
    sample_df = sample_df[~sample_df.index.duplicated(keep='first')]

    obs_dict['id'].append(f'{dataset_id}_{tf}')
    obs_dict['condition'].append(dataset_id)
    obs_dict['perturbed_gene'].append(tf)
    obs_dict['perturbation_sign'].append(1) # -1 for knockdown, 1 for overexpression
    
    log2fcs = 3 * sample_df['sign'] * sample_df['weight']

    standard_gene_df['log2fc'] = log2fcs
    standard_gene_df['measure_mask'] = True

    X.append(np.nan_to_num(standard_gene_df['log2fc'].values, 
                           nan=0).astype(np.float32))
    X_measure_mask.append(standard_gene_df['measure_mask'].values.astype(np.float32))

X = np.array(X)
X_measure_mask = np.array(X_measure_mask)

# Create the perturbation anndata
adata_perturb = anndata.AnnData(
    X=X, 
    obs=pd.DataFrame(obs_dict).set_index('id'),
    var=standard_gene_df[[]].copy(),
)
adata_perturb.layers['measure_mask'] = X_measure_mask
adata_perturb.layers['control'] = np.zeros(X.shape, dtype=np.float32)

adata_perturb.obs['perturbed_gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.obs['perturbed_gene'], 'human', 'human', 'id', 'name')
adata_perturb.var['gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.var.index, 'human', 'human', 'id', 'name')

adata_perturb.write_h5ad(os.path.join(
    '/home/xingjie/Data/data2/cytofuture/datasets/perturbation_data', 
    f'{dataset_id}.h5ad'))

plt.hist(adata_perturb.X.flatten(), bins=100)
adata_perturb

In [None]:
plt.hist(adata_perturb.X.flatten(), bins=100)
plt.ylim(0, 1000)