In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc

import torch

from scmg.preprocessing.data_standardization import GeneNameMapper, standardize_adata
from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_adata, score_marker_genes

gene_name_mapper = GeneNameMapper()

In [None]:
sc.set_figure_params(vector_friendly=True, dpi_save=300)

plot_output_path = 'hesc_sc_analysis_plots'
os.makedirs(plot_output_path, exist_ok=True)

In [None]:
# Load the autoencoder model
model_ce_path = '../../contrastive_embedding/trained_embedder/'

model = torch.load(os.path.join(model_ce_path, 'model.pt'))
model.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model.to(device)
model.eval()

In [None]:
adata_ct_ref = sc.read_h5ad('../../manifold_generator/ref_cell_adata.h5ad')
adata_ct_ref

In [None]:
adata = sc.read_h5ad('/GPUData_xingjie/SCMG/hESC_perturb_seq/adata_single_gene_pert.h5ad')
adata

In [None]:
l2_obs_df = pd.read_csv('adata_obs_l2.csv', index_col=0)
adata.obs['leiden_l1'] = l2_obs_df['leiden_l1'].astype(str)
adata.obs['cluster'] = l2_obs_df['cluster'].astype(str)
adata.obsm['X_umap'] = l2_obs_df.loc[adata.obs.index][['umap_x', 'umap_y']].values
adata

In [None]:
adata.var.index = gene_name_mapper.map_gene_names(adata.var.index,
                                'human', 'human', 'name', 'id')

adata = adata[:, adata.var_names != 'na']
adata

In [None]:
# Only map the targeting clusters
targeting_clusters = ['10_0', '11_0', '12_0', '14_0', '14_1', '14_2', '14_3', '14_4',
       '14_5', '15_0', '15_1', '15_10', '15_11', '15_12', '15_13',
       '15_14', '15_2', '15_3', '15_4', '15_6', '15_7', '15_8', '15_9',
       '18_0', '18_1', '18_10', '18_11', '18_12', '18_2', '18_3', '18_4',
       '18_5', '18_6', '18_7', '18_8', '18_9', '19_1', '20_0', '20_1',
       '20_10', '20_11', '20_2', '20_3', '20_4', '20_5', '20_6', '20_7',
       '20_8', '20_9', '21_1', '23_0', '23_1', '24_0', '24_1', '24_2',
       '25_0', '25_1', '25_2', '25_3', '25_4', '26_0', '27_0', '28_0',
       '28_1',
       #'16_0', '16_1', '16_2', '16_3', '16_4', '16_5', '16_6', '16_7', '25_3', 
       ]

adata = adata[adata.obs['cluster'].isin(targeting_clusters)].copy()
adata

In [None]:
embed_adata(model, adata, batch_size=8192)

In [None]:
from scmg.model.cell_type_search import CellTypeSearcher

cts = CellTypeSearcher(adata_ct_ref)

In [None]:
cell_match_df = cts.search_ref_cell(adata.obsm['X_ce_latent'])

adata.obsm['X_project_umap'] = cell_match_df[['umap_x', 'umap_y']].values
# Add a small random shift to the UMAP coordinates to avoid overlapping points
adata.obsm['X_project_umap'] += np.clip(np.random.normal(0, 0.2, 
                                adata.obsm['X_project_umap'].shape), -0.5, 0.5)

adata.obs['project_dist'] = cell_match_df['distance'].values
adata.obs['ref_cell'] = cell_match_df['ref_cell'].values
adata.obs['projected_cell_type'] = adata_ct_ref.obs['cell_type'].loc[
                                            cell_match_df['ref_cell']].values

In [None]:
sc.pl.umap(adata_ct_ref, color='major_cell_type')

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sc.pl.umap(adata_ct_ref, ax=ax, show=False)
sc.pl.embedding(adata, basis='X_project_umap', color='project_dist', ax=ax,
                s=0.1)

In [None]:
import seaborn as sns
fig, ax = plt.subplots(figsize=(30, 4))
sns.boxplot(data=adata.obs, x='cluster', y='project_dist', ax=ax,
            showfliers=False)

In [None]:
adata.obs['cluster'].value_counts()

In [None]:
adata[(adata.obs['cluster'] == '15_0')
      #&(adata.obs['project_dist'] < 2)
      ].obs['projected_cell_type'].value_counts()[:30]

In [None]:
cross_df = pd.crosstab(adata.obs['cluster'], adata.obs['projected_cell_type'])
cross_frac_df = cross_df.div(cross_df.sum(axis=1), axis=0)

ct_frac_maxes = cross_frac_df.max(axis=0)
selected_cts = ct_frac_maxes[ct_frac_maxes > 0.1].index
cross_frac_df = cross_frac_df.loc[:, selected_cts]

In [None]:
import seaborn as sns

sns.clustermap(cross_frac_df, cmap='gnuplot', figsize=(15, 25), row_cluster=False)

In [None]:
cluster_annotation_map = {
'0_0' : 'non-targeting enriched',
'10_0' : 'non-targeting like',
'11_0' : 'upregulation of lipid biosynthesis',
'12_0' : 'upregulation of stress response',
'13_0' : 'non-targeting enriched',
'14_0' : 'pert cell cycle',
'14_1' : 'pert spliceosome',
'14_2' : 'pert mRNA-3 processing',
'14_3' : 'pert mRNA transcription',
'14_4' : 'pert mRNA transcription',
'14_5' : 'pert mRNA transcription',
'15_0' : 'germ layer differentiation_15_0',
'15_1' : 'germ layer differentiation_15_1',
'15_10' : 'germ layer differentiation_15_10',
'15_11' : 'germ layer differentiation_15_11',
'15_12' : 'germ layer differentiation_15_12',
'15_13' : 'germ layer differentiation_15_13',
'15_14' : 'mesenchymal differentiation_15_14',
'15_2' : 'germ layer differentiation_15_2',
'15_3' : 'germ layer differentiation_15_3',
'15_4' : 'germ layer differentiation_15_4',
'15_5' : 'non-targeting enriched',
'15_6' : 'germ layer differentiation_15_6',
'15_7' : 'germ layer differentiation_15_7',
'15_8' : 'germ layer differentiation_15_8',
'15_9' : 'germ layer differentiation_15_9',
'16_0' : 'low UMI count',
'16_1' : 'low UMI count',
'16_2' : 'low UMI count',
'16_3' : 'low UMI count',
'16_4' : 'low UMI count',
'16_5' : 'low UMI count',
'16_6' : 'low UMI count',
'16_7' : 'pert DBR1',
'17_0' : 'non-targeting enriched',
'18_0' : 'pert translation',
'18_1' : 'pert mTOR signaling',
'18_10' : 'pert translation',
'18_11' : 'pert translation',
'18_12' : 'pert translation',
'18_2' : 'pert translation',
'18_3' : 'pert translation',
'18_4' : 'pert translation',
'18_5' : 'pert mTOR signaling',
'18_6' : 'pert translation',
'18_7' : 'pert translation',
'18_8' : 'pert translation',
'18_9' : 'pert translation',
'19_0' : 'non-targeting enriched',
'19_1' : 'non-targeting like',
'1_0' : 'non-targeting enriched',
'20_0' : 'pert mRNA transcription',
'20_1' : 'pert mRNA transcription',
'20_10' : 'pert mRNA deadenylation',
'20_11' : 'pert mRNA transcription',
'20_2' : 'pert GNB2L1',
'20_3' : 'pert mRNA deadenylation',
'20_4' : 'pert mRNA transcription',
'20_5' : 'pert mRNA transcription',
'20_6' : 'pert mRNA transcription',
'20_7' : 'pert mRNA transcription',
'20_8' : 'pert mRNA transcription',
'20_9' : 'pert mRNA transcription',
'21_0' : 'non-targeting enriched',
'21_1' : 'pert DBR1',
'22_0' : 'non-targeting enriched',
'23_0' : 'pert ubiquitin E3 ligase',
'23_1' : 'pert protein neddylation',
'24_0' : 'low mito-genes',
'24_1' : 'low mito-genes',
'24_2' : 'upregulation of stress response',
'25_0' : 'mesenchymal differentiation_25_0',
'25_1' : 'mesenchymal differentiation_25_1',
'25_2' : 'mesenchymal differentiation_25_2',
'25_3' : 'low UMI count',
'25_4' : 'mesenchymal differentiation_25_4',
'26_0' : 'pert DBR1',
'27_0' : 'pert RNA methylation',
'28_0' : 'pert DNA damage checkpoint',
'28_1' : 'pert DNA damage checkpoint',
'2_0' : 'non-targeting enriched',
'3_0' : 'non-targeting enriched',
'4_0' : 'non-targeting enriched',
'5_0' : 'non-targeting enriched',
'6_0' : 'non-targeting enriched',
'7_0' : 'non-targeting enriched',
'8_0' : 'non-targeting enriched',
'9_0' : 'non-targeting enriched',
'9_1' : 'non-targeting enriched',
}

anno_cross_df = pd.crosstab(adata.obs['cluster'].map(cluster_annotation_map), 
                       adata.obs['projected_cell_type'])
anno_cross_frac_df = anno_cross_df.div(anno_cross_df.sum(axis=1), axis=0)

anno_cross_frac_df = anno_cross_frac_df.loc[:,
                [c for c in ['Epiblast', 'embryonic stem cell',
                 'Anterior', 'Early ectoderm', 
                 #'Primordial germ cells', 'primordial germ cell', 'Lens epithelial cells',
                 'Endoderm',  'Gut', 
                 'Emergent Mesoderm', 'mesodermal cell', 'Advanced Mesoderm', 'Cardiac mesoderm',
                 'Vascular smooth muscle', 'Early fibroblasts', 'fibroblast',
                  'chondrocyte', 'smooth muscle cell', 'embryonic fibroblast','uterine smooth muscle cell',
                  ]
                 if c in anno_cross_frac_df.columns]]

anno_cross_frac_df = anno_cross_frac_df.loc[
                [
    'non-targeting like',
    'pert ubiquitin E3 ligase',
    'upregulation of lipid biosynthesis',
    'upregulation of stress response',
    'low mito-genes',
    'pert translation',
    'pert mTOR signaling',
    'pert cell cycle',
    'pert spliceosome',
    'pert mRNA-3 processing',
    'pert mRNA transcription',
    'pert GNB2L1',
    'pert mRNA deadenylation',
    'pert DNA damage checkpoint',
    'pert RNA methylation',
    'pert DBR1',
    'germ layer differentiation_15_0',
    'germ layer differentiation_15_12',
    'germ layer differentiation_15_9',
    'germ layer differentiation_15_10',
    'germ layer differentiation_15_11',
    'germ layer differentiation_15_7',
    'germ layer differentiation_15_1',
    'germ layer differentiation_15_2',
    'germ layer differentiation_15_8',
    'germ layer differentiation_15_13',
    'germ layer differentiation_15_3',
    'germ layer differentiation_15_6',
    'germ layer differentiation_15_4',
    'mesenchymal differentiation_15_14',
    'mesenchymal differentiation_25_1',
    'mesenchymal differentiation_25_0',
    'mesenchymal differentiation_25_2',
    'mesenchymal differentiation_25_4',
                 ],
                  :]

fig, ax = plt.subplots(figsize=(15, 10))
g = sns.heatmap(anno_cross_frac_df, cmap='viridis', cbar_kws={'label': 'Fraction of cells'}, ax=ax, vmax=0.6)
g.set_xticklabels(g.get_xticklabels(), rotation=-30)
fig.savefig(os.path.join(plot_output_path, 'project_pert_clusters_to_global_manifold_heatmap.pdf'))

In [None]:
cluster_color_df = pd.read_csv('cluster_colors.csv')
cluster_name_color_map = {cluster_annotation_map[k]: v for k, v in zip(cluster_color_df['cluster'], cluster_color_df['color'])}
cluster_name_color_map

In [None]:
sc.pl.umap(adata_ct_ref, color='major_cell_type')

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sc.pl.umap(adata_ct_ref, ax=ax, show=False)

adata.obs['cluster_name'] = adata.obs['cluster'].map(cluster_annotation_map)
adata_differentiation = adata[adata.obs['cluster_name'].isin([
    'germ layer differentiation_15_0',
    'germ layer differentiation_15_12',
    'germ layer differentiation_15_9',
    'germ layer differentiation_15_10',
    'germ layer differentiation_15_11',
    'germ layer differentiation_15_7',
    'germ layer differentiation_15_1',
    'germ layer differentiation_15_2',
    'germ layer differentiation_15_8',
    'germ layer differentiation_15_13',
    'germ layer differentiation_15_3',
    'germ layer differentiation_15_6',
    'germ layer differentiation_15_4',
    'mesenchymal differentiation_15_14',
    'mesenchymal differentiation_25_1',
    'mesenchymal differentiation_25_0',
    'mesenchymal differentiation_25_2',
    'mesenchymal differentiation_25_4',])].copy()

adata_differentiation.obs['cluster_name'] = adata_differentiation.obs['cluster_name'].astype('category')
adata_differentiation.uns['cluster_name_colors'] = [cluster_name_color_map[c] for c in adata_differentiation.obs['cluster_name'].cat.categories]

sc.pl.embedding(adata_differentiation, basis='X_project_umap', color='cluster_name', ax=ax,
                s=1)

fig.savefig(os.path.join(plot_output_path, 'project_pert_clusters_to_global_manifold_umap.pdf'))

In [None]:
cluster_groups = {
    'non_target_like' : ['0_0', ],
    'development' : ['15_0'],
    'orthogonal' : ['18_0'], 
}

cluster_to_cluster_group_map = {cluster : group for group, clusters in cluster_groups.items() for cluster in clusters}
adata.obs['cluster_group'] = adata.obs['cluster'].map(cluster_to_cluster_group_map)

cluster_group_counts = adata.obs['cluster_group'].value_counts()

fig, ax = plt.subplots(figsize=(3, 3))
ax.pie(cluster_group_counts, labels=cluster_group_counts.index, autopct='%1.1f%%')
plt.show()

In [None]:
adata_ct_ref.obsm['X_ce_latent'] = adata_ct_ref.X
adata_ref_decoded = decode_adata(model, adata_ct_ref, adata_ct_ref.obs['dataset_id'])
adata_ref_decoded.var['gene_name'] = gene_name_mapper.map_gene_names(
    adata_ref_decoded.var.index, 'human', 'human', 'id', 'name')

adata_ref_named = adata_ref_decoded.copy()
adata_ref_named.var.index = adata_ref_decoded.var['gene_name']
adata_ref_named = adata_ref_named[:, adata_ref_named.var.index != 'na'].copy()
adata_ref_named.var_names_make_unique()

adata_named = adata.copy()
adata_named.var.index = gene_name_mapper.map_gene_names(
    adata_named.var.index, 'human', 'human', 'id', 'name')
adata_named.var_names_make_unique()


common_genes = np.intersect1d(adata_named.var.index, adata_ref_named.var.index)
adata_named = adata_named[:, common_genes].copy()
adata_ref_named = adata_ref_named[:, common_genes].copy()

sc.pp.normalize_total(adata_named, target_sum=1e4)
sc.pp.log1p(adata_named)

In [None]:
adata_ref_decoded.var['mean'] = adata_ref_decoded.X.mean(axis=0)
adata_ref_decoded.var['std'] = adata_ref_decoded.X.std(axis=0)

In [None]:
projected_ct = 'Primordial germ cells'
selected_cluster = '16_2'

adata_selecte_query = standardize_adata(adata[
    (adata.obs['projected_cell_type'] == projected_ct) & (adata.obs['cluster'] == selected_cluster)
].copy())

sc.pp.normalize_total(adata_selecte_query, target_sum=1e4)
sc.pp.log1p(adata_selecte_query)

adata_selected_ref = adata_ref_decoded[
    adata_ref_decoded.obs['cell_type'] == projected_ct
].copy()
#adata_selected_ref = adata_ref_decoded[
#    list(adata_selecte_query.obs['ref_cell'].values)
#].copy()

x_query = adata_selecte_query.X.mean(axis=0) - adata_ref_decoded.var['mean']
x_query = x_query / np.maximum(adata_ref_decoded.var['std'], 0.1)
x_ref = adata_selected_ref.X.mean(axis=0) - adata_ref_decoded.var['mean']
x_ref = x_ref / np.maximum(adata_ref_decoded.var['std'], 0.1)

In [None]:
#adata_selecte_query1 = standardize_adata(adata[
#    (adata.obs['cluster'] == '3_0')
#].copy())
#
#sc.pp.normalize_total(adata_selecte_query, target_sum=1e4)
#sc.pp.log1p(adata_selecte_query)
#
#adata_selecte_query2 = standardize_adata(adata[
#    (adata.obs['cluster'] == '13_0')
#].copy())
#
#sc.pp.normalize_total(adata_selecte_query, target_sum=1e4)
#sc.pp.log1p(adata_selecte_query)
#
#x_query = adata_selecte_query1.X.mean(axis=0) - adata_ref_decoded.var['mean']
#x_query = x_query / np.maximum(adata_ref_decoded.var['std'], 0.1)
#x_ref = adata_selecte_query2.X.mean(axis=0) - adata_ref_decoded.var['mean']
#x_ref = x_ref / np.maximum(adata_ref_decoded.var['std'], 0.1)

In [None]:
plt.scatter(x_query, x_ref, s=1)

top_agree_gene_indices = np.argsort(- x_query * x_ref)[:10]
for i in top_agree_gene_indices:
    plt.text(x_query[i], x_ref[i], adata_ref_decoded.var['gene_name'].iloc[i],
             fontsize=8)
    
top_disagree_gene_indices = np.argsort(x_query * x_ref)[:5]
for i in top_disagree_gene_indices:
    plt.text(x_query[i], x_ref[i], adata_ref_decoded.var['gene_name'].iloc[i],
             fontsize=8, color='red')
    
plt.axhline(c='grey', lw=0.5)
plt.axvline(c='grey', lw=0.5)
plt.axline((0, 0), slope=1, color='grey', lw=0.5)

import scipy.stats
scipy.stats.pearsonr(x_query, x_ref)

In [None]:
gene_to_show = 'L1TD1'
sc.pl.umap(adata_ref_named, color=gene_to_show, cmap='inferno_r')

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sc.pl.umap(adata_ct_ref, ax=ax, show=False)

sc.pl.embedding(
    adata_named[(adata_named.obs['projected_cell_type'] == projected_ct) & (adata.obs['cluster'] == selected_cluster)],
    basis='X_project_umap', color=gene_to_show, ax=ax,
                cmap='inferno_r', s=10)