In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
from tqdm.auto import tqdm

import anndata
import scanpy as sc

from cytofuture_data.gene_name_mapping import GeneNameMapper

In [None]:
# Load the gene name mapper
gene_name_mapper = GeneNameMapper(
    '../../standard_genes/gene_names/human_genes.csv',
    '../../standard_genes/gene_names/mouse_genes.csv',
    '../../standard_genes/gene_names/orthologue_map_human2mouse_best.csv',
    '../../standard_genes/gene_names/orthologue_map_mouse2human_best.csv'
)

In [None]:
data_path = '/home/xingjie/Data/data2/cytofuture/datasets/perturb-seq'
dataset_id = 'JiangSatija2024_INS'
adata = sc.read_h5ad(os.path.join('/home/xingjie/Data/data2/cytofuture/datasets/perturb-seq',
                                  'raw_data', 'JiangSatija2024', 'Seurat_object_INS_Perturb_seq.h5ad'))
adata = adata.raw.to_adata()
adata

In [None]:
adata.obs['condition'] = [r[1]['cell_type'] + '_' + r[1]['gene'] for r in adata.obs.iterrows()]

# Set the gene IDs to Ensembl IDs
adata.var['ensembl_id'] = gene_name_mapper.map_gene_names(adata.var.index, 'human', 'human', 'name', 'id')
adata.var.index = adata.var['ensembl_id']
adata = adata[:, adata.var.index.str.startswith('ENSG')].copy()

sc.pp.normalize_total(adata, target_sum=1e4)
adata

In [None]:
# Print the conditions
condition_col = 'condition'
condition_cell_counts = adata.obs[condition_col].value_counts()
for c in condition_cell_counts.index:
    print(f'{c}\t{condition_cell_counts.loc[c]}')

In [None]:
# Get the data frame of perturbations
n_cell_threshold = 50
perturbation_dict = {
    'condition' : [],
    'perturbed_gene' : [],
    'control_condition': [],
}

for c in condition_cell_counts.index:
    if condition_cell_counts.loc[c] < n_cell_threshold:
        continue
    
    perturbed_genes = gene_name_mapper.map_gene_names([c.split('_')[1]], 
                                'human', 'human', 'name', 'id')
    perturbed_genes = np.intersect1d(perturbed_genes, adata.var.index)
    
    # Only consider single gene perturbations
    if len(perturbed_genes) != 1:
        continue
    pg = perturbed_genes[0]
    control_condition = c.split('_')[0] + '_NT'

    # Ignore condtions with too few cells
    n_cells = int(condition_cell_counts.loc[c])
    if n_cells < n_cell_threshold:
        continue

    control_exps = adata[adata.obs[condition_col] == control_condition, pg
                         ].X.toarray().reshape(-1)
    perturbed_exps = adata[adata.obs[condition_col] == c, pg
                           ].X.toarray().reshape(-1)
    
    # Calculate the fold-change and p-val for the perturbed gene
    fc = np.mean(perturbed_exps) / np.mean(control_exps)
    pval = scipy.stats.mannwhitneyu(control_exps, perturbed_exps, alternative='two-sided')[1]

    # Only keep the perturbations that are effective
    if (fc < 0.8) and (pval < 1e-2):
        perturbation_dict['condition'].append(c)
        perturbation_dict['perturbed_gene'].append(pg)
        perturbation_dict['control_condition'].append(control_condition)
        
        print(f'{c}\t{n_cells}\t{pg}\t{fc:.2f}\t{pval:.2e}')

perturbation_df = pd.DataFrame(perturbation_dict).set_index('condition')
perturbation_df

In [None]:
def get_cluster_mean_expression_matrix(adata, cluster_column):
    '''Get a dataframe of mean gene expression of each cluster.'''
    cell_exp_mtx = pd.DataFrame(adata.X.toarray(), index=adata.obs[cluster_column], columns=adata.var.index)    
    cluster_mean_exp = cell_exp_mtx.groupby(by=cluster_column).mean()
    return cluster_mean_exp

def get_cluster_mean_expression_matrix_low_mem(adata, cluster_column):
    '''Get a dataframe of mean gene expression of each cluster.'''
    cluster_names = np.unique(adata.obs[cluster_column].values)
    cluster_mean_df = pd.DataFrame(np.zeros((len(cluster_names), adata.shape[1]), dtype=np.float32), 
                                   index=cluster_names, columns=adata.var.index)
    
    for c in tqdm(cluster_names):
        X_c = adata[adata.obs[cluster_column] == c].X
        cluster_mean_df.loc[c] = X_c.mean(axis=0)
    
    return cluster_mean_df

In [None]:
# Get mean expressions of each condition
selected_conditions = np.unique(list(perturbation_df.index) 
                                + list(perturbation_df['control_condition']))
adata_selected = adata[adata.obs[condition_col].isin(selected_conditions)]

condition_mean_exp_df = get_cluster_mean_expression_matrix_low_mem(adata_selected, 
                                                                   condition_col)
condition_mean_exp_df


In [None]:
standard_gene_df = pd.read_csv('standard_genes.csv').set_index('human_id')

obs_dict = {
    'id' : [],
    'condition' : [],
    'perturbed_gene' : [],
    'perturbation_sign' : [],
}

X_perturb = []
X_control = []
X_measure_mask = []

# Get the perturbation data
for c in tqdm(perturbation_df.index):
    obs_dict['id'].append(f'{dataset_id}_{c}')
    obs_dict['condition'].append(dataset_id)
    obs_dict['perturbed_gene'].append(perturbation_df.loc[c, 'perturbed_gene'])
    obs_dict['perturbation_sign'].append(-1) # -1 for knockdown, 1 for overexpression
    
    control_c = perturbation_df.loc[c, 'control_condition']

    standard_gene_df['perturb'] = condition_mean_exp_df.loc[c]
    standard_gene_df['control'] = condition_mean_exp_df.loc[control_c]
    standard_gene_df['measure_mask'] = ~standard_gene_df['perturb'].isna()

    X_perturb.append(np.nan_to_num(standard_gene_df['perturb'].values, 
                           nan=0).astype(np.float32))
    X_control.append(np.nan_to_num(standard_gene_df['control'].values, 
                           nan=0).astype(np.float32))
    X_measure_mask.append(standard_gene_df['measure_mask'].values.astype(np.float32))

X_perturb = np.array(X_perturb)
X_control = np.array(X_control)
X_perturb = np.log1p(X_perturb / X_perturb.sum(axis=1, keepdims=True) * 1e4)
X_control = np.log1p(X_control / X_control.sum(axis=1, keepdims=True) * 1e4)
X_shift = X_perturb - X_control

X_measure_mask = np.array(X_measure_mask)

# Create the perturbation anndata
adata_perturb = anndata.AnnData(
    X=X_shift, 
    obs=pd.DataFrame(obs_dict).set_index('id'),
    var=standard_gene_df[[]].copy(),
)
adata_perturb.layers['measure_mask'] = X_measure_mask
adata_perturb.layers['control'] = X_control

adata_perturb.obs['perturbed_gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.obs['perturbed_gene'], 'human', 'human', 'id', 'name')
adata_perturb.var['gene_name'] = gene_name_mapper.map_gene_names(
    adata_perturb.var.index, 'human', 'human', 'id', 'name')

adata_perturb.write_h5ad(os.path.join(
    '/home/xingjie/Data/data2/cytofuture/datasets/perturbation_data', 
    f'{dataset_id}.h5ad'))

plt.hist(adata_perturb.X.flatten(), bins=100)
adata_perturb