In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc
import umap
import scipy.stats

import torch

from scmg.model.contrastive_embedding import (CellEmbedder, 
                                        decode_cell_state_embedding, embed_adata)

from scmg.model.manifold_generation import ConditionalDiffusionModel, generate_cells
from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
output_path = 'causal_gene_plots'
os.makedirs(output_path, exist_ok=True)

In [None]:
# Load the autoencoder model
model_ce_path = '../../contrastive_embedding/trained_embedder/'

model_ce = torch.load(os.path.join(model_ce_path, 'model.pt'))
model_ce.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_ce.to(device)
model_ce.eval()
model_ce.dataset_id_map

In [None]:
pert_data_files = [
    '/GPUData_xingjie/SCMG/perturbation_data/AdamsonWeissman2016_GSM2406681_10X010.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/FrangiehIzar2021_RNA.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/hESC_TF_screen.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNG.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_INS.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TGFB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TNFA.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/Joung_TFScreen_HS_2023.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_human.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_mouse.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/omnipath.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/PertOrg.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_essential.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_gwps.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_rpe1.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRa.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRi.h5ad',
    #'/GPUData_xingjie/SCMG/hESC_perturb_seq/pseudo_bulk.h5ad', # Test
]

adata_pert_list = []
for pdf in pert_data_files:
    adata_pert_list.append(sc.read_h5ad(pdf))
    print(os.path.basename(pdf), adata_pert_list[-1].shape[0])

adata_pert = anndata.concat(adata_pert_list, axis=0)
adata_pert.var['gene_name'] = adata_pert_list[0].var['gene_name']

adata_pert

In [None]:
# Mask out the direct target genes
for i in range(adata_pert.shape[0]):
    pg = adata_pert.obs['perturbed_gene'].iloc[i]
    
    if pg in adata_pert.var_names:
        adata_pert.X[i, adata_pert.var_names.get_loc(pg)] = 0

In [None]:
adata_pert_ctl = adata_pert.copy()
adata_pert_ctl.X = np.exp(adata_pert_ctl.layers['control']) - 1
embed_adata(model_ce, adata_pert_ctl, batch_size=8192)

adata_pert.obsm['X_ctl_ce_latent'] = adata_pert_ctl.obsm['X_ce_latent']
adata_pert

In [None]:
# Load the diffusion model
model_d_path = '../../manifold_generator/trained_diffusion_model'

model_d = torch.load(os.path.join(model_d_path, 'model.pt'))
model_d.load_state_dict(torch.load(os.path.join(model_d_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_d.to(device)
model_d.eval()

In [None]:
target_cell_type_dict = {
#    'Inner cell mass' : 300,
    'Epiblast' : 300,
    'Primitive streak and adjacent ectoderm' : 300,
    'Rostral neuroectoderm' : 300,
#    'Caudal neuroectoderm' : 300,
    'Nascent mesoderm' : 300,
    'Anterior primitive streak' : 300,
#    'Forebrain/midbrain' : 300,
#    'Definitive endoderm' : 300,
}

#target_cell_type_dict = {
#    'hematopoietic multipotent progenitor cell' : 600,
#    'megakaryocyte-erythroid progenitor cell' : 300,
##    'common myeloid progenitor' : 300,
#    'granulocyte monocyte progenitor cell' : 300,
##    'megakaryocyte' : 300,
##    'erythroid progenitor cell' : 300,
#    'common lymphoid progenitor' : 300,
#}

In [None]:
cond_classes = []
for cell_type in target_cell_type_dict:
    cond_classes.extend([cell_type] * target_cell_type_dict[cell_type])

In [None]:
generated_cells = generate_cells(model_d, cond_classes)

In [None]:
adata_generated = decode_cell_state_embedding(model_ce, generated_cells, ['Qiu_Organogenesis_MM_2022:all'] * generated_cells.shape[0])
adata_generated.obs['cell_type'] = cond_classes
adata_generated

In [None]:
sc.pp.neighbors(adata_generated, use_rep='X_ce_latent', n_neighbors=30)
sc.tl.umap(adata_generated)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)

sc.pl.umap(adata_generated, color='cell_type', 
           legend_loc='on data', legend_fontsize=5, ax=ax, frameon=False, show=False)

plt.savefig(os.path.join(output_path, 'umap_generated_cells_gastrulation.pdf'))

In [None]:
from scmg.model.causal_prediction import CausalGenePredictor

gene_stats_df = pd.read_parquet(
    '/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/global_marker_genes/global_gene_stats.parquet')

causal_gene_predictor = CausalGenePredictor(adata_pert, 
                        gene_stats_df.loc[adata_pert.var.index.values]['std'].values)

# Endoderm

In [None]:
source_cts = ['Epiblast']
target_cts = [
    #'Nascent mesoderm',
    'Anterior primitive streak',
    #'Rostral neuroectoderm',
    #'Caudal neuroectoderm',
]

source_cell_mask = adata_generated.obs['cell_type'].isin(source_cts)
target_cell_mask = adata_generated.obs['cell_type'].isin(target_cts)

source_mean_exp = adata_generated.X[source_cell_mask].mean(axis=0)
target_mean_exp = adata_generated.X[target_cell_mask].mean(axis=0)

pert_match_df = causal_gene_predictor.calc_causal_scores(target_mean_exp - source_mean_exp)
pert_match_df

In [None]:
pert_match_df= pert_match_df.sort_values('causal_score', ascending=False)
pert_match_df = pert_match_df.drop_duplicates('perturbed_gene', keep='first')

In [None]:
pert_match_df.sort_values('causal_score', ascending=False)[:20]

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)

ax.axhline(c='grey', lw=0.5)
ax.axvline(c='grey', lw=0.5)
ax.scatter(pert_match_df['gene_shift_z'], pert_match_df['pert_match_score'], s=1, rasterized=True)

label_df = pert_match_df[pert_match_df['perturbed_gene_name'].isin(
    ['GSC', 'EOMES', 'LHX1', 'TBXT', 'POU5F1', 'FOXA2', 'SOX2'])]
for i, row in label_df.iterrows():
    plt.text(row['gene_shift_z'], row['pert_match_score'], row['perturbed_gene_name'], 
             fontsize=8)
    
ax.set_xlabel('Gene expression shift (z-score)')
ax.set_ylabel('Perturbation match score')
ax.set_title('Anterior primitive streak')
fig.savefig(os.path.join(output_path, 'scatter_causal_genes_anterior_primitive_streak.pdf'))

In [None]:
pert_match_df[pert_match_df['perturbed_gene_name'].isin(['POU5F1', 'SOX2', 'EOMES', 'FOXA2', 'LHX1', 'TBXT'])]

In [None]:
# Plot the relationship between causal genes and perturbation shifts
pert_id = 'knockTF_human_DataSet_01_350'

x = target_mean_exp - source_mean_exp
y = adata_pert[pert_id].X[0]

x = x / causal_gene_predictor.gene_stds
y = y / causal_gene_predictor.gene_stds

match_scores = np.abs(x * y)

# Plot
plt.axhline(c='grey', lw=0.5)
plt.axvline(c='grey', lw=0.5)

for i in np.argsort(-match_scores)[:10]:
    plt.text(x[i], y[i], causal_gene_predictor.adata_pert.var['gene_name'].iloc[i], fontsize=8)

plt.scatter(x, y, s=1)

In [None]:
adata_named = adata_generated.copy()
adata_named.var.index = gene_name_mapper.map_gene_names(
    adata_named.var.index.values, 'human', 'human', 'id', 'name')

In [None]:
genes_to_plot = ['FOXA2']

sc.pl.umap(adata_named, color=genes_to_plot, vmax=None, cmap='viridis')

# Mesoderm

In [None]:
source_cts = ['Epiblast']
target_cts = [
    'Nascent mesoderm',
    #'Anterior primitive streak',
    #'Rostral neuroectoderm',
    #'Caudal neuroectoderm',
]

source_cell_mask = adata_generated.obs['cell_type'].isin(source_cts)
target_cell_mask = adata_generated.obs['cell_type'].isin(target_cts)

source_mean_exp = adata_generated.X[source_cell_mask].mean(axis=0)
target_mean_exp = adata_generated.X[target_cell_mask].mean(axis=0)

pert_match_df = causal_gene_predictor.calc_causal_scores(target_mean_exp - source_mean_exp)
pert_match_df

In [None]:
pert_match_df= pert_match_df.sort_values('causal_score', ascending=False)
pert_match_df = pert_match_df.drop_duplicates('perturbed_gene', keep='first')

In [None]:
pert_match_df.sort_values('causal_score', ascending=False)[:20]

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)

ax.axhline(c='grey', lw=0.5)
ax.axvline(c='grey', lw=0.5)
ax.scatter(pert_match_df['gene_shift_z'], pert_match_df['pert_match_score'], s=1, rasterized=True)

label_df = pert_match_df[pert_match_df['perturbed_gene_name'].isin(
    ['EOMES', 'TBXT', 'EVX1', 'POU5F1', 'GSC', 'SNAI1', 'SOX2', 'ETV2', 'LHX1', 'OTX2'])]
for i, row in label_df.iterrows():
    plt.text(row['gene_shift_z'], row['pert_match_score'], row['perturbed_gene_name'], 
             fontsize=8)
    
ax.set_xlabel('Gene expression shift (z-score)')
ax.set_ylabel('Perturbation match score')
ax.set_title('Nascent mesoderm')
fig.savefig(os.path.join(output_path, 'scatter_causal_genes_nascent_mesoderm.pdf'))

In [None]:
pert_match_df[pert_match_df['perturbed_gene_name'].isin([
    'POU5F1', 'SOX2', 'EOMES', 'ETV2', 'TBXT', 'EVX1', 'SNAI1'
])]

In [None]:
# Plot the relationship between causal genes and perturbation shifts
pert_id = 'PertOrg_Pertg09141'

x = target_mean_exp - source_mean_exp
y = adata_pert[pert_id].X[0]

x = x / causal_gene_predictor.gene_stds
y = y / causal_gene_predictor.gene_stds

match_scores = np.abs(x * y)

# Plot
plt.axhline(c='grey', lw=0.5)
plt.axvline(c='grey', lw=0.5)

for i in np.argsort(-match_scores)[:10]:
    plt.text(x[i], y[i], causal_gene_predictor.adata_pert.var['gene_name'].iloc[i], fontsize=8)

plt.scatter(x, y, s=1)

In [None]:
adata_named = adata_generated.copy()
adata_named.var.index = gene_name_mapper.map_gene_names(
    adata_named.var.index.values, 'human', 'human', 'id', 'name')

In [None]:
genes_to_plot = ['OTX2']

sc.pl.umap(adata_named, color=genes_to_plot, vmax=None, cmap='viridis')

# Ectoderm

In [None]:
source_cts = ['Epiblast']
target_cts = [
    #'Nascent mesoderm',
    #'Anterior primitive streak',
    'Rostral neuroectoderm',
    #'Caudal neuroectoderm',
]

source_cell_mask = adata_generated.obs['cell_type'].isin(source_cts)
target_cell_mask = adata_generated.obs['cell_type'].isin(target_cts)

source_mean_exp = adata_generated.X[source_cell_mask].mean(axis=0)
target_mean_exp = adata_generated.X[target_cell_mask].mean(axis=0)

pert_match_df = causal_gene_predictor.calc_causal_scores(target_mean_exp - source_mean_exp)
pert_match_df

In [None]:
pert_match_df= pert_match_df.sort_values('causal_score', ascending=False)
pert_match_df = pert_match_df.drop_duplicates('perturbed_gene', keep='first')

In [None]:
pert_match_df.sort_values('causal_score', ascending=False)[:20]

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)

ax.axhline(c='grey', lw=0.5)
ax.axvline(c='grey', lw=0.5)
ax.scatter(pert_match_df['gene_shift_z'], pert_match_df['pert_match_score'], s=1, rasterized=True)

label_df = pert_match_df[pert_match_df['perturbed_gene_name'].isin(
    ['HESX1', 'POU5F1', 'FOXH1', 'DLX5', 'NANOG', 'LMO1', 'FOXB1', 'LHX5', 'SOX2'])]
for i, row in label_df.iterrows():
    plt.text(row['gene_shift_z'], row['pert_match_score'], row['perturbed_gene_name'], 
             fontsize=8)
    
ax.set_xlabel('Gene expression shift (z-score)')
ax.set_ylabel('Perturbation match score')
ax.set_title('Rostral neuroectoderm')
fig.savefig(os.path.join(output_path, 'scatter_causal_genes_rostral_neuroectoderm.pdf'))

In [None]:
pert_match_df[pert_match_df['perturbed_gene_name'].isin([
    'POU5F1', 'NANOG', 'FOXH1', 'SOX2', 'DLX5'
])]

In [None]:
# Plot the relationship between causal genes and perturbation shifts
pert_id = 'hESC_TF_screen_DLX5'

x = target_mean_exp - source_mean_exp
y = adata_pert[pert_id].X[0]

x = x / causal_gene_predictor.gene_stds
y = y / causal_gene_predictor.gene_stds

match_scores = np.abs(x * y)

# Plot
plt.axhline(c='grey', lw=0.5)
plt.axvline(c='grey', lw=0.5)

for i in np.argsort(-match_scores)[:10]:
    plt.text(x[i], y[i], causal_gene_predictor.adata_pert.var['gene_name'].iloc[i], fontsize=8)

plt.scatter(x, y, s=1)

In [None]:
adata_named = adata_generated.copy()
adata_named.var.index = gene_name_mapper.map_gene_names(
    adata_named.var.index.values, 'human', 'human', 'id', 'name')

In [None]:
genes_to_plot = ['DLX5']

sc.pl.umap(adata_named, color=genes_to_plot, vmax=None, cmap='viridis')