In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
from tqdm.auto import tqdm

import anndata
import scanpy as sc

from cytofuture_data.gene_name_mapping import GeneNameMapper

In [None]:
# Load the gene name mapper
gene_name_mapper = GeneNameMapper(
    '../../standard_genes/gene_names/human_genes.csv',
    '../../standard_genes/gene_names/mouse_genes.csv',
    '../../standard_genes/gene_names/orthologue_map_human2mouse_best.csv',
    '../../standard_genes/gene_names/orthologue_map_mouse2human_best.csv'
)

In [None]:
database_df = pd.read_csv('/home/xingjie/Data/data2/cytofuture/databases/KnockTF2/knocktf_v2_main_mouse.txt',
                     sep='\t', low_memory=False)

database_df['target_gene'] = gene_name_mapper.map_gene_names(database_df['Gene'], 
                                                            'mouse', 'human', 'name', 'id')
database_df = database_df[database_df['target_gene'] != 'na'].copy()

database_df['perturbed_gene'] = gene_name_mapper.map_gene_names(database_df['TF'], 
                                                            'mouse', 'human', 'name', 'id')
database_df = database_df[database_df['perturbed_gene'] != 'na']

database_df = database_df[~database_df['Mean_Case'].isna()]
database_df = database_df[~database_df['Mean_Control'].isna()]

database_df

In [None]:
standard_gene_df = pd.read_csv('standard_genes.csv').set_index('human_id')

obs_dict = {
    'id' : [],
    'condition' : [],
    'perturbed_gene' : [],
    'perturbation_sign' : [],
}
X_perturb = []
X_control = []
X_measure_mask = []

dataset_id = 'knockTF_mouse'
samples = np.unique(database_df['Sample_ID'])

# Get the perturbation data
for c in tqdm(samples):
    sample_df = database_df[database_df['Sample_ID'] == c]
    sample_df = sample_df.set_index('target_gene')
    sample_df = sample_df[~sample_df.index.duplicated(keep='first')]

    # Filter out the dataset if there are too few genes 
    # or the expression counts are not raw
    if ((sample_df.shape[0] < 5000)
        or (sample_df['Mean_Case'].min() < 0)
        or (sample_df['Mean_Control'].min() < 0)
        or (sample_df['Mean_Case'].max() < 100)
        or (sample_df['Mean_Control'].max() < 100)):
        continue

    obs_dict['id'].append(f'{dataset_id}_{c}')
    obs_dict['condition'].append(f'{dataset_id}_{c}')
    obs_dict['perturbed_gene'].append(sample_df['perturbed_gene'].iloc[0])
    obs_dict['perturbation_sign'].append(-1) # -1 for knockdown, 1 for overexpression

    standard_gene_df['perturb'] = sample_df['Mean_Case']
    standard_gene_df['control'] = sample_df['Mean_Control']
    standard_gene_df['measure_mask'] = ~standard_gene_df['perturb'].isna()

    X_perturb.append(np.nan_to_num(standard_gene_df['perturb'].values, 
                           nan=0).astype(np.float32))
    X_control.append(np.nan_to_num(standard_gene_df['control'].values, 
                           nan=0).astype(np.float32))
    X_measure_mask.append(standard_gene_df['measure_mask'].values.astype(np.float32))

X_perturb = np.array(X_perturb)
X_control = np.array(X_control)
X_perturb = np.log1p(X_perturb / X_perturb.sum(axis=1, keepdims=True) * 1e4)
X_control = np.log1p(X_control / X_control.sum(axis=1, keepdims=True) * 1e4)
X_shift = X_perturb - X_control

X_measure_mask = np.array(X_measure_mask)

# Create the perturbation anndata
adata_perturb = anndata.AnnData(
    X=X_shift, 
    obs=pd.DataFrame(obs_dict).set_index('id'),
    var=standard_gene_df[[]].copy(),
)
adata_perturb.layers['measure_mask'] = X_measure_mask
adata_perturb.layers['control'] = X_control

adata_perturb.obs['perturbed_gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.obs['perturbed_gene'], 'human', 'human', 'id', 'name')
adata_perturb.var['gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.var.index, 'human', 'human', 'id', 'name')

adata_perturb.write_h5ad(os.path.join(
    '/home/xingjie/Data/data2/cytofuture/datasets/perturbation_data', 
    f'{dataset_id}.h5ad'))

plt.hist(adata_perturb.X.flatten(), bins=100)
adata_perturb