In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import seaborn as sns
import scipy.stats
from tqdm import tqdm

import anndata
import scanpy as sc

from scmg.preprocessing.data_standardization import GeneNameMapper

gene_name_mapper = GeneNameMapper()

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
plot_output_path = 'hesc_pseudobulk_plots'
os.makedirs(plot_output_path, exist_ok=True)

In [None]:
housekeeping_score_df = pd.read_parquet('/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/global_marker_genes/housekeeping_score.parquet')
housekeeping_score_map = dict(zip(housekeeping_score_df['human_gene_name'], housekeeping_score_df['housekeeping_score']))

In [None]:
pert_sim_df = pd.read_parquet('conservation_analysis/perturbation_similarity.parquet')
pert_sim_map = dict(zip(pert_sim_df['perturbed_gene_name'], pert_sim_df['hesc_AND_k562_gwps_cos_sim']))

In [None]:
dg_sim_df = pd.read_parquet('conservation_analysis/downstream_gene_similarity_hesc_k562.parquet')
dg_sim_map = dict(zip(dg_sim_df['gene_name'], dg_sim_df['cos_sim']))

In [None]:
adata = sc.read_h5ad('/GPUData_xingjie/SCMG/hESC_perturb_seq/pseudo_bulk.h5ad')
#adata = sc.read_h5ad('/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_gwps.h5ad')

adata

In [None]:
pert_cluster_df = pd.read_csv('clustering/perturbed_gene_clusters_hESC.csv', index_col=0)
dg_cluster_df = pd.read_csv('clustering/downstream_gene_clusters_hESC.csv', index_col=0)

common_perts = list(set(pert_cluster_df.index) & set(adata.obs.index))
common_dgs = list(set(dg_cluster_df.index) & set(adata.var.index))

pert_cluster_df = pert_cluster_df.loc[common_perts].copy()
dg_cluster_df = dg_cluster_df.loc[common_dgs].copy()
adata = adata[common_perts, common_dgs].copy()

In [None]:
pert_df = pd.DataFrame(
    data=adata.X,
    index=list(adata.obs['perturbed_gene_name']),
    columns=list(adata.var['gene_name'])
)

pert_sim_df = pd.DataFrame(
    data= 1 - scipy.spatial.distance.cdist(pert_df.values, pert_df.values, metric='cosine'),
    index=list(pert_df.index),
    columns=list(pert_df.index)
)

downstream_gene_sim_df = pd.DataFrame(
    data= 1 - scipy.spatial.distance.cdist(pert_df.values.T, pert_df.values.T, metric='cosine'),
    index=list(pert_df.columns),
    columns=list(pert_df.columns)
)

In [None]:
adata_other_dataset = sc.read_h5ad('/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_gwps.h5ad')
#adata_k562 = sc.read_h5ad('/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_rpe1.h5ad')

other_dataset_pert_df = pd.DataFrame(
    data=adata_other_dataset.X,
    index=list(adata_other_dataset.obs['perturbed_gene_name']),
    columns=list(adata_other_dataset.var['gene_name'])
)

selected_other_dataset_pert_df = pd.DataFrame(index=list(pert_df.index), columns=list(pert_df.columns),
                                     dtype=np.float32)
for c in selected_other_dataset_pert_df.columns:
    if c in other_dataset_pert_df.columns:
        selected_other_dataset_pert_df[c] = other_dataset_pert_df[c]

conservation_df = (selected_other_dataset_pert_df * pert_df)

In [None]:
# Perturbations
pert_clusters = [
   5, 11,
   7,
   28, 29, 30,
   6, 22, 
   23, 
   21, 
   
   9,
    
   31,
   8,
    
   18,
   20,
   32,
   
   17,
   

   3, 
   13,

   24,
   10,
   1, 
    
   12,
   25, 
   15,
   33,
   14, 
   
   2, 
   26,
   27, 
   4, 
   16, 19, 
   34,
   0,  
]
pert_gene_order = []
pert_cluster_colors = []
pert_cluster_centers = []

for cluster in pert_clusters:
    cluster_mask = pert_cluster_df['leiden'] == cluster
    selected_genes = list(pert_cluster_df[cluster_mask]['perturbed_gene_name'])
    
    for c in pert_cluster_df[cluster_mask]['leiden_color']:
       pert_cluster_colors.append(to_rgba(c))

    pert_cluster_centers.append(len(pert_gene_order) + (len(selected_genes) // 2))
    pert_gene_order += selected_genes


# Downstream genes
dg_clusters = [
   6,
   23,
   3,
   10,  
   20,
   7,
   17,    
   12,
   21, 
   13,
   19, 18, 4,  
   9,
   16, 
   5,
   0, 
   1, 2, 
   8, 22, 14,  
   15,   
   11,
]
dg_gene_order = []
dg_cluster_colors = []
dg_cluster_centers = []

for cluster in dg_clusters:
    cluster_mask = dg_cluster_df['leiden'] == cluster
    selected_genes = list(dg_cluster_df[cluster_mask]['gene_name'])
    
    for c in dg_cluster_df[cluster_mask]['leiden_color']:
       dg_cluster_colors.append(to_rgba(c))

    dg_cluster_centers.append(len(dg_gene_order) + (len(selected_genes) // 2))
    dg_gene_order += selected_genes

In [None]:
fig = plt.figure(figsize=(20, 20))
gs = fig.add_gridspec(4, 4, 
                      width_ratios=[0.05, 0.05, 0.05, 1], height_ratios=[1, 0.05, 0.05, 0.05], hspace=0.01, wspace=0.01)

heatmap_ax = fig.add_subplot(gs[0, 3])
sns.heatmap(pert_df.loc[pert_gene_order, dg_gene_order],
            center=0, cmap='seismic', vmax=0.5, vmin=-0.5,
            ax=heatmap_ax, cbar=False, rasterized=True)
heatmap_ax.set_xticks([])
heatmap_ax.set_yticks([])

# Plot the housekeeping score
housekeeping_score_row_ax = fig.add_subplot(gs[0, 2])
sns.heatmap([[housekeeping_score_map.get(g, 0)] for g in pert_gene_order], vmin=0, vmax=1, cmap='plasma', cbar=False, 
            rasterized=True, ax=housekeeping_score_row_ax)
housekeeping_score_row_ax.axis('off')

housekeeping_score_col_ax = fig.add_subplot(gs[1, 3])
sns.heatmap([[housekeeping_score_map.get(g, 0) for g in dg_gene_order]], vmin=0, vmax=1, cmap='plasma', cbar=False, 
            rasterized=True, ax=housekeeping_score_col_ax)
housekeeping_score_col_ax.axis('off')

# Plot the conservation score
conservation_row_ax = fig.add_subplot(gs[0, 1])
sns.heatmap([[pert_sim_map.get(g, 0)] for g in pert_gene_order], center=0, cmap='PuOr_r', vmax=1, vmin=-1, 
            rasterized=True, ax=conservation_row_ax, cbar=False)
conservation_row_ax.axis('off')

conservation_col_ax = fig.add_subplot(gs[2, 3])
sns.heatmap([[dg_sim_map.get(g, 0) for g in dg_gene_order]], center=0, cmap='PuOr_r', vmax=1, vmin=-1, 
            rasterized=True, ax=conservation_col_ax, cbar=False)
conservation_col_ax.axis('off')


# Row color strip
row_colors_ax = fig.add_subplot(gs[0, 0], sharey=heatmap_ax)
row_colors_array = np.array(pert_cluster_colors)[:, np.newaxis, :]
row_colors_ax.imshow(row_colors_array, aspect="auto")
row_colors_ax.set_xticks([])
row_colors_ax.set_yticks(pert_cluster_centers, pert_clusters)

# Column color strip
col_colors_ax = fig.add_subplot(gs[3, 3], sharex=heatmap_ax)
col_colors_array = np.array(dg_cluster_colors)[np.newaxis, :, :]
col_colors_ax.imshow(col_colors_array, aspect="auto")
col_colors_ax.set_xticks(dg_cluster_centers, dg_clusters)
col_colors_ax.set_yticks([])

fig.savefig(f'{plot_output_path}/hesc_pseudobulk_exp_shift_heatmap.pdf', dpi=300)

In [None]:
fig = plt.figure(figsize=(20, 20))
gs = fig.add_gridspec(4, 4, 
                      width_ratios=[0.05, 0.05, 0.05, 1], height_ratios=[1, 0.05, 0.05, 0.05], hspace=0.01, wspace=0.01)

heatmap_ax = fig.add_subplot(gs[0, 3])
sns.heatmap(conservation_df.loc[pert_gene_order, dg_gene_order],
            center=0, cmap='coolwarm', vmax=0.03, vmin=-0.03,
            rasterized=True, ax=heatmap_ax, cbar=False)
heatmap_ax.set_xticks([])
heatmap_ax.set_yticks([])

# Plot the housekeeping score
housekeeping_score_row_ax = fig.add_subplot(gs[0, 2])
sns.heatmap([[housekeeping_score_map.get(g, 0)] for g in pert_gene_order], vmin=0, vmax=1, cmap='plasma', cbar=False, 
            rasterized=True, ax=housekeeping_score_row_ax)
housekeeping_score_row_ax.axis('off')

housekeeping_score_col_ax = fig.add_subplot(gs[1, 3])
sns.heatmap([[housekeeping_score_map.get(g, 0) for g in dg_gene_order]], vmin=0, vmax=1, cmap='plasma', cbar=False, 
            rasterized=True, ax=housekeeping_score_col_ax)
housekeeping_score_col_ax.axis('off')

# Plot the conservation score
conservation_row_ax = fig.add_subplot(gs[0, 1])
sns.heatmap([[pert_sim_map.get(g, 0)] for g in pert_gene_order], center=0, cmap='PuOr_r', vmax=1, vmin=-1, 
            rasterized=True, ax=conservation_row_ax, cbar=False)
conservation_row_ax.axis('off')

conservation_col_ax = fig.add_subplot(gs[2, 3])
sns.heatmap([[dg_sim_map.get(g, 0) for g in dg_gene_order]], center=0, cmap='PuOr_r', vmax=1, vmin=-1, 
            rasterized=True, ax=conservation_col_ax, cbar=False)
conservation_col_ax.axis('off')


# Row color strip
row_colors_ax = fig.add_subplot(gs[0, 0], sharey=heatmap_ax)
row_colors_array = np.array(pert_cluster_colors)[:, np.newaxis, :]
row_colors_ax.imshow(row_colors_array, aspect="auto")
row_colors_ax.set_xticks([])
row_colors_ax.set_yticks(pert_cluster_centers, pert_clusters)

# Column color strip
col_colors_ax = fig.add_subplot(gs[3, 3], sharex=heatmap_ax)
col_colors_array = np.array(dg_cluster_colors)[np.newaxis, :, :]
col_colors_ax.imshow(col_colors_array, aspect="auto")
col_colors_ax.set_xticks(dg_cluster_centers, dg_clusters)
col_colors_ax.set_yticks([])

fig.savefig(f'{plot_output_path}/hesc_pseudobulk_conservation_score_heatmap.pdf', dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
norm = matplotlib.colors.Normalize(vmin=-0.5, vmax=0.5)
cbar = ax.figure.colorbar(
    matplotlib.cm.ScalarMappable(norm=norm, cmap='seismic'), 
    ax=ax, orientation='vertical', label='gene expression shift')
ax.axis('off')
fig.savefig(f'{plot_output_path}/hesc_gene_exp_shift_colorbar.pdf', dpi=300)

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
cbar = ax.figure.colorbar(
    matplotlib.cm.ScalarMappable(norm=norm, cmap='plasma'), 
    ax=ax, orientation='vertical', label='house keeping score')
ax.axis('off')
fig.savefig(f'{plot_output_path}/hesc_house_keeping_score_colorbar.pdf', dpi=300)

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
norm = matplotlib.colors.Normalize(vmin=-1, vmax=1)
cbar = ax.figure.colorbar(
    matplotlib.cm.ScalarMappable(norm=norm, cmap='PuOr_r'), 
    ax=ax, orientation='vertical', label='perturbation vector similarity')
ax.axis('off')
fig.savefig(f'{plot_output_path}/hesc_pert_vec_sim_colorbar.pdf', dpi=300)

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
norm = matplotlib.colors.Normalize(vmin=-0.03, vmax=0.03)
cbar = ax.figure.colorbar(
    matplotlib.cm.ScalarMappable(norm=norm, cmap='coolwarm'), 
    ax=ax, orientation='vertical', label='conservation score')
ax.axis('off')
fig.savefig(f'{plot_output_path}/hesc_conservation_score_colorbar.pdf', dpi=300)

In [None]:
fig = plt.figure(figsize=(20, 20))
gs = fig.add_gridspec(2, 2, 
                      width_ratios=[0.05, 1], height_ratios=[1, 0.05,], hspace=0.01, wspace=0.01)

heatmap_ax = fig.add_subplot(gs[0, 1])
sns.heatmap(pert_sim_df.loc[pert_gene_order, pert_gene_order],
            center=0, cmap='seismic', vmax=1, vmin=-1,
            ax=heatmap_ax, cbar=False)
heatmap_ax.set_xticks([])
heatmap_ax.set_yticks([])


# Row color strip
row_colors_ax = fig.add_subplot(gs[0, 0], sharey=heatmap_ax)
row_colors_array = np.array(pert_cluster_colors)[:, np.newaxis, :]
row_colors_ax.imshow(row_colors_array, aspect="auto")
row_colors_ax.set_xticks([])
row_colors_ax.set_yticks(pert_cluster_centers, pert_clusters)

# Column color strip
col_colors_ax = fig.add_subplot(gs[1, 1], sharex=heatmap_ax)
col_colors_array = np.array(pert_cluster_colors)[np.newaxis, :, :]
col_colors_ax.imshow(col_colors_array, aspect="auto")
col_colors_ax.set_xticks(pert_cluster_centers, pert_clusters)
col_colors_ax.set_yticks([])


In [None]:
fig = plt.figure(figsize=(10, 10))
gs = fig.add_gridspec(2, 2, 
                      width_ratios=[0.05, 1], height_ratios=[1, 0.05,], hspace=0.01, wspace=0.01)

heatmap_ax = fig.add_subplot(gs[0, 1])
sns.heatmap(downstream_gene_sim_df.loc[dg_gene_order, dg_gene_order],
            center=0, cmap='seismic', vmax=1, vmin=-1,
            ax=heatmap_ax, cbar=False)
heatmap_ax.set_xticks([])
heatmap_ax.set_yticks([])


# Row color strip
row_colors_ax = fig.add_subplot(gs[0, 0], sharey=heatmap_ax)
row_colors_array = np.array(dg_cluster_colors)[:, np.newaxis, :]
row_colors_ax.imshow(row_colors_array, aspect="auto")
row_colors_ax.set_xticks([])
row_colors_ax.set_yticks(dg_cluster_centers, dg_clusters)

# Column color strip
col_colors_ax = fig.add_subplot(gs[1, 1], sharex=heatmap_ax)
col_colors_array = np.array(dg_cluster_colors)[np.newaxis, :, :]
col_colors_ax.imshow(col_colors_array, aspect="auto")
col_colors_ax.set_xticks(dg_cluster_centers, dg_clusters)
col_colors_ax.set_yticks([])


In [None]:
all_gene_corr_df = pd.read_parquet('/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/global_gene_correlation/direct_gene_corr_df_measured.parquet')
#all_gene_corr_df = pd.read_parquet('/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/global_gene_correlation/gene_corr_df_measured.parquet')


named_all_gene_corr_df = all_gene_corr_df.copy()
named_all_gene_corr_df.index = gene_name_mapper.map_gene_names(
    named_all_gene_corr_df.index, 'human', 'human', 'id', 'name')
named_all_gene_corr_df.columns = gene_name_mapper.map_gene_names(
    named_all_gene_corr_df.columns, 'human', 'human', 'id', 'name')

In [None]:
fig = plt.figure(figsize=(10, 10))
gs = fig.add_gridspec(2, 2, 
                      width_ratios=[0.05, 1], height_ratios=[1, 0.05,], hspace=0.01, wspace=0.01)

heatmap_ax = fig.add_subplot(gs[0, 1])
sns.heatmap(named_all_gene_corr_df.loc[dg_gene_order, dg_gene_order],
            center=0, cmap='seismic', vmax=1, vmin=-1,
            ax=heatmap_ax, cbar=False)
heatmap_ax.set_xticks([])
heatmap_ax.set_yticks([])


# Row color strip
row_colors_ax = fig.add_subplot(gs[0, 0], sharey=heatmap_ax)
row_colors_array = np.array(dg_cluster_colors)[:, np.newaxis, :]
row_colors_ax.imshow(row_colors_array, aspect="auto")
row_colors_ax.set_xticks([])
row_colors_ax.set_yticks(dg_cluster_centers, dg_clusters)

# Column color strip
col_colors_ax = fig.add_subplot(gs[1, 1], sharex=heatmap_ax)
col_colors_array = np.array(dg_cluster_colors)[np.newaxis, :, :]
col_colors_ax.imshow(col_colors_array, aspect="auto")
col_colors_ax.set_xticks(dg_cluster_centers, dg_clusters)
col_colors_ax.set_yticks([])


In [None]:
selected_genes = list(dg_cluster_df[dg_cluster_df['leiden'].isin([12])]['gene_name'])
sns.clustermap(named_all_gene_corr_df.loc[selected_genes, selected_genes],
            center=0, cmap='seismic', vmax=1, vmin=-1, figsize=(10, 10))

In [None]:
adata_ct_ref = sc.read_h5ad('../../manifold_generator/ref_cell_adata_measured_count.h5ad')
sc.pp.normalize_total(adata_ct_ref, target_sum=1e4)
sc.pp.log1p(adata_ct_ref)
#sc.pp.scale(adata_ct_ref, max_value=10)

adata_ct_ref.obs['total_exp'] = adata_ct_ref.X.sum(axis=1)
adata_ct_ref

In [None]:
named_adata = adata_ct_ref.copy()
named_adata.var.index = adata_ct_ref.var['human_gene_name']

In [None]:
selected_genes = list(dg_cluster_df[dg_cluster_df['leiden'].isin([17])]['gene_name'])
display(np.array(selected_genes))

named_adata.obs['selected_exp'] = named_adata[:, named_adata.var.index.isin(selected_genes)].X.sum(axis=1)
named_adata.obs['selected_exp_frac'] = named_adata.obs['selected_exp'] / named_adata.obs['total_exp']

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

sc.pl.umap(named_adata, color=['selected_exp'], vmax=None, cmap='cool', ax=axes[0], show=False)

vmax = np.quantile(named_adata.obs['selected_exp_frac'], 0.99)
sc.pl.umap(named_adata, color=['selected_exp_frac'], vmax=vmax, cmap='cool', ax=axes[1], show=False)

In [None]:
np.array(pert_cluster_df[pert_cluster_df['leiden'].isin([9])]['perturbed_gene_name'])

In [None]:
#np.array(dg_cluster_df[dg_cluster_df['leiden'].isin([6])]['gene_name'])

In [None]:
pert_genes = list(pert_cluster_df[pert_cluster_df['leiden'].isin([28])]['perturbed_gene_name'])
dg_genes = list(dg_cluster_df[dg_cluster_df['leiden'].isin([21])]['gene_name'])

sns.set(font_scale = 0.5)
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.heatmap(pert_df.loc[pert_genes, dg_genes], center=0, vmax=0.5, vmin=-0.5, cmap='seismic', ax=ax)
plt.show()
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.heatmap(conservation_df.loc[pert_genes, dg_genes], center=0, vmax=0.03, vmin=-0.03, cmap='RdYlGn', ax=ax)
plt.show()