In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
from tqdm import tqdm

import anndata
import scanpy as sc

from cytofuture_data.gene_name_mapping import GeneNameMapper

In [None]:
# Load the gene name mapper
gene_name_mapper = GeneNameMapper(
    '../../standard_genes/gene_names/human_genes.csv',
    '../../standard_genes/gene_names/mouse_genes.csv',
    '../../standard_genes/gene_names/orthologue_map_human2mouse_best.csv',
    '../../standard_genes/gene_names/orthologue_map_mouse2human_best.csv'
)

In [None]:
# Create a map from Entrez ID to Ensembl ID
entrez_name_df = pd.read_csv(
    '/home/xingjie/Data/data2/cytofuture/databases/PertOrg/mart_export_ensembl_entrez.csv')
entrez_name_df = entrez_name_df[~entrez_name_df['NCBI gene (formerly Entrezgene) ID'].isna()].copy()
entrez_name_df = entrez_name_df.drop_duplicates(subset='NCBI gene (formerly Entrezgene) ID').copy()
display(entrez_name_df)

entrez_to_ensembl_map = {int(r[1]['NCBI gene (formerly Entrezgene) ID']) : r[1]['Gene stable ID']
                        for r in entrez_name_df.iterrows()}

In [None]:
# Load the meta data for the database
database_df = pd.read_csv(
    '/home/xingjie/Data/data2/cytofuture/databases/PertOrg/Genetical_perurbation_Datasets.csv')

database_df = database_df[database_df['organism'] == 'Mus_musculus']

In [None]:
# Load the meta data for the database
database_df = pd.read_csv(
    '/home/xingjie/Data/data2/cytofuture/databases/PertOrg/Genetical_perurbation_Datasets.csv')

database_df = database_df[database_df['organism'] == 'Mus_musculus']
database_df = database_df[database_df['edittype'].isin([
    'knockout', 'transgene', 'overexpression', 'knockdown', 'knockin',
])]
database_df = database_df.set_index('Pertorgid')
database_df

In [None]:
standard_gene_df = pd.read_csv('standard_genes.csv').set_index('human_id')
pert_sign_map = {'knockout' : -1, 'transgene' : 1, 
                 'overexpression' : 1, 'knockdown' : -1, 
                 'knockin' : 1,}

obs_dict = {
    'id' : [],
    'condition' : [],
    'perturbed_gene' : [],
    'perturbation_sign' : [],
}
X_perturb = []
X_control = []
X_measure_mask = []

dataset_id = 'PertOrg'

# Get the perturbation data
for pertorgid in tqdm(database_df.index):
    # Get the perturbed gene
    pg = gene_name_mapper.map_gene_names([database_df.loc[pertorgid, 'genotype']],
                                         'mouse', 'human', 'name', 'id')[0]
    if pg == 'na':
        continue

    # Load the dataset
    dataset_file = os.path.join('/home/xingjie/Data/data2/cytofuture/databases/PertOrg/Differential_Gene_Expression_Analysis_Result_of_Each_PertOrg_Dataset',
                    f'{pertorgid}.txt')
    
    if not os.path.exists(dataset_file):
        continue
    
    dataset_df = pd.read_csv(dataset_file, sep='\t')
    dataset_df = dataset_df[~dataset_df['GeneID'].isna()]
    dataset_df['target_gene'] = gene_name_mapper.map_gene_names(
                                        dataset_df['GeneID'].astype(int).map(entrez_to_ensembl_map),
                                        'mouse', 'human', 'id', 'id')
    dataset_df = dataset_df[dataset_df['target_gene'] != 'na']
    dataset_df['AveExpr_Control'] = dataset_df['AveExpr_Control'].astype(float)
    dataset_df['AveExpr_Case'] = dataset_df['AveExpr_Case'].astype(float)
    dataset_df = dataset_df[~dataset_df['AveExpr_Control'].isna()]
    dataset_df = dataset_df[~dataset_df['AveExpr_Case'].isna()].copy()
    dataset_df = dataset_df.set_index('target_gene')
    dataset_df = dataset_df[~dataset_df.index.duplicated(keep='first')]

    # Filter out the dataset if there are too few genes 
    # or the expression counts are not raw
    if ((dataset_df.shape[0] < 5000)
        or (dataset_df['AveExpr_Case'].min() < 0)
        or (dataset_df['AveExpr_Control'].min() < 0)
        or (dataset_df['AveExpr_Case'].max() < 100)
        or (dataset_df['AveExpr_Control'].max() < 100)):
        continue

    # Get the sign of perturbation
    pert_type = database_df.loc[pertorgid, 'edittype']
    perturbation_sign = pert_sign_map[pert_type]

    obs_dict['id'].append(f'{dataset_id}_{pertorgid}')
    obs_dict['condition'].append(f'{dataset_id}_{pertorgid}')
    obs_dict['perturbed_gene'].append(pg)
    obs_dict['perturbation_sign'].append(perturbation_sign) # -1 for knockdown, 1 for overexpression
    
    standard_gene_df['perturb'] = dataset_df['AveExpr_Case']
    standard_gene_df['control'] = dataset_df['AveExpr_Control']
    standard_gene_df['measure_mask'] = ~standard_gene_df['perturb'].isna()

    X_perturb.append(np.nan_to_num(standard_gene_df['perturb'].values, 
                           nan=0).astype(np.float32))
    X_control.append(np.nan_to_num(standard_gene_df['control'].values, 
                           nan=0).astype(np.float32))
    X_measure_mask.append(standard_gene_df['measure_mask'].values.astype(np.float32))

X_perturb = np.array(X_perturb)
X_control = np.array(X_control)
X_perturb = np.log1p(X_perturb / X_perturb.sum(axis=1, keepdims=True) * 1e4)
X_control = np.log1p(X_control / X_control.sum(axis=1, keepdims=True) * 1e4)
X_shift = X_perturb - X_control

X_measure_mask = np.array(X_measure_mask)

# Create the perturbation anndata
adata_perturb = anndata.AnnData(
    X=X_shift, 
    obs=pd.DataFrame(obs_dict).set_index('id'),
    var=standard_gene_df[[]].copy(),
)
adata_perturb.layers['measure_mask'] = X_measure_mask
adata_perturb.layers['control'] = X_control

adata_perturb.obs['perturbed_gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.obs['perturbed_gene'], 'human', 'human', 'id', 'name')
adata_perturb.var['gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.var.index, 'human', 'human', 'id', 'name')

adata_perturb.write_h5ad(os.path.join(
    '/home/xingjie/Data/data2/cytofuture/datasets/perturbation_data', 
    f'{dataset_id}.h5ad'))

plt.hist(adata_perturb.X.flatten(), bins=100)
adata_perturb

In [None]:
plt.hist(adata_perturb.X.flatten(), bins=100)
plt.ylim(0, 1000)