In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import anndata
import scanpy as sc

import torch

from scmg.preprocessing.data_standardization import GeneNameMapper, standardize_adata
from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_adata, score_marker_genes

gene_name_mapper = GeneNameMapper()

In [None]:
# Load the autoencoder model
model_ce_path = '../../contrastive_embedding/trained_embedder/'

model = torch.load(os.path.join(model_ce_path, 'model.pt'))
model.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model.to(device)
model.eval()

In [None]:
adata_ct_ref = sc.read_h5ad('../../manifold_generator/ref_cell_adata.h5ad')
adata_ct_ref

In [None]:
adata = sc.read_h5ad('/GPUData_xingjie/SCMG/hESC_perturb_seq/adata_single_gene_pert.h5ad')
adata.obs_names_make_unique()

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata

In [None]:
l2_obs_df = pd.read_csv('adata_obs_l2.csv', index_col=0)
adata.obs['leiden_l1'] = l2_obs_df['leiden_l1'].astype(str)
adata.obs['cluster'] = l2_obs_df['cluster'].astype(str)
adata.obsm['X_umap'] = l2_obs_df.loc[adata.obs.index][['umap_x', 'umap_y']].values
adata

In [None]:
sc.pl.umap(adata, color=['ANK2'], cmap='cool', vmax=None)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sc.pl.umap(adata, ax=ax, show=False)
sc.pl.umap(adata[adata.obs['perturbed_gene'].isin(['ELP5'])], 
           color='feature_call', ax=ax, show=False, s=20)

In [None]:
cluster_annotation_map = {
    '0_0' : 'non_target_like_0_0',
    '1_0' : 'non_target_like_1_0',
    '2_0' : 'non_target_like_2_0',
    '3_0' : 'non_target_like_3_0',
    '4_0' : 'non_target_like_4_0',
    '5_0' : 'non_target_like_5_0',
    '6_0' : 'non_target_like_6_0',
    '7_0' : 'non_target_like_7_0',
    '8_0' : 'non_target_like_8_0',
    '9_0' : 'non_target_like_9_0',
    '10_0' : 'mitochondrial_translation_10_0',
    '11_0' : 'mitochondrial_gene_expression_11_0',
    '11_1' : 'mitochondrion_organization_11_1',
    '11_2' : 'LAMTOR3_11_2',
    '12_0' : 'low_count_12_0',
    '12_1' : 'low_count_Golgi_vesicle_transport_12_1',
    '12_2' : 'non_target_like_12_2',
    '12_3' : 'low_count_12_3',
    '12_4' : 'low_count_12_4',
    '13_0' : 'exit_pluripotency_13_0',
    '13_1' : 'mesoderm_13_1',
    '13_2' : 'mesoderm_endoderm_13_2',
    '13_3' : 'cardiac_mesoderm_13_3',
    '13_4' : 'cardiac_mesoderm_13_4',
    '13_5' : 'anterior_neural_tube_13_5',
    '13_6' : 'mixed_germ_layer_13_6',
    '13_7' : 'ectoderm_13_7',
    '13_8' : 'endoderm_13_8',
    '13_9' : 'vascular_smooth_muscle_13_9',
    '13_10' : 'advanced_mesoderm_13_10',
    '13_11' : 'axial_mesoderm_13_11',
    '14_0' : 'multivesicular_body_assembly_14_0',
    '15_0' : 'low_count_DNA_damage_response_15_0',
    '15_1' : 'spindle_assembly_checkpoint_signaling_15_1',
    '15_2' : 'mRNA_processing_15_2',
    '15_3' : 'CCR4-NOT_complex_15_3',
    '15_4' : 'RNA_exosome_complex_15_4',
    '15_5' : 'nucleosome_disassembly_15_5',
    '16_0' : 'transcription_initiation_16_0',
    '16_1' : 'epigenetic_regulation_16_1',
    '16_2' : 'heterochromatin_formation_16_2',
    '17_0' : 'low_count_17_0',
    '17_1' : 'low_count_17_1',
    '18_0' : 'low_count_translational_initiation_18_0',
    '19_0' : 'translational_initiation_19_0',
    '20_0' : 'fibroblast_20_0',
    '20_1' : 'fibroblast_20_1',
    '20_2' : 'fibroblast_20_2',
    '20_3' : 'fibroblast_20_3',
    '21_0' : 'unknown_21_0',
    '22_0' : 'unknown_22_0',
}

cluster_groups = {
    'non_target_like' : ['0_0', '1_0', '2_0', '3_0', '4_0', '5_0', '6_0', '7_0', '8_0', '9_0', '12_2'],
    'development_aligned' : ['13_0', '13_1', '13_2', '13_3', '13_4', '13_5', '13_6', '13_7', '13_8', '13_9', '13_10', '13_11',
                         '20_0', '20_1', '20_2', '20_3'],
    'development_orthogonal' : ['10_0', '11_0', '11_1', '11_2', '12_0', '12_1', '12_3', '12_4',
                                '14_0', '15_0', '15_1', '15_2', '15_3', '15_4', '15_5', '16_0', '16_1', '16_2',
                                '17_0', '17_1', '18_0', '19_0', '21_0', '22_0'], 
}

adata.obs['cluster_name'] = adata.obs['cluster'].map(cluster_annotation_map)

In [None]:
adata_d = adata[adata.obs['cluster'].isin(cluster_groups['development_aligned'])].copy()

sc.pp.highly_variable_genes(adata_d, min_mean=0.0125, max_mean=3, min_disp=0.5)

adata_d.raw = adata_d.copy()
adata_d = adata_d[:, adata_d.var.highly_variable].copy()
sc.pp.scale(adata_d, max_value=10)
sc.tl.pca(adata_d, svd_solver='arpack')

sc.pp.neighbors(adata_d, n_neighbors=20)
sc.tl.umap(adata_d)

sc.pl.umap(adata_d, color=['cluster_name'])

In [None]:
sc.pl.umap(adata, color=['cluster_name'])

In [None]:
sc.pl.umap(adata, color=['cluster_name'])

In [None]:
adata_ct_ref.obsm['X_ce_latent'] = adata_ct_ref.X
adata_ref_decoded = decode_adata(model, adata_ct_ref, adata_ct_ref.obs['dataset_id'])
adata_ref_decoded.var['gene_name'] = gene_name_mapper.map_gene_names(
    adata_ref_decoded.var.index, 'human', 'human', 'id', 'name')

adata_ref_named = adata_ref_decoded.copy()
adata_ref_named.var.index = adata_ref_decoded.var['gene_name']
adata_ref_named = adata_ref_named[:, adata_ref_named.var.index != 'na'].copy()
adata_ref_named.var_names_make_unique()

In [None]:
gene_to_plot = ['GAL']

sc.pl.umap(adata_ref_named, color=gene_to_plot, cmap='inferno_r')
sc.pl.umap(adata, color=gene_to_plot, cmap='inferno_r')
sc.pl.umap(adata_d, color=gene_to_plot, cmap='inferno_r')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sc.pl.umap(adata_d, ax=ax, show=False)
sc.pl.umap(adata_d[adata_d.obs['perturbed_gene'].isin(gene_to_plot)], 
           color='feature_call', ax=ax, show=False, s=20)

In [None]:
gene_list = ['ARHGEF37', 'NANOG', 
       'PRDM14', 'SUPT20H', 'SOX2', 
       'ZFP90', 'ZNF396',
       'EIF3B', 'RAD18', 
       'DSEL', 'MPPED1', 'PET117',
       'POU5F1', 'DCTN5', 'ETF1', 'GRK4', 'SC5D',
       'KIAA0753', 'MBD5', 
       'ALS2', 'CPSF4', 'FBLN5',
       'CHIC2', 'CLTC', 'CUL1', 'HECTD4','MED19', 'OTUB2',
       'PGPEP1', 'RARA', 'SP1', 'TADA2B', 'UCMA', 'USP8',
       'TAF12', 
       'GEMIN5', 'CENPI', 'RPP14', 'PDCD11', 'ZC3H8', 
       'CCNH', 'DCTN2', 'FOXD3', 'MED22', 
       'SKA3', 'BRIX1',]

In [None]:
pgs = ['TAF12']

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sc.pl.umap(adata_d, ax=ax, show=False)
sc.pl.umap(adata_d[adata_d.obs['perturbed_gene'].isin(pgs)], 
           color='feature_call', ax=ax, show=False, s=20)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
sc.pl.umap(adata, ax=ax, show=False)
sc.pl.umap(adata[adata.obs['perturbed_gene'].isin(pgs)], 
           color='feature_call', ax=ax, show=False, s=20)