In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc
import umap
import scipy.stats

import torch

from scmg.model.contrastive_embedding import (CellEmbedder,  embed_adata)

from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
plot_output_path = 'pathway_rewiring_plots/'
os.makedirs(plot_output_path, exist_ok=True)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
pert_data_files = [
#    '/GPUData_xingjie/SCMG/perturbation_data/AdamsonWeissman2016_GSM2406681_10X010.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/FrangiehIzar2021_RNA.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/hESC_TF_screen.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNB.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNG.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_INS.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TGFB.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TNFA.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/Joung_TFScreen_HS_2023.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_human.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_mouse.h5ad',
#    #'/GPUData_xingjie/SCMG/perturbation_data/omnipath.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/PertOrg.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_essential.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_gwps.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_rpe1.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRa.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRi.h5ad',
    '/GPUData_xingjie/SCMG/hESC_perturb_seq/pseudo_bulk_filtered.h5ad', # Test
]

adata_pert_list = []
for pdf in pert_data_files:
    adata_pert_list.append(sc.read_h5ad(pdf))
    print(os.path.basename(pdf), adata_pert_list[-1].shape[0])

adata_pert = anndata.concat(adata_pert_list, axis=0)
adata_pert.var['gene_name'] = adata_pert_list[0].var['gene_name']

adata_pert

In [None]:
# Mask out the direct target genes
for i in range(adata_pert.shape[0]):
    pg = adata_pert.obs['perturbed_gene'].iloc[i]
    
    if pg in adata_pert.var_names:
        adata_pert.X[i, adata_pert.var_names.get_loc(pg)] = 0

adata_pert.obs['effect_size'] = np.linalg.norm(adata_pert.X, axis=1)
adata_pert.obs['max_effect'] = np.max(np.abs(adata_pert.X), axis=1)

In [None]:
gene_exp_cluster_df = pd.read_csv('../../classify_genes/systematic_classification/gene_exp_cluster_annotation.csv', index_col=0)
adata_pert.var['gene_exp_cluster'] = gene_exp_cluster_df['cluster_name']
gene_exp_cluster_counts = adata_pert.var['gene_exp_cluster'].value_counts()
gene_exp_cluster_counts

In [None]:
np.array(adata_pert.var[adata_pert.var['gene_exp_cluster'] == 'integrated stress response']['gene_name'])

In [None]:
adata_pert.uns['gene_program_names'] = list(gene_exp_cluster_df['cluster_name'].unique())
adata_pert.obsm['gene_program'] = np.zeros((adata_pert.shape[0], len(adata_pert.uns['gene_program_names'])))

for i in range(len(adata_pert.uns['gene_program_names'])):
    genes = list(gene_exp_cluster_df[gene_exp_cluster_df['cluster_name'] == adata_pert.uns['gene_program_names'][i]].index)
    if len(genes) == 0:
        continue
    scale_factor = np.sqrt(len(genes))
    adata_pert.obsm['gene_program'][:, i] = scale_factor * np.mean(adata_pert[:, genes].X, axis=1)

In [None]:
gene_func_emb_df = pd.read_parquet('gene_func_emb_no_hesc.parquet')
adata_pert = adata_pert[adata_pert.obs['perturbed_gene_name'].isin(gene_func_emb_df.index)].copy()
adata_pert.obsm['gene_func_emb'] = gene_func_emb_df.loc[adata_pert.obs['perturbed_gene_name']].values

In [None]:
adata_pert.obs['condition'].value_counts()

In [None]:
adata_k562 = adata_pert[adata_pert.obs['condition'] == 'ReplogleWeissman2022_K562_gwps'].copy()
adata_k562_e = adata_pert[adata_pert.obs['condition'] == 'ReplogleWeissman2022_K562_essential'].copy()
adata_rpe1 = adata_pert[adata_pert.obs['condition'] == 'ReplogleWeissman2022_rpe1'].copy()
adata_hesc = adata_pert[adata_pert.obs['condition'] == 'hESC_perturb_seq'].copy()

common_genes = np.intersect1d(adata_k562_e.obs['perturbed_gene'], adata_k562.obs['perturbed_gene'])
len(common_genes)

In [None]:
adata_k562.obs.index = list(adata_k562.obs['perturbed_gene'])
adata_k562 = adata_k562[common_genes, :].copy()
adata_k562_e.obs.index = list(adata_k562_e.obs['perturbed_gene'])
adata_k562_e = adata_k562_e[common_genes, :].copy()

adata_rpe1.obs.index = list(adata_rpe1.obs['perturbed_gene'])
adata_hesc.obs.index = list(adata_hesc.obs['perturbed_gene'])

In [None]:
adata_k562.obs['k562_e_corr'] = [
    scipy.stats.pearsonr(
        adata_k562.X[i], adata_k562_e.X[i]
    )[0] for i in range(adata_k562.shape[0])
]

adata_k562.obs['k562_e_corr'].hist(bins=50)

In [None]:
high_k562_intra_corr_genes = adata_k562.obs[adata_k562.obs['k562_e_corr'] > 0.4].index.tolist()
print(len(high_k562_intra_corr_genes))

adata_k562 = adata_k562[high_k562_intra_corr_genes, :].copy()
adata_k562_e = adata_k562_e[high_k562_intra_corr_genes, :].copy()

In [None]:
sc.pp.neighbors(adata_k562, n_neighbors=5, use_rep='gene_func_emb', 
    metric='cosine',
    #metric='euclidean',
)
sc.tl.umap(adata_k562, random_state=0)
sc.tl.leiden(adata_k562, flavor="igraph", n_iterations=10, resolution=5, random_state=0)

In [None]:
#adata_k562.write_h5ad('adata_function_cluster.h5ad')
adata_k562 = sc.read_h5ad('adata_function_cluster.h5ad')
adata_k562

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=200)

sc.pl.umap(adata_k562, color='leiden', legend_loc='on data', ax=ax,
           legend_fontsize=10, legend_fontoutline=2, s=20,
           show=False)

In [None]:
np.array(adata_k562.obs[adata_k562.obs['leiden'] == '45']['perturbed_gene_name'])

In [None]:
pg_name_map = {
    '0' : 'ribosome large unit biogenesis',
    '1' : 'mixed functions',
    '2' : 'histone acetylation',
    '3' : 'rRNA processing 1',
    '4' : 'ribosome large unit protein',
    '5' : 'ribosome small unit biogenesis',
    '6' : 'proteasome',
    '7' : 'chromatin organization',
    '8' : 'DNA replication',
    '9' : 'mitochondrion organization',
    '10' : 'protein neddylation',
    '11' : 'mixed functions',
    '12' : 'pol II transcription',
    '13' : 'translation initiation',
    '14' : 'mixed functions',
    '15' : 'mixed functions',
    '16' : 'mediator complex',
    '17' : 'rRNA processing 2',
    '18' : 'protein processing in ER',
    '19' : 'pol II transcription',
    '20' : 'ribosome large unit biogenesis',
    '21' : 'DNA replication',
    '22' : 'RNA metabolism',
    '23' : 'mitochondrial translation',
    '24' : 'mixed functions',
    '25' : 'ribosome small unit biogenesis',
    '26' : 'spliceosome',
    '27' : 'transcription initiation',
    '28' : 'CCT complex',
    '29' : 'spliceosome',
    '30' : 'mitochondrion organization',
    '31' : 'rRNA processing 2',
    '32' : 'Recycling Of eIF2:GDP',
    '33' : 'mediator complex',
    '34' : 'membrane fission',
    '35' : 'histone acetylation',
    '36' : 'mitochondrial transcription',
    '37' : 'spliceosome',
    '38' : 'integrator complex',
    '39' : 'ribosome large unit protein',
    '40' : 'pol II enlongation',
    '41' : 'mRNA surveillance',
    '42' : 'RNA degradation',
    '43' : 'spliceosome',
    '44' : 'rRNA processing 1',
    '45' : 'ribosome small unit protein',
    '46' : 'RNA methylation',
    '47' : 'mitochondrial translation',
    '48' : 'spliceosome',
    '49' : 'mRNA Polyadenylation',
}

adata_k562.obs['cluster_name'] = adata_k562.obs['leiden'].map(pg_name_map)
adata_k562.write_h5ad('adata_function_cluster_annotated.h5ad')

In [None]:
fig, ax = plt.subplots(figsize=(5, 5), dpi=200)

sc.pl.umap(adata_k562, color='cluster_name', 
        legend_loc='on data', legend_fontsize=8, legend_fontoutline=2, 
        s=100, ax=ax, title='gene function cluster',
         show=False)

In [None]:
adata_anno = adata_k562[adata_k562.obs['cluster_name'] != 'mixed functions'].copy()

sc.pp.neighbors(adata_anno, n_neighbors=10, use_rep='gene_func_emb', 
    metric='cosine',
    #metric='correlation',
    #metric='euclidean',
)
sc.tl.umap(adata_anno, random_state=0)

fig, ax = plt.subplots(figsize=(8, 8), dpi=100)
sc.pl.umap(adata_anno, color='cluster_name', legend_loc='on data', ax=ax, 
           legend_fontsize=10, legend_fontoutline=2)
fig.savefig(f'{plot_output_path}/k562_gene_function_clusters_umap.pdf')

In [None]:
adata_anno.obs[['perturbed_gene_name', 'cluster_name']].to_csv('perturbation_cluster_annotation.csv')

In [None]:
pg_id_to_cluster_map = dict(zip(adata_anno.obs.index, adata_anno.obs['cluster_name']))

In [None]:
adata_k562.obs['cluster_name'] = adata_k562.obs.index.map(pg_id_to_cluster_map)
k562_group_df = pd.DataFrame(
    index=list(adata_anno.obs['cluster_name'].unique()),
    columns=adata_k562.uns['gene_program_names'],
    dtype=float,
)

for pg_cluster in k562_group_df.index:
    pg_genes = list(adata_k562.obs[adata_k562.obs['cluster_name'] == pg_cluster].index)
    k562_group_df.loc[pg_cluster] = adata_k562[pg_genes].obsm['gene_program'].mean(axis=0)


adata_rpe1.obs['cluster_name'] = adata_rpe1.obs.index.map(pg_id_to_cluster_map)
rpe1_group_df = pd.DataFrame(
    index=list(adata_anno.obs['cluster_name'].unique()),
    columns=adata_rpe1.uns['gene_program_names'],
    dtype=float,
)

for pg_cluster in rpe1_group_df.index:
    pg_genes = list(adata_rpe1.obs[adata_rpe1.obs['cluster_name'] == pg_cluster].index)
    rpe1_group_df.loc[pg_cluster] = adata_rpe1[pg_genes].obsm['gene_program'].mean(axis=0)

adata_hesc.obs['cluster_name'] = adata_hesc.obs.index.map(pg_id_to_cluster_map)
hesc_group_df = pd.DataFrame(
    index=list(adata_anno.obs['cluster_name'].unique()),
    columns=adata_hesc.uns['gene_program_names'],
    dtype=float,
)

for pg_cluster in hesc_group_df.index:
    pg_genes = list(adata_hesc.obs[adata_hesc.obs['cluster_name'] == pg_cluster].index)
    hesc_group_df.loc[pg_cluster] = adata_hesc[pg_genes].obsm['gene_program'].mean(axis=0)

In [None]:
pg_cluster_order = [
    'chromatin organization', 'DNA replication', 'histone acetylation', 'CCT complex', 
    'mediator complex', 'transcription initiation', 'integrator complex', 'pol II transcription', 'pol II enlongation', 
    'spliceosome', 'RNA metabolism', 'RNA degradation', 'RNA methylation', 'mRNA Polyadenylation', 'mRNA surveillance', 


    'ribosome small unit biogenesis', 
    'ribosome small unit protein', 'rRNA processing 1', 'rRNA processing 2', 
    'ribosome large unit biogenesis', 'ribosome large unit protein',
    'translation initiation', 'Recycling Of eIF2:GDP', 

    'protein neddylation', 'proteasome',
    'protein processing in ER', 'membrane fission',

       'mitochondrion organization', 
       
       
       'mitochondrial translation', 
       
       'mitochondrial transcription', 
       ]

gene_program_order = [
    'chromatin structure', 'DNA replication/repair', 'cell cycle (G1/S)', 'cell cycle (prometaphase)', 'cell cycle (M phase)', 
    'spliceosome', 'proliferation',
    'ribosome biogenesis', 'ribosomal protein genes', 'mitochondrial encoded', 'cholesterol biosynthesis', 'mitochondrial ribosome', 
    'Golgi vesicle transport', 'unfolded protein response', 
    'integrated stress response', 'p53 signaling', 'lysosome/autophagy',  

       
#    'pluripotency', 'Hox genes', 'glia', 'neural development', 'neuronal', 'peripheral neurons', 
#    'visual perception', 'retinal epithelium', 'melanin biosynthesis', 
#    'epithelial', 'respiratory epithelium', 'kidney', 'intestine', 'pancreatic', 'pancreatic islet', 'liver', 'epidermal', 
#    'mesothelial', 'adipocyte',  'mesenchymal', 'smooth muscle', 'endothelial', 'bone', 
#
#    'interferon signaling', 'TNF signaling', 'immune system', 'myeloid', 'macrophage', 'B cell', 'T cell', 'natural killer',
#    'mast cell', 'erythroid', 'megakaryocyte',
#       
#    'muscle', 'heart', 'cilia', 
       
    ]


fig, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(k562_group_df.loc[pg_cluster_order, gene_program_order],
            cmap='seismic', center=0, cbar_kws={'label': 'mean gene program shift'},
            #xticklabels=False, yticklabels=False,
            vmax=2, vmin=-2,
            ax=ax
            )
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Functional clusters')
ax.set_title('K562')
fig.savefig(f'{plot_output_path}/k562_gene_program_shift_heatmap.pdf')

fig, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(rpe1_group_df.loc[pg_cluster_order, gene_program_order],
            cmap='seismic', center=0, cbar_kws={'label': 'mean gene program shift'},
            #xticklabels=False, yticklabels=False,
            vmax=2, vmin=-2,
            ax=ax
            )
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Functional clusters')
ax.set_title('RPE1')
fig.savefig(f'{plot_output_path}/rpe1_gene_program_shift_heatmap.pdf')

fig, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(hesc_group_df.loc[pg_cluster_order, gene_program_order],
            cmap='seismic', center=0, cbar_kws={'label': 'mean gene program shift'},
            #xticklabels=False, yticklabels=False,
            vmax=1, vmin=-1,
            ax=ax
            )
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Functional clusters')
ax.set_title('hESC')
fig.savefig(f'{plot_output_path}/hesc_gene_program_shift_heatmap.pdf')

In [None]:
genes_to_show = [
    'ATF5', 'DDIT3', 'GDF15', 'XBP1', 'HSPA5', 'PDIA4', 'ATF4', 'ATF3', 'GARS1', 'EPRS1', 'SARS1', 'WARS1', 'MYC'
]

genes_to_show = [g for g in genes_to_show if g in list(adata_k562.var['gene_name'])]

adata_k562.obs['cluster_name'] = adata_k562.obs.index.map(pg_id_to_cluster_map)
k562_pathway_gene_df = pd.DataFrame(
    index=list(adata_anno.obs['cluster_name'].unique()),
    columns=genes_to_show,
    dtype=float,
)

for pg_cluster in k562_pathway_gene_df.index:
    pg_genes = list(adata_k562.obs[adata_k562.obs['cluster_name'] == pg_cluster].index)
    gene_indices = [list(adata_k562.var['gene_name']).index(gene) for gene in genes_to_show]

    k562_pathway_gene_df.loc[pg_cluster] = adata_k562[pg_genes].X[:, gene_indices].mean(axis=0)


adata_rpe1.obs['cluster_name'] = adata_rpe1.obs.index.map(pg_id_to_cluster_map)
rpe1_pathway_gene_df = pd.DataFrame(
    index=list(adata_anno.obs['cluster_name'].unique()),
    columns=genes_to_show,
    dtype=float,
)

for pg_cluster in rpe1_pathway_gene_df.index:
    pg_genes = list(adata_rpe1.obs[adata_rpe1.obs['cluster_name'] == pg_cluster].index)
    gene_indices = [list(adata_rpe1.var['gene_name']).index(gene) for gene in genes_to_show]

    rpe1_pathway_gene_df.loc[pg_cluster] = adata_rpe1[pg_genes].X[:, gene_indices].mean(axis=0) 


adata_hesc.obs['cluster_name'] = adata_hesc.obs.index.map(pg_id_to_cluster_map)
hesc_pathway_gene_df = pd.DataFrame(
    index=list(adata_anno.obs['cluster_name'].unique()),
    columns=genes_to_show,
    dtype=float,
)

for pg_cluster in hesc_pathway_gene_df.index:
    pg_genes = list(adata_hesc.obs[adata_hesc.obs['cluster_name'] == pg_cluster].index)
    gene_indices = [list(adata_hesc.var['gene_name']).index(gene) for gene in genes_to_show]

    hesc_pathway_gene_df.loc[pg_cluster] = adata_hesc[pg_genes].X[:, gene_indices].mean(axis=0)

In [None]:
fig, ax = plt.subplots(figsize=(15, 8))
sns.heatmap(k562_pathway_gene_df.loc[pg_cluster_order],
            cmap='seismic', center=0, cbar_kws={'label': 'mean gene program shift'},
            #xticklabels=False, yticklabels=False,
            vmax=0.5, vmin=-0.5,
            ax=ax
            )
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Functional clusters')
ax.set_title('K562')

fig, ax = plt.subplots(figsize=(15, 8))
sns.heatmap(rpe1_pathway_gene_df.loc[pg_cluster_order],
            cmap='seismic', center=0, cbar_kws={'label': 'mean gene program shift'},
            #xticklabels=False, yticklabels=False,
            vmax=0.5, vmin=-0.5,
            ax=ax
            )
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Functional clusters')
ax.set_title('RPE1')

fig, ax = plt.subplots(figsize=(15, 8))
sns.heatmap(hesc_pathway_gene_df.loc[pg_cluster_order],
            cmap='seismic', center=0, cbar_kws={'label': 'mean gene program shift'},
            #xticklabels=False, yticklabels=False,
            vmax=0.5, vmin=-0.5,
            ax=ax
            )
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Functional clusters')
ax.set_title('hESC')

In [None]:
pathway_of_interest = 'cytoplasmic translation 1'

#pathway_of_interest = 'transcription initiation'
#pathway_of_interest = 'translation initiation'
#pathway_of_interest = 'mitochondrial transcription'

gene_list1 = np.array(adata_k562.obs[adata_k562.obs['cluster_name'] == pathway_of_interest]['perturbed_gene_name'])
gene_list1

In [None]:
gene_list2 = np.array(adata_rpe1.obs[adata_rpe1.obs['cluster_name'] == pathway_of_interest]['perturbed_gene_name'])
gene_list2

In [None]:
gene_list3 = np.array(adata_hesc.obs[adata_hesc.obs['cluster_name'] == pathway_of_interest]['perturbed_gene_name'])
gene_list3

In [None]:
common_genes = np.intersect1d(gene_list1, gene_list2)
common_genes = np.intersect1d(common_genes, gene_list3)
common_genes

In [None]:
genes_to_show = [
    'INO80B', 'NELFA', 'NELFB', 'NELFCD', 'NELFE', 'TAF2', 'TAF6', 
    'DPH2', 'EIF1AX', 'EIF3M', 'EIF4E', 'EIF4G1', 'STRAP', 'XRCC5',
    'AHCY', 'LRPPRC', 'MTPAP', 'POLRMT', 'REXO2', 'SSBP1', 'TEFM', 'TFAM'
    ]

In [None]:
adata_pert.obs['condition'].value_counts()

In [None]:
dataset_to_show = 'ReplogleWeissman2022_K562_gwps'
#genes_to_show = common_genes

adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show]
adata_to_show = adata_to_show[adata_to_show.obs['perturbed_gene_name'].isin(genes_to_show)].copy()



gene_program_show_df = pd.DataFrame(
    index=adata_to_show.obs['perturbed_gene_name'],
    columns=adata_pert.uns['gene_program_names'],
    data=adata_to_show.obsm['gene_program']
)

fig, ax = plt.subplots(figsize=(6, 6))
sns.heatmap(gene_program_show_df.loc[genes_to_show, gene_program_order], cmap='seismic', ax=ax, center=0, vmax=4, vmin=-4,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')


ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Perturbed genes')
ax.set_title(dataset_to_show)
fig.savefig(f'{plot_output_path}/{dataset_to_show}_example_genes_heatmap.pdf')

In [None]:
dataset_to_show = 'ReplogleWeissman2022_rpe1'
#genes_to_show = common_genes

adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show]
adata_to_show = adata_to_show[adata_to_show.obs['perturbed_gene_name'].isin(genes_to_show)].copy()

gene_program_show_df = pd.DataFrame(
    index=adata_to_show.obs['perturbed_gene_name'],
    columns=adata_pert.uns['gene_program_names'],
    data=adata_to_show.obsm['gene_program']
)

fig, ax = plt.subplots(figsize=(6, 6))
sns.heatmap(gene_program_show_df.loc[genes_to_show, gene_program_order], cmap='seismic', ax=ax, center=0, vmax=4, vmin=-4,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Perturbed genes')
ax.set_title(dataset_to_show)
fig.savefig(f'{plot_output_path}/{dataset_to_show}_example_genes_heatmap.pdf')

In [None]:
dataset_to_show = 'hESC_perturb_seq'
#genes_to_show = common_genes

adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show]
adata_to_show = adata_to_show[adata_to_show.obs['perturbed_gene_name'].isin(genes_to_show)].copy()


gene_program_show_df = pd.DataFrame(
    index=adata_to_show.obs['perturbed_gene_name'],
    columns=adata_pert.uns['gene_program_names'],
    data=adata_to_show.obsm['gene_program']
)

fig, ax = plt.subplots(figsize=(6, 6))
sns.heatmap(gene_program_show_df.loc[genes_to_show, gene_program_order], cmap='seismic', ax=ax, center=0, vmax=2, vmin=-2,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xlabel('Gene expression programs')
ax.set_ylabel('Perturbed genes')
ax.set_title(dataset_to_show)
fig.savefig(f'{plot_output_path}/{dataset_to_show}_example_genes_heatmap.pdf')

In [None]:
#dataset_to_show = 'ReplogleWeissman2022_K562_gwps'
#dataset_to_show = 'ReplogleWeissman2022_rpe1'
dataset_to_show = 'hESC_perturb_seq'


adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show]
adata_to_show = adata_to_show[adata_to_show.obs['perturbed_gene_name'].str.startswith('EIF')].copy()


gene_program_show_df = pd.DataFrame(
    index=adata_to_show.obs['perturbed_gene_name'],
    columns=adata_pert.uns['gene_program_names'],
    data=adata_to_show.obsm['gene_program']
)

fig, ax = plt.subplots(figsize=(6, 10))
sns.heatmap(gene_program_show_df.loc[:, gene_program_order].sort_index(), cmap='seismic', ax=ax, center=0, vmax=2, vmin=-2,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Perturbed Genes')
ax.set_title(dataset_to_show)

In [None]:
np.array(sorted(gene_program_show_df.index))

In [None]:
adata_pert

In [None]:
pg_to_show = [
    'TADA3', 'YEATS2',
    'TFAM', 'TEFM', 'POLRMT',
]

readout_gene_to_show = [
    'BAX', 'MDM2', 'ELOB', 
    'ASPM', 'HMMR', 'TPX2', 
    'ATF3', 'ATF4', 'ATF5', 'DDIT3',
    'GARS1', 'IARS1', 'LARS1',
    'INSIG1', 'MVD', 'FDPS', 
]