In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc
import umap
import scipy.stats

import torch

from scmg.model.contrastive_embedding import (CellEmbedder,  embed_adata)

from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
plot_output_path = 'causal_perts_for_exp_programs_plots/'
os.makedirs(plot_output_path, exist_ok=True)

In [None]:
pert_data_files = [
    '/GPUData_xingjie/SCMG/perturbation_data/AdamsonWeissman2016_GSM2406681_10X010.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/FrangiehIzar2021_RNA.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/hESC_TF_screen.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNG.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_INS.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TGFB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TNFA.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/Joung_TFScreen_HS_2023.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_human.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_mouse.h5ad',
#    #'/GPUData_xingjie/SCMG/perturbation_data/omnipath.h5ad',
#    '/GPUData_xingjie/SCMG/perturbation_data/PertOrg.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_essential.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_gwps.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_rpe1.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRa.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRi.h5ad',
    '/GPUData_xingjie/SCMG/hESC_perturb_seq/pseudo_bulk_filtered.h5ad', # Test

    '/GPUData_xingjie/SCMG/perturbation_data/ZhuMarson2025_T_cell_Rest.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ZhuMarson2025_T_cell_Stim8hr.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ZhuMarson2025_T_cell_Stim48hr.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/Xaira2025_HCT116.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/Xaira2025_HEK293T.h5ad',
]

adata_pert_list = []
for pdf in pert_data_files:
    adata_pert_list.append(sc.read_h5ad(pdf))
    print(os.path.basename(pdf), adata_pert_list[-1].shape[0])

adata_pert = anndata.concat(adata_pert_list, axis=0)
adata_pert.var['gene_name'] = adata_pert_list[0].var['gene_name']

adata_pert

In [None]:
# Mask out the direct target genes
for i in range(adata_pert.shape[0]):
    pg = adata_pert.obs['perturbed_gene'].iloc[i]
    
    if pg in adata_pert.var_names:
        adata_pert.X[i, adata_pert.var_names.get_loc(pg)] = 0

adata_pert.obs['effect_size'] = np.linalg.norm(adata_pert.X, axis=1)
adata_pert.obs['max_effect'] = np.max(np.abs(adata_pert.X), axis=1)

In [None]:
gene_exp_cluster_df = pd.read_csv('../../classify_genes/systematic_classification/gene_exp_cluster_annotation.csv', index_col=0)
adata_pert.var['gene_exp_cluster'] = gene_exp_cluster_df['cluster_name']
gene_exp_cluster_counts = adata_pert.var['gene_exp_cluster'].value_counts()
gene_exp_cluster_counts

In [None]:
gene_exp_cluster_all_df = pd.read_csv('../../classify_genes/systematic_classification/gene_exp_cluster_annotation_all.csv', index_col=0)
essential_genes = gene_exp_cluster_all_df[(gene_exp_cluster_all_df['entropy_decoded'] >= 6.4)
                        | (gene_exp_cluster_all_df['annotation'].isin([
                            'chromatin structure', 'DNA replication/repair', 'cell cycle (G1/S)', 'cell cycle (prometaphase)', 'cell cycle (M phase)', 
                            'spliceosome', 'proliferation',
                            'ribosome biogenesis', 'ribosomal protein genes', 'mitochondrial encoded', 'cholesterol biosynthesis', 'mitochondrial ribosome', 
                            'Golgi vesicle transport', 'unfolded protein response', 
                            'integrated stress response', 'p53 signaling', 'lysosome/autophagy',
                        ]))]['gene_name'].values

In [None]:
adata_pert.uns['gene_program_names'] = list(gene_exp_cluster_df['cluster_name'].unique())
adata_pert.obsm['gene_program'] = np.zeros((adata_pert.shape[0], len(adata_pert.uns['gene_program_names'])))

for i in range(len(adata_pert.uns['gene_program_names'])):
    genes = list(gene_exp_cluster_df[gene_exp_cluster_df['cluster_name'] == adata_pert.uns['gene_program_names'][i]].index)
    if len(genes) == 0:
        continue
    scale_factor = np.sqrt(len(genes))
    #scale_factor = 1
    adata_pert.obsm['gene_program'][:, i] = scale_factor * np.mean(adata_pert[:, genes].X, axis=1)

In [None]:
gene_program_rmss_df = pd.DataFrame(
    index=np.unique(adata_pert.obs['condition']),
    columns=adata_pert.uns['gene_program_names'],
    data=np.float32(0.0))

for cond in tqdm(gene_program_rmss_df.index):
    adata_sub = adata_pert[adata_pert.obs['condition'] == cond]
    for i in range(len(adata_pert.uns['gene_program_names'])):
        gene_program_rmss_df.loc[cond, adata_pert.uns['gene_program_names'][i]] = np.sqrt(np.mean(adata_sub.obsm['gene_program'][:, i] ** 2))

In [None]:
gene_program_rmss_df_normalized = gene_program_rmss_df.div(gene_program_rmss_df.mean(axis=1), axis=0)

In [None]:
conditions_to_show = ['hESC_TF_screen', 'ReplogleWeissman2022_K562_gwps', 'ReplogleWeissman2022_rpe1', 'TianKampmann2021_CRISPRa', 'TianKampmann2021_CRISPRi', ]
gene_program_order = [
    'chromatin structure', 'DNA replication/repair', 'cell cycle (G1/S)', 'cell cycle (prometaphase)', 'cell cycle (M phase)', 
    'spliceosome', 'proliferation',
    'ribosome biogenesis', 'ribosomal protein genes', 'mitochondrial encoded', 'cholesterol biosynthesis', 'mitochondrial ribosome', 
    'Golgi vesicle transport', 'unfolded protein response', 
    'integrated stress response', 'p53 signaling', 'lysosome/autophagy',  
       
    'pluripotency', 'Hox genes', 'glia', 'neural development', 'neuronal', 'peripheral neurons', 
    'visual perception', 'retinal epithelium', 'melanin biosynthesis', 
    'epithelial', 'respiratory epithelium', 'kidney', 'intestine', 'pancreatic', 'pancreatic islet', 'liver', 'epidermal', 
    'mesothelial', 'adipocyte',  'mesenchymal', 'smooth muscle', 'endothelial', 'bone', 

    'interferon signaling', 'TNF signaling', 'immune system', 'myeloid', 'macrophage', 'B cell', 'T cell', 'natural killer',
    'mast cell', 'erythroid', 'megakaryocyte',
       
    'muscle', 'heart', 'cilia', 
       
    ]

fig, ax = plt.subplots(figsize=(25, 2))
sns.heatmap(gene_program_rmss_df.loc[conditions_to_show, gene_program_order], cmap='Reds', ax=ax, vmax=1,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Perturbation Datasets')

In [None]:
adata_pert.obs['condition'].value_counts()

In [None]:
dataset_to_show = 'Xaira2025_HEK293T'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'Xaira2025_HCT116'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'ZhuMarson2025_T_cell_Rest'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'ZhuMarson2025_T_cell_Stim8hr'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'ZhuMarson2025_T_cell_Stim48hr'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'hESC_TF_screen'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'ReplogleWeissman2022_K562_gwps'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'ReplogleWeissman2022_rpe1'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'hESC_perturb_seq'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

fig, ax = plt.subplots(figsize=(25, 4))

for i in range(len(gene_program_order)):
    x = i - 0.25 + (np.arange(adata_to_show.shape[0]) / adata_to_show.shape[0] / 2)

    idx = adata_pert.uns['gene_program_names'].index(gene_program_order[i])
    y = adata_to_show.obsm['gene_program'][:, idx]
    essential_mask = adata_to_show.obs['perturbed_gene_name'].isin(essential_genes).values

    ax.axvline(i, color='lightgray', lw=0.5, zorder=-1)

    colors = np.array(['orange'] * adata_to_show.shape[0])
    colors[~essential_mask] = 'blue'
    ax.scatter(x, y, s=1, edgecolor='none', c=list(colors), rasterized=True)

    top_shift_indices = list(np.argsort(y)[:3]) + list(np.argsort(y)[-3:])
    for j in top_shift_indices:
        if np.abs(y[j]) > 1:
            ax.text(x[j], y[j], adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
            #ax.scatter(x[j], y[j], s=6, edgecolor='red', c='none', linewidth=0.8)
            #print(gene_program_order[i], adata_to_show.obs['perturbed_gene_name'].iloc[j], f'{y[j]:.2f}')


ax.legend(handles=[
    matplotlib.lines.Line2D([], [], marker='o', color='orange', linestyle='None', markersize=5, label='Essential genes perturbed'),
    matplotlib.lines.Line2D([], [], marker='o', color='blue', linestyle='None', markersize=5, label='Non-essential genes perturbed'),
], loc='lower right', fontsize=12)
ax.axhline(0, color='lightgray', lw=0.5, zorder=-1)
ax.set_xticks(np.arange(len(gene_program_order)), gene_program_order, rotation=45, ha='right')
ax.set_xlim(-1, len(gene_program_order))
ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Program Activity')
ax.set_title(dataset_to_show)

plt.savefig(f'{plot_output_path}/{dataset_to_show}_gene_program_activity.pdf')
plt.show()

In [None]:
dataset_to_show = 'hESC_perturb_seq'
genes_to_show = [
    'MED19', 'SP1', 'SUPT20H', 'NANOG', 'SOX2', 'POU5F1',
    'HARS', 'TARS', 'RPP14', 'RPP30', 'TSEN2', 'RRP9', 'PDCD11',
       'DDX56', 'DDX21', 'HEATR1', 'BRIX1', 'MRP63', 'RCL1', 'KRR1',
       'EIF2B4', 'EIF2B5', 'EIF2S2', 'EIF2B3', 'RNF214', 'UBE2T', 'SHFM1',
       'GEMIN5', 'SMNDC1', 'PDCD7', 'PPIE', 'RNGTT', 'EXOSC10', 'GPN3',
       'POLR2M', 'TAF6', 'MED22', 'ZC3H8', 'ZNF574', 'FOXD3', 'SKA1',
       'SKA3', 'CENPI', 'CCNH', 'MNAT1', 'CHAF1B', 'TIMELESS', 'MMS22L',
       'BRCA2', 'BRIP1', 'SDE2', 'REV3L', 'RFWD3', 'TANGO6', 'EXOC3',
       'DERL2', 'PHB', 'IPO7', 'BCL2L1', 'C22orf15', 'MTBP', 'NCAPH', 'MED4'
     ]

adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show]
adata_to_show = adata_to_show[adata_to_show.obs['perturbed_gene_name'].isin(genes_to_show)].copy()

genes_to_show = [g for g in genes_to_show if g in adata_to_show.obs['perturbed_gene_name'].values]

gene_program_show_df = pd.DataFrame(
    index=adata_to_show.obs['perturbed_gene_name'],
    columns=adata_pert.uns['gene_program_names'],
    data=adata_to_show.obsm['gene_program']
)


fig, ax = plt.subplots(figsize=(25, 15))
sns.heatmap(gene_program_show_df.loc[genes_to_show, gene_program_order], cmap='seismic', ax=ax, center=0,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xlabel('Perturbed Genes')
ax.set_ylabel('Gene Expression Programs')
ax.set_title(dataset_to_show)

In [None]:
dataset_to_show = 'hESC_TF_screen'
genes_to_show = [
    'NEUROD4', 'E2F1', 'NEUROD1',
    'E2F4', 'MYC', 
    'ZIC1', 'KLF12', 'KLF9', 
    'SOX10', 'CEBPB', 'SREBF2',
    'NR2C2', 'NR2E1', 'TBX5',
    'NR4A3', 'FOXF1', 'TP73',
    'PATZ1', 'FLI1', 'BANP', 
    'MEF2C', 'ELF4',
    'NFIB', 'BMP2',
    'SOX9',
    'NEUROG2', 'NEUROG3', 'MYT1',
    'TCF4', 'SOX5', 'TP73',
    'ETV2', 'SOX5', 'KLF15',
    'ETS1', 'ETS2', 'ETV2',
    'KLF4', 
    'SPI1', 'IRF1', 'SPIB',
    'CEBPA', 'CEBPB',
    'MYOD1', 'MYF5',
    'FOXJ1',
    'PTF1A'
     ]
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show]
adata_to_show = adata_to_show[adata_to_show.obs['perturbed_gene_name'].isin(genes_to_show)].copy()

gene_program_show_df = pd.DataFrame(
    index=adata_to_show.obs['perturbed_gene_name'],
    columns=adata_pert.uns['gene_program_names'],
    data=adata_to_show.obsm['gene_program']
)


fig, ax = plt.subplots(figsize=(25, 15))
sns.heatmap(gene_program_show_df.loc[genes_to_show, gene_program_order], cmap='seismic', ax=ax, center=0,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xlabel('Perturbed Genes')
ax.set_ylabel('Gene Expression Programs')
ax.set_title(dataset_to_show)

In [None]:
dataset_to_show = 'hESC_TF_screen'
programs_to_consider = [
    'chromatin structure', 'DNA replication/repair', 'cell cycle (G1/S)', 'cell cycle (prometaphase)', 'cell cycle (M phase)', 
    'spliceosome', 'proliferation',
    'ribosome biogenesis', 'ribosomal protein genes', 'mitochondrial encoded', 'cholesterol biosynthesis', 'mitochondrial ribosome', 
    'Golgi vesicle transport', 'unfolded protein response', 
    
       
    'pluripotency', 'Hox genes', 'glia', 'neural development', 'neuronal', 'peripheral neurons', 
    'visual perception', 'retinal epithelium', 'melanin biosynthesis', 
    'epithelial', 'respiratory epithelium', 'kidney', 'intestine', 'pancreatic', 'pancreatic islet', 'liver', 'epidermal', 
    'mesothelial', 'adipocyte',  'mesenchymal', 'smooth muscle', 'endothelial', 'bone', 

    'interferon signaling', 'TNF signaling', 'immune system', 'myeloid', 'macrophage', 'B cell', 'T cell', 'natural killer',
    'mast cell', 'erythroid', 'megakaryocyte',
       
    'muscle', 'heart', 'cilia', 

    # 'integrated stress response', 'p53 signaling', 'lysosome/autophagy', 
       
    ]
programs_ids_to_consider = [adata_pert.uns['gene_program_names'].index(p) for p in programs_to_consider]


adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

adata_to_show.obs['top_program_shift'] = np.max(adata_to_show.obsm['gene_program'][:, programs_ids_to_consider], axis=1)
adata_to_show.obs['pluripotency_shift'] = adata_to_show.obsm['gene_program'][:, 
                                        adata_to_show.uns['gene_program_names'].index('pluripotency')]

adata_to_show.obs['shift_diff'] = adata_to_show.obs['top_program_shift'] - np.abs(adata_to_show.obs['pluripotency_shift'])

In [None]:
adata_to_show.obs.sort_values('shift_diff', ascending=False)[:20]['perturbed_gene_name'].values

In [None]:
genes_to_highlight_red = [
    'E2F1', 'E2F4', 
    'FOXM1',
    'MYC', 
    'KLF14', 'KLF9', 'CBX8', 
    'SREBF2', 
    'PLXNB3', 'WNT3A', 'NR2C2', 
    'SPI1', 'IRF1', 'SPIB',
    'MYT1', 'ETV2', 'MYOD1', 'MYF5', 'FOXJ1', 
]


fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(adata_to_show.obs['pluripotency_shift'], adata_to_show.obs['top_program_shift'], s=5, edgecolor='none', c='grey', rasterized=True)


indices_to_highlight_red = [ i for i, g in enumerate(adata_to_show.obs['perturbed_gene_name']) if g in genes_to_highlight_red] 
display(adata_to_show.obs.iloc[indices_to_highlight_red].sort_values('pluripotency_shift', ascending=False)['perturbed_gene_name'].values)

for j in indices_to_highlight_red:
    if adata_to_show.obs['top_program_shift'].iloc[j] > 1:
        ax.text(adata_to_show.obs['pluripotency_shift'].iloc[j], adata_to_show.obs['top_program_shift'].iloc[j], 
                adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
        ax.scatter(adata_to_show.obs['pluripotency_shift'].iloc[j], adata_to_show.obs['top_program_shift'].iloc[j], s=20, edgecolor='none', c='red')

ax.axhline(0, color='lightgray', lw=1, zorder=-1)
ax.axvline(0, color='lightgray', lw=1, zorder=-1)
ax.set_xlabel('Pluripotency Program Shift')
ax.set_ylabel('Top Program Shift')
fig.savefig(f'{plot_output_path}/{dataset_to_show}_pluripotency_vs_top_program_shift.pdf')

In [None]:
dataset_to_show = 'hESC_TF_screen'
genes_to_show = [
    'E2F1', 'E2F4', 
    'FOXM1',
    'MYC', 
    'KLF14', 'KLF9', 'CBX8', 
    'SREBF2', 
    'PLXNB3', 'WNT3A', 'NR2C2', 
    'SPI1', 'IRF1', 'SPIB',
    'MYT1', 'ETV2', 'MYOD1', 'MYF5', 'FOXJ1', 
]
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show]
adata_to_show = adata_to_show[adata_to_show.obs['perturbed_gene_name'].isin(genes_to_show)].copy()

gene_program_show_df = pd.DataFrame(
    index=adata_to_show.obs['perturbed_gene_name'],
    columns=adata_pert.uns['gene_program_names'],
    data=adata_to_show.obsm['gene_program']
)

fig, ax = plt.subplots(figsize=(25, 6))
sns.heatmap(gene_program_show_df.loc[genes_to_show, gene_program_order], cmap='seismic', ax=ax, center=0, vmax=4, vmin=-4,
            cbar_kws={'label': 'RMS'})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xlabel('Gene Expression Programs')
ax.set_ylabel('Perturbed Genes')
ax.set_title(dataset_to_show)
plt.savefig(f'{plot_output_path}/{dataset_to_show}_selected_genes_gene_program_activity.pdf')

In [None]:
dataset_to_show = 'hESC_perturb_seq'
adata_to_show = adata_pert[adata_pert.obs['condition'] == dataset_to_show].copy()

adata_to_show.obs['top_program_shift'] = np.max(adata_to_show.obsm['gene_program'], axis=1)
adata_to_show.obs['pluripotency_shift'] = adata_to_show.obsm['gene_program'][:, 
                                        adata_to_show.uns['gene_program_names'].index('pluripotency')]

adata_to_show.obs['shift_diff'] = adata_to_show.obs['top_program_shift'] - np.abs(adata_to_show.obs['pluripotency_shift'])

In [None]:
adata_to_show.obs.sort_values('pluripotency_shift', ascending=True)[:10]['perturbed_gene_name'].values

In [None]:
genes_to_highlight_red = ['DCTN5', 'ETF1', 'RAD18', 'DSEL', 'GRK4', 'POU5F1', 'EIF3B',
       'FBLN5', 'SC5D', 'KIAA0753']


fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(adata_to_show.obs['pluripotency_shift'], adata_to_show.obs['top_program_shift'], s=5, edgecolor='none', c='grey')


indices_to_highlight_red = [ i for i, g in enumerate(adata_to_show.obs['perturbed_gene_name']) if g in genes_to_highlight_red] 
for j in indices_to_highlight_red:
    if adata_to_show.obs['top_program_shift'].iloc[j] > 1:
        ax.text(adata_to_show.obs['pluripotency_shift'].iloc[j], adata_to_show.obs['top_program_shift'].iloc[j], 
                adata_to_show.obs['perturbed_gene_name'].iloc[j], fontsize=10)
        ax.scatter(adata_to_show.obs['pluripotency_shift'].iloc[j], adata_to_show.obs['top_program_shift'].iloc[j], s=20, edgecolor='none', c='red')

ax.axhline(0, color='lightgray', lw=1, zorder=-1)
ax.axvline(0, color='lightgray', lw=1, zorder=-1)
ax.set_xlabel('Pluripotency Program Shift')
ax.set_ylabel('Top Program Shift')