In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os
import json

from tqdm import tqdm
import numpy as np
import scipy.spatial
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc

In [None]:
class_anno_map = {
'0_0' : 'non-targeting enriched',
'10_0' : 'non-targeting like',
'11_0' : 'upregulation of lipid biosynthesis',
'12_0' : 'upregulation of stress response',
'13_0' : 'non-targeting enriched',
'14_0' : 'pert cell cycle',
'14_1' : 'pert spliceosome',
'14_2' : 'pert mRNA-3 processing',
'14_3' : 'pert mRNA transcription',
'14_4' : 'pert mRNA transcription',
'14_5' : 'pert mRNA transcription',
'15_0' : 'germ layer differentiation',
'15_1' : 'germ layer differentiation',
'15_10' : 'germ layer differentiation',
'15_11' : 'germ layer differentiation',
'15_12' : 'germ layer differentiation',
'15_13' : 'germ layer differentiation',
'15_14' : 'mesenchymal differentiation',
'15_2' : 'germ layer differentiation',
'15_3' : 'germ layer differentiation',
'15_4' : 'germ layer differentiation',
'15_5' : 'non-targeting enriched',
'15_6' : 'germ layer differentiation',
'15_7' : 'germ layer differentiation',
'15_8' : 'germ layer differentiation',
'15_9' : 'germ layer differentiation',
'16_0' : 'low UMI count',
'16_1' : 'low UMI count',
'16_2' : 'low UMI count',
'16_3' : 'low UMI count',
'16_4' : 'low UMI count',
'16_5' : 'low UMI count',
'16_6' : 'low UMI count',
'16_7' : 'pert DBR1',
'17_0' : 'non-targeting enriched',
'18_0' : 'pert translation',
'18_1' : 'pert mTOR signaling',
'18_10' : 'pert translation',
'18_11' : 'pert translation',
'18_12' : 'pert translation',
'18_2' : 'pert translation',
'18_3' : 'pert translation',
'18_4' : 'pert translation',
'18_5' : 'pert mTOR signaling',
'18_6' : 'pert translation',
'18_7' : 'pert translation',
'18_8' : 'pert translation',
'18_9' : 'pert translation',
'19_0' : 'non-targeting enriched',
'19_1' : 'non-targeting like',
'1_0' : 'non-targeting enriched',
'20_0' : 'pert mRNA transcription',
'20_1' : 'pert mRNA transcription',
'20_10' : 'pert mRNA deadenylation',
'20_11' : 'pert mRNA transcription',
'20_2' : 'pert GNB2L1',
'20_3' : 'pert mRNA deadenylation',
'20_4' : 'pert mRNA transcription',
'20_5' : 'pert mRNA transcription',
'20_6' : 'pert mRNA transcription',
'20_7' : 'pert mRNA transcription',
'20_8' : 'pert mRNA transcription',
'20_9' : 'pert mRNA transcription',
'21_0' : 'non-targeting enriched',
'21_1' : 'pert DBR1',
'22_0' : 'non-targeting enriched',
'23_0' : 'pert ubiquitin E3 ligase',
'23_1' : 'pert protein neddylation',
'24_0' : 'low mito-genes',
'24_1' : 'low mito-genes',
'24_2' : 'upregulation of stress response',
'25_0' : 'mesenchymal differentiation',
'25_1' : 'mesenchymal differentiation',
'25_2' : 'mesenchymal differentiation',
'25_3' : 'low UMI count',
'25_4' : 'mesenchymal differentiation',
'26_0' : 'pert DBR1',
'27_0' : 'pert RNA methylation',
'28_0' : 'pert DNA damage checkpoint',
'28_1' : 'pert DNA damage checkpoint',
'2_0' : 'non-targeting enriched',
'3_0' : 'non-targeting enriched',
'4_0' : 'non-targeting enriched',
'5_0' : 'non-targeting enriched',
'6_0' : 'non-targeting enriched',
'7_0' : 'non-targeting enriched',
'8_0' : 'non-targeting enriched',
'9_0' : 'non-targeting enriched',
'9_1' : 'non-targeting enriched',
}

cluster_annotation_map = {
    k : class_anno_map[k] + '_' + k for k in class_anno_map.keys()
}

In [None]:
adata = sc.read_h5ad('/GPUData_xingjie/SCMG/hESC_perturb_seq/adata_single_gene_pert.h5ad')
adata

In [None]:
l2_obs_df = pd.read_csv('adata_obs_l2.csv', index_col=0)
adata.obs['leiden_l1'] = l2_obs_df['leiden_l1'].astype(str)
adata.obs['cluster'] = l2_obs_df['cluster'].astype(str)
adata.obsm['X_umap'] = l2_obs_df.loc[adata.obs.index][['umap_x', 'umap_y']].values
adata

In [None]:
adata.obs['sc_cluster_name'] = adata.obs['cluster'].map(cluster_annotation_map)
adata

In [None]:
bulk_cluster_df = pd.read_csv('../pseudo_bulk_analysis/clustering/perturbed_gene_clusters_hESC.csv', index_col=0)
bulk_cluster_map = {k : v for k, v in zip(bulk_cluster_df['perturbed_gene_name'], bulk_cluster_df['leiden'])}
bulk_cluster_df

In [None]:
adata.obs['bulk_cluster_name'] = adata.obs['perturbed_gene'].map(bulk_cluster_map)
adata

In [None]:
adata.obs

In [None]:
local_df = adata.obs[~adata.obs['bulk_cluster_name'].isna()]

conf_df = pd.crosstab(local_df['bulk_cluster_name'].astype(int), local_df['sc_cluster_name'])

conf_frac_df = conf_df / conf_df.sum(axis=0)
conf_frac_df = conf_df / conf_df.values.sum(axis=1)[:, None]

In [None]:
row_probs = conf_df.sum(axis=1) / conf_df.sum().sum()
col_probs = conf_df.sum(axis=0) / conf_df.sum().sum()

exp_df = np.outer(row_probs, col_probs) * conf_df.sum().sum()
enrichment_df = conf_df / exp_df

In [None]:
bulk_cluster_order = [ 
     0, 30, 28, 29, 7, 6, 21, 22, 23, 11, 24, 17, 34, 32,
     27, 4,  12, 25, 14, 2, 15, 26, 33, 13, 3, 16, 18, 8, 5,
       
       1,        9, 10,
       19, 20,31, 
       ]

sc_cluster_order = [
    'germ layer differentiation_15_0', 'germ layer differentiation_15_1',
       'germ layer differentiation_15_10', 'germ layer differentiation_15_11',
       'germ layer differentiation_15_12', 'germ layer differentiation_15_13',
       'germ layer differentiation_15_2', 'germ layer differentiation_15_3',
       'germ layer differentiation_15_4', 'germ layer differentiation_15_6',
       'germ layer differentiation_15_7', 'germ layer differentiation_15_8',
       'germ layer differentiation_15_9', 
       'low UMI count_16_6', 'non-targeting enriched_15_5',
       
    'upregulation of lipid biosynthesis_11_0',
    'upregulation of stress response_12_0',
    

    'low UMI count_16_0',
    'low UMI count_16_1', 'low UMI count_16_2', 'low UMI count_16_3',
    'low UMI count_16_4', 'low UMI count_16_5', 
    'upregulation of stress response_24_2',
    'low mito-genes_24_0', 'low mito-genes_24_1', 'low UMI count_25_3', 
    'mesenchymal differentiation_15_14', 'mesenchymal differentiation_25_0',
    'mesenchymal differentiation_25_1', 'mesenchymal differentiation_25_2',
    'mesenchymal differentiation_25_4', 
    
    

       'pert mTOR signaling_18_5', 
       'pert mTOR signaling_18_1', 'pert translation_18_0',
       'pert translation_18_2',
       'pert translation_18_12',  'pert translation_18_4',
       'pert translation_18_6', 'pert translation_18_7',
       'pert translation_18_9',

       'pert translation_18_3', 'pert translation_18_8', 
       'pert translation_18_10',
       'pert mRNA transcription_20_11', 'pert mRNA transcription_20_4',
       'pert mRNA transcription_20_6', 'pert mRNA transcription_20_7',
       'pert GNB2L1_20_2',
       'pert mRNA transcription_20_9',

       'pert mRNA transcription_20_0', 'pert mRNA transcription_20_1',  
       'pert mRNA transcription_20_8', 
       'pert mRNA transcription_14_5', 'pert mRNA transcription_20_5',

       'pert mRNA transcription_14_3', 'pert mRNA transcription_14_4',
       'pert spliceosome_14_1', 

       'pert RNA methylation_27_0', 
       'pert mRNA-3 processing_14_2', 

       'pert mRNA deadenylation_20_10', 'pert mRNA deadenylation_20_3',
       'pert DBR1_16_7',
       'pert DBR1_21_1', 'pert DBR1_26_0', 'pert DNA damage checkpoint_28_0',
       'pert DNA damage checkpoint_28_1',
       
       'pert protein neddylation_23_1', 'pert ubiquitin E3 ligase_23_0',

       'pert cell cycle_14_0',
       'pert translation_18_11',
       
       'non-targeting like_19_1', 'non-targeting like_10_0', 
       'non-targeting enriched_0_0',
       'non-targeting enriched_13_0', 
       'non-targeting enriched_17_0', 'non-targeting enriched_19_0',
       'non-targeting enriched_1_0', 'non-targeting enriched_21_0',
       'non-targeting enriched_22_0', 'non-targeting enriched_2_0',
       'non-targeting enriched_3_0', 'non-targeting enriched_4_0',
       'non-targeting enriched_5_0', 'non-targeting enriched_6_0',
       'non-targeting enriched_7_0', 'non-targeting enriched_8_0',
       'non-targeting enriched_9_0', 'non-targeting enriched_9_1',
       
       ]

enrichment_df = enrichment_df.loc[bulk_cluster_order, sc_cluster_order]

In [None]:
fig, ax = plt.subplots(figsize=(25, 10))
g = sns.heatmap(enrichment_df, cmap='seismic', cbar_kws={'label': 'enrichment'}, ax=ax, vmax=15, center=1)
g.set_yticklabels(g.get_yticklabels(), rotation=0)
fig.savefig('hesc_sc_analysis_plots/bulk_sc_clusters_enrichment.pdf')