In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc

import torch

from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_adata, score_marker_genes
from scmg.preprocessing.data_standardization import GeneNameMapper, standardize_adata

gene_name_mapper = GeneNameMapper()

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
output_path = 'Treutlein_2016_neuron_transdifferentiation_plots'
os.makedirs(output_path, exist_ok=True)

In [None]:
# Load the autoencoder model

model_path = '../../../contrastive_embedding/trained_embedder'

model = torch.load(os.path.join(model_path, 'model.pt'))
model.load_state_dict(torch.load(os.path.join(model_path, 'best_state_dict.pth')))

device = 'cuda'
model.to(device)
model.eval()

In [None]:
adata = sc.read_h5ad(
    '/GPUData_xingjie/SCMG/perturbation_trajectories/Treutlein_2016_neuron_transdifferentiation.h5ad')

adata.obs_names_make_unique()
adata.var.index = gene_name_mapper.map_gene_names(adata.var.index,
                            'mouse', 'human', 'name', 'id')
adata = adata[:, adata.var.index != 'na'].copy()
adata.var_names_make_unique()

embed_adata(model, adata, batch_size=4096)
adata

In [None]:
adata = adata[adata.obs['assignment'].isin([
    'MEF', 'Myocyte', 'Neuron',
    'd2_induced', 'd2_intermediate', 'd5_earlyiN',
    'd5_intermediate'
])].copy()

In [None]:
adata_ct_ref = sc.read_h5ad('../../../manifold_generator/ref_cell_adata.h5ad')
adata_ct_ref

In [None]:
adata_measured_ref = sc.read_h5ad('../../../manifold_generator/ref_cell_adata_measured_count.h5ad')
sc.pp.normalize_total(adata_measured_ref, target_sum=1e4)
sc.pp.log1p(adata_measured_ref)
adata_measured_ref

In [None]:
from scmg.model.cell_type_search import CellTypeSearcher

cts = CellTypeSearcher(adata_ct_ref)

In [None]:
cell_match_df = cts.search_ref_cell(adata.obsm['X_ce_latent'])

adata.obsm['X_project_umap'] = cell_match_df[['umap_x', 'umap_y']].values
# Add a small random shift to the UMAP coordinates to avoid overlapping points
adata.obsm['X_project_umap'] += np.clip(np.random.normal(0, 0.2, 
                                adata.obsm['X_project_umap'].shape), -0.5, 0.5)

adata.obs['project_dist'] = cell_match_df['distance'].values
adata.obs['ref_cell'] = cell_match_df['ref_cell'].values
adata.obs['projected_cell_type'] = adata_ct_ref.obs['cell_type'].loc[
                                            cell_match_df['ref_cell']].values

In [None]:
projected_cell_count_map = adata.obs['ref_cell'].value_counts().to_dict()
adata_ct_ref.obs['projected_cell_count'] = adata_ct_ref.obs.index.map(
    lambda x: projected_cell_count_map.get(x, 0))

In [None]:
sc.pl.umap(adata_ct_ref, color='major_cell_type')
sc.pl.umap(adata_ct_ref, color='projected_cell_count', vmax=20, 
           cmap='inferno_r', alpha=1, s=20)

In [None]:
sc.pl.umap(adata_ct_ref, color='major_cell_type')

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=300)
sc.pl.umap(adata_ct_ref, ax=ax, show=False)
sc.pl.embedding(adata, 
                basis='X_project_umap', color='assignment', 
                ax=ax, s=10)
fig.savefig(os.path.join(output_path, 'umap_projected_cells.pdf'))

In [None]:
adata_ref_named = adata_measured_ref.copy()
adata_ref_named.var.index = adata_measured_ref.var['human_gene_name']
adata_ref_named = adata_ref_named[:, adata_ref_named.var.index != 'na'].copy()
adata_ref_named.var_names_make_unique()

adata_named = adata.copy()
adata_named.var.index = gene_name_mapper.map_gene_names(
    adata_named.var.index, 'human', 'human', 'id', 'name')
adata_named.var_names_make_unique()


common_genes = np.intersect1d(adata_named.var.index, adata_ref_named.var.index)
adata_named = adata_named[:, common_genes].copy()
adata_ref_named = adata_ref_named[:, common_genes].copy()

sc.pp.normalize_total(adata_named, target_sum=1e4)
sc.pp.log1p(adata_named)

In [None]:
sc.pp.highly_variable_genes(adata_named, n_top_genes=2000)
adata_named_hvg = adata_named[:, adata_named.var['highly_variable']].copy()
sc.pp.scale(adata_named_hvg)
sc.pp.pca(adata_named_hvg, n_comps=50)

sc.pp.neighbors(adata_named_hvg, n_neighbors=20)
sc.tl.umap(adata_named_hvg)

In [None]:
#sc.pp.neighbors(adata_named, use_rep='X_ce_latent', n_neighbors=20)
#sc.tl.umap(adata_named)

adata_named.obsm['X_umap'] = adata_named_hvg.obsm['X_umap']

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3), dpi=300)
sc.pl.umap(adata_named, color=['assignment'], s=50, ax=ax)
fig.savefig(os.path.join(output_path, 'umap_transdiff_cells.pdf'))

In [None]:
ct_confusion_df = pd.crosstab(adata.obs['assignment'], adata.obs['projected_cell_type'])
ct_confusion_df = ct_confusion_df / ct_confusion_df.values.sum(axis=1)[:, None]

for ct_query in ct_confusion_df.index:
    display(ct_confusion_df.loc[ct_query].sort_values(ascending=False).head(5))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5.5, 4), dpi=300)

sns.heatmap(ct_confusion_df.loc[
    [
        'MEF',
        'd2_intermediate',
        'd5_intermediate',
        'd2_induced',
        'd5_earlyiN',
        'Neuron',
        'Myocyte',

     ], 
    [
        'embryonic fibroblast',
        'Early fibroblasts',
        #'Pituitary gland cells',
        'Midbrain dopaminergic',
        'Hindbrain serotoninergic',
        'cell of skeletal muscle',

     ]],
    cmap='viridis')

fig.savefig(os.path.join(output_path, 'cell_type_confusion_matrix.pdf'))

In [None]:
adata_measured_ref.var['mean'] = adata_measured_ref.X.mean(axis=0)
adata_measured_ref.var['std'] = adata_measured_ref.X.std(axis=0)

In [None]:
query_ct = 'Myocyte'
projected_ct = 'cell of skeletal muscle'

adata_selected_query = standardize_adata(adata[
    (adata.obs['assignment'] == query_ct) & (adata.obs['projected_cell_type'] == projected_ct)
    #(adata.obs['assignment'] == projected_ct])
].copy())

sc.pp.normalize_total(adata_selected_query, target_sum=1e4)
sc.pp.log1p(adata_selected_query)

adata_selected_ref = adata_measured_ref[
    adata_measured_ref.obs['cell_type'] == projected_ct
].copy()

x_query = adata_selected_query.X.mean(axis=0) - adata_measured_ref.var['mean']
x_query = x_query / np.maximum(adata_measured_ref.var['std'], 0.1)
x_ref = adata_selected_ref.X.mean(axis=0) - adata_measured_ref.var['mean']
x_ref = x_ref / np.maximum(adata_measured_ref.var['std'], 0.1)

comp_df = pd.DataFrame({
    'gene': list(adata_measured_ref.var['human_gene_name']),
    'x_query': x_query,
    'x_ref': x_ref,
}).set_index('gene')

comp_df['x_diff'] = comp_df['x_query'] - comp_df['x_ref']

# Only keep the genes measured in the query dataset
comp_df = comp_df[comp_df.index.isin(
    gene_name_mapper.map_gene_names(adata.var.index, 'human', 'human', 'id', 'name'))].copy()

In [None]:
# Find the matched and unmatched genes
double_pos_genes = comp_df[(comp_df['x_query'] > 2) & (comp_df['x_ref'] > 2)
                           ].sort_values('x_ref', ascending=False)[:4].index.values

q_pos_r_neg_genes = comp_df[(comp_df['x_query'] > 3) & (comp_df['x_ref'] < 1)
                           ].sort_values('x_diff', ascending=False)[:4].index.values

q_neg_r_pos_genes = comp_df[(comp_df['x_query'] < 1) & (comp_df['x_ref'] > 3)
                            ].sort_values('x_diff', ascending=True)[:4].index.values

print('Double positive genes:')
display(double_pos_genes)
print('Query positive, ref negative genes:')
display(q_pos_r_neg_genes)
print('Query negative, ref positive genes:')
display(q_neg_r_pos_genes)

q_neg_r_pos_genes = np.array(['MYOZ1', 'ATP2A1', 'MYBPC2', 'MYOG'])


In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)

ax.scatter(comp_df['x_query'], comp_df['x_ref'], s=1, rasterized=True)

comp_df_to_show = comp_df[comp_df.index.isin(
    np.concatenate([double_pos_genes, q_pos_r_neg_genes, q_neg_r_pos_genes]))]

for i, row in comp_df_to_show.iterrows():
    ax.text(row['x_query'], row['x_ref'], i, fontsize=8)

ax.axhline(c='grey', lw=0.5)
ax.axvline(c='grey', lw=0.5)
ax.set_xlabel(query_ct + ' normalized expression', fontsize=10)
ax.set_ylabel(projected_ct +  ' normalized expression', fontsize=10)

fig.savefig(os.path.join(output_path, f'{query_ct}_{projected_ct}_gene_expression_scatter.pdf'))
import scipy.stats
scipy.stats.pearsonr(x_query, x_ref)

In [None]:
genes_to_show = np.concatenate([double_pos_genes, q_pos_r_neg_genes, q_neg_r_pos_genes])
fig = sc.pl.umap(adata_ref_named, color=genes_to_show, vmax=None, cmap='inferno_r', return_fig=True)
fig.savefig(os.path.join(output_path, f'{query_ct}_{projected_ct}_gene_exp_global_umap.pdf'))
fig = sc.pl.umap(adata_named, color=genes_to_show, cmap='inferno_r', s=50, return_fig=True)
fig.savefig(os.path.join(output_path, f'{query_ct}_{projected_ct}_gene_exp_query_umap.pdf'))



In [None]:
fig = sc.pl.stacked_violin(adata_named, 
    [
    'SERPINE1', 'F3', 'FBLN2', 'TAGLN', 'ACTA2', 'PPIC', 'COL1A2', 'GPX8', 'PCOLCE', 
    'TUBB3', 'MLLT11', 'UCHL1', 'ZCCHC12', 'GAP43', 'SYT4', 'SYNGR3', 'STMN3', 'CALY', 'GNG3',
    'COX8C', 'TNNI1', 'PGAM2', 'ACTA1', 'ENO3', 'TNNT3', 'TNNC2', 'MYL11', 'MYL1', 
     ],
    #['COL1A2', 'LOX', 'SYT4', 'TUBB3', 'COX8C', 'TNNC2'], 
    groupby='assignment', vmax=3,
    categories_order=['MEF', 'd2_intermediate', 'd2_induced', 'd5_intermediate', 
        'd5_earlyiN', 'Neuron', 'Myocyte',],
    return_fig=True)

fig.savefig(os.path.join(output_path, 'marker_genes_violin.pdf'))

In [None]:
genes_to_show = ['PCOLCE', 'TUBB3', 'COX8C']
fig = sc.pl.umap(adata_ref_named, color=genes_to_show, vmax=None, cmap='inferno_r', return_fig=True)
fig.savefig(os.path.join(output_path, 'marker_genes_umap_global.pdf'))

fig = sc.pl.umap(adata_named, color=genes_to_show + ['assignment'], cmap='inferno_r', s=50, return_fig=True)
fig.savefig(os.path.join(output_path, 'marker_genes_umap_transdiff.pdf'))

In [None]:
ct_of_i = "Midbrain dopaminergic"
adata_ref_named.obs['ct_of_i'] = (adata_ref_named.obs['cell_type'] == ct_of_i).astype(int)
sc.pl.umap(adata_ref_named, color='ct_of_i', cmap='inferno_r')

In [None]:
sc.tl.rank_genes_groups(adata_named, groupby="assignment", method="wilcoxon")

sc.pl.rank_genes_groups_dotplot(
    adata_named, groupby="assignment", standard_scale="var", n_genes=5
)

In [None]:
genes_to_show = [
    #'S100A6', 'TAGLN2', 'SERPINE1', 'HMGA2', 'CNN2', 
    #'GOLT1B', 'PRDX1', 'PTGES3', 'CALU', 'SSR1',
    #'MYL1', 'ACTA1', 'TNNI1', 'TPM2', 'TNNC2',
    #'DPYSL2', 'PRKAR1B', 'GRIA2', 'SYT11', 'INPP5F',
    #'H3Y1', 'DNER', 'MARCKSL1', 'YWHAZ', 'SEC11C',
    #'PPIB', 'B2M', 'HES6', 'ARL6IP1', 'TMSB4Y',
    'COL3A1', 'ITM2A', 'GAS2', 'ABI3BP', 'POSTN'
    ]
sc.pl.umap(adata_ref_named, color=genes_to_show, vmax=None, cmap='inferno_r')