In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.spatial

import anndata
import scanpy as sc
import umap

import torch
from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_cell_state_embedding

from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
import matplotlib
import matplotlib.pyplot as plt

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)

In [None]:
output_path = 'marker_TF_plots'
os.makedirs(output_path, exist_ok=True)

In [None]:
adata = sc.read_h5ad('../ref_cell_adata_measured_count.h5ad')
adata.X = adata.X.astype(np.float32)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata

In [None]:
adata_named = adata.copy()
adata_named.var.index = list(adata_named.var['human_gene_name'])

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
indices = list(adata_named.obs.index.values)
np.random.shuffle(indices)

sc.pl.umap(adata_named[indices], color='cell_type', 
           legend_fontsize=0, legend_loc='on data',
           ax=ax, s=1, frameon=False, palette='tab20')
fig.savefig(os.path.join(output_path, 'umap_cell_type.pdf'))

In [None]:
all_gene_de_df = pd.read_parquet('cell_type_DE_scores.parquet')
all_gene_de_df

In [None]:
marker_gene_df = all_gene_de_df[all_gene_de_df['pval_adj'] < 0.01]
marker_gene_df = marker_gene_df[marker_gene_df['fc'] > 5]
marker_gene_df = marker_gene_df[marker_gene_df['foreground_exp_frac'] > 0.2]
marker_gene_df = marker_gene_df[marker_gene_df['foreground_exp_frac'] > 3 * marker_gene_df['background_exp_frac']].copy()
marker_gene_df = marker_gene_df.sort_values('fc', ascending=False)
marker_gene_df

In [None]:
tf_df = pd.read_csv('tf_genes_Tfome.csv')
tf_df

In [None]:
marker_tf_df = marker_gene_df[marker_gene_df['human_gene_name'].isin(tf_df['TF_name'])]
marker_tf_df

In [None]:
marker_gene_count_map = marker_gene_df['cell_type'].value_counts().to_dict()
adata.obs['marker_gene_count'] = adata.obs['cell_type'].map(marker_gene_count_map).fillna(0)
marker_tf_count_map = marker_tf_df['cell_type'].value_counts().to_dict()
adata.obs['marker_tf_count'] = adata.obs['cell_type'].map(marker_tf_count_map).fillna(0)
adata.obs['n_exp_genes'] = np.sum(adata.X > 0, axis=1)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
sc.pl.umap(adata, color='n_exp_genes', 
           vmax=None, cmap='gnuplot', ax=ax)
fig.savefig(os.path.join(output_path, 'umap_n_exp_genes.pdf'))

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
sc.pl.umap(adata, color='marker_gene_count', 
           vmax=None, cmap='gnuplot', ax=ax)
fig.savefig(os.path.join(output_path, 'umap_marker_gene_count.pdf'))

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
sc.pl.umap(adata, color='marker_tf_count', 
           vmax=None, cmap='gnuplot', ax=ax)
fig.savefig(os.path.join(output_path, 'umap_marker_tf_count.pdf'))

In [None]:
marker_gene_df['cell_type'].value_counts().hist(bins=30)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)

marker_tf_df['cell_type'].value_counts().hist(bins=np.arange(56) + 0.5, color='grey', ax=ax)
ax.grid(False)
ax.set_xlabel('Number of marker TFs')
ax.set_ylabel('Number of cell types')
fig.savefig(os.path.join(output_path, 'hist_marker_tf_count.pdf'))

In [None]:
marker_tf_df['cell_type'].value_counts()

In [None]:
marker_tf_df[marker_tf_df['cell_type'] == 'erythrocyte'].sort_values('fc', ascending=False)

In [None]:
cell_type = 'erythrocyte'
genes = ['KLF1', 'GATA1', 'NFE2', 'GFI1B', 'TAL1', 'LYL1']

fig = sc.pl.umap(adata_named, color=genes,
   vmax=None, cmap='inferno_r',
   return_fig=True)

fig.savefig(os.path.join(output_path, f'umap_marker_tfs_{cell_type}.pdf'), dpi=300)

In [None]:
all_gene_de_df[
    (all_gene_de_df['human_gene_name'] == 'TAL1')
    #&(all_gene_de_df['cell_type'] == 'Epiblast')
    ].sort_values('fc', ascending=False)[:20]