In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os
import json

from tqdm import tqdm
import numpy as np
import scipy.spatial
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc

In [None]:
import scipy.stats
import statsmodels.stats.multitest

def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential_bh = statsmodels.stats.multitest.multipletests(
        p_val_mtx.reshape(-1), method='fdr_bh')[1]

            
    return p_val_sequential_bh.reshape(p_val_mtx.shape)

In [None]:
obs_df = pd.read_csv('adata_obs_l2.csv', index_col=0)
print(len(obs_df))

# Keep cells with enough UMI
obs_df = obs_df[(obs_df['num_umis'] > 20)]
obs_df

In [None]:
# Keep the perturbations with enough cells
pg_counts = obs_df['feature_call'].value_counts()
obs_df = obs_df[obs_df['feature_call'].isin(pg_counts[pg_counts > 20].index)].copy()
obs_df

In [None]:
class_anno_map = {
'0_0' : 'non-targeting enriched',
'10_0' : 'non-targeting like',
'11_0' : 'upregulation of cholesterol biosynthesis',
'12_0' : 'upregulation of stress response',
'13_0' : 'non-targeting enriched',
'14_0' : 'pert cell cycle',
'14_1' : 'pert spliceosome',
'14_2' : 'pert mRNA-3 processing',
'14_3' : 'pert mRNA transcription',
'14_4' : 'pert mRNA transcription',
'14_5' : 'pert mRNA transcription',
'15_0' : 'germ layer differentiation',
'15_1' : 'germ layer differentiation',
'15_10' : 'germ layer differentiation',
'15_11' : 'germ layer differentiation',
'15_12' : 'germ layer differentiation',
'15_13' : 'germ layer differentiation',
'15_14' : 'mesenchymal differentiation',
'15_2' : 'germ layer differentiation',
'15_3' : 'germ layer differentiation',
'15_4' : 'germ layer differentiation',
'15_5' : 'non-targeting enriched',
'15_6' : 'germ layer differentiation',
'15_7' : 'germ layer differentiation',
'15_8' : 'germ layer differentiation',
'15_9' : 'germ layer differentiation',
'16_0' : 'low UMI count',
'16_1' : 'low UMI count',
'16_2' : 'low UMI count',
'16_3' : 'low UMI count',
'16_4' : 'low UMI count',
'16_5' : 'low UMI count',
'16_6' : 'low UMI count',
'16_7' : 'pert DBR1',
'17_0' : 'non-targeting enriched',
'18_0' : 'pert translation',
'18_1' : 'pert mTOR signaling',
'18_10' : 'pert translation',
'18_11' : 'pert translation',
'18_12' : 'pert translation',
'18_2' : 'pert translation',
'18_3' : 'pert translation',
'18_4' : 'pert translation',
'18_5' : 'pert mTOR signaling',
'18_6' : 'pert translation',
'18_7' : 'pert translation',
'18_8' : 'pert translation',
'18_9' : 'pert translation',
'19_0' : 'non-targeting enriched',
'19_1' : 'non-targeting like',
'1_0' : 'non-targeting enriched',
'20_0' : 'pert mRNA transcription',
'20_1' : 'pert mRNA transcription',
'20_10' : 'pert mRNA deadenylation',
'20_11' : 'pert mRNA transcription',
'20_2' : 'pert GNB2L1',
'20_3' : 'pert mRNA deadenylation',
'20_4' : 'pert mRNA transcription',
'20_5' : 'pert mRNA transcription',
'20_6' : 'pert mRNA transcription',
'20_7' : 'pert mRNA transcription',
'20_8' : 'pert mRNA transcription',
'20_9' : 'pert mRNA transcription',
'21_0' : 'non-targeting enriched',
'21_1' : 'pert DBR1',
'22_0' : 'non-targeting enriched',
'23_0' : 'pert ubiquitin E3 ligase',
'23_1' : 'pert protein neddylation',
'24_0' : 'low mito-genes',
'24_1' : 'low mito-genes',
'24_2' : 'upregulation of stress response',
'25_0' : 'mesenchymal differentiation',
'25_1' : 'mesenchymal differentiation',
'25_2' : 'mesenchymal differentiation',
'25_3' : 'low UMI count',
'25_4' : 'mesenchymal differentiation',
'26_0' : 'pert DBR1',
'27_0' : 'pert RNA methylation',
'28_0' : 'pert DNA damage checkpoint',
'28_1' : 'pert DNA damage checkpoint',
'2_0' : 'non-targeting enriched',
'3_0' : 'non-targeting enriched',
'4_0' : 'non-targeting enriched',
'5_0' : 'non-targeting enriched',
'6_0' : 'non-targeting enriched',
'7_0' : 'non-targeting enriched',
'8_0' : 'non-targeting enriched',
'9_0' : 'non-targeting enriched',
'9_1' : 'non-targeting enriched',
}

obs_df['class_anno'] = obs_df['cluster'].map(class_anno_map)

In [None]:
output_path = 'enrichment_results'
os.makedirs(output_path, exist_ok=True)
all_clusters = np.unique(obs_df['class_anno'])

positive_count_df = pd.DataFrame(
    index=all_clusters,
    columns=np.unique(obs_df['feature_call']),
    dtype=int) 
log2fc_df = pd.DataFrame(
    index=all_clusters,
    columns=np.unique(obs_df['feature_call']),
    dtype=float)
pval_df = log2fc_df.copy()

for cluster_of_interest in log2fc_df.index:
    print(cluster_of_interest)

    for pg_of_interest in tqdm(log2fc_df.columns):
        contigency_table = pd.crosstab(obs_df['class_anno'] == cluster_of_interest, 
                            obs_df['feature_call'] == pg_of_interest)
        if contigency_table.shape != (2, 2):
            print(contigency_table.shape)
            continue

        results = scipy.stats.chi2_contingency(contigency_table)
        pval_df.loc[cluster_of_interest, pg_of_interest] = results[1] 

        positive_count_df.loc[cluster_of_interest, pg_of_interest] = contigency_table.values[1, 1]
        log2fc_df.loc[cluster_of_interest, pg_of_interest] = np.log2(
            contigency_table.values[1, 1] / results.expected_freq[1, 1] + 1e-6)

positive_count_df.to_parquet(os.path.join(output_path, 'class_anno_enrich_guide_positive_count.parquet'))
log2fc_df.to_parquet(os.path.join(output_path, 'class_anno_enrich_guide_log2fc.parquet'))
pval_df.to_parquet(os.path.join(output_path, 'class_anno_enrich_guide_pval.parquet'))    

In [None]:
guide_to_gene_map = {guide : gene for guide, gene in zip(obs_df['feature_call'], obs_df['perturbed_gene'])}

cep_dict = {
    'class_anno': [],
    'feature_call': [],
    'perturbed_gene': [],
    'positive_count': [],
    'log2fc': [],
    'pval': []
}

for cluster in positive_count_df.index:
    for pg in positive_count_df.columns:
        cep_dict['class_anno'].append(cluster)
        cep_dict['feature_call'].append(pg)
        cep_dict['perturbed_gene'].append(guide_to_gene_map[pg])
        cep_dict['positive_count'].append(positive_count_df.loc[cluster, pg])
        cep_dict['log2fc'].append(log2fc_df.loc[cluster, pg])
        cep_dict['pval'].append(pval_df.loc[cluster, pg])

cep_df = pd.DataFrame(cep_dict)
cep_df['pval_adj'] = statsmodels.stats.multitest.multipletests(
        cep_df['pval'].values, method='fdr_bh')[1]

cep_df.to_parquet(os.path.join(output_path, 'class_anno_enrich_guide.parquet'))

cep_df