In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.spatial

import anndata
import scanpy as sc
import umap

import torch
from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_cell_state_embedding

from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
plot_output_path = 'classify_genes_plots/'
os.makedirs(plot_output_path, exist_ok=True)

In [None]:
# Load the autoencoder model
model_ce_path = '../../contrastive_embedding/trained_embedder/'

model_ce = torch.load(os.path.join(model_ce_path, 'model.pt'))
model_ce.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_ce.to(device)
model_ce.eval()

In [None]:
from tqdm import tqdm
def get_cluster_mean_expression_matrix_low_mem(adata, cluster_column):
    '''Get a dataframe of mean gene expression of each cluster.'''
    cluster_names = np.unique(adata.obs[cluster_column].values)
    cluster_mean_df = pd.DataFrame(np.zeros((len(cluster_names), adata.shape[1]), dtype=np.float32), 
                                   index=cluster_names, columns=adata.var.index)
    
    for c in tqdm(cluster_names):
        X_c = adata[adata.obs[cluster_column] == c].X
        cluster_mean_df.loc[c] = X_c.mean(axis=0)
    
    return cluster_mean_df

In [None]:
adata = sc.read_h5ad('../../manifold_generator/ref_cell_adata_measured_count.h5ad')
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

adata.obs['species'] = ['human' if '_HS_' in d else 'mouse' 
                            for d in adata.obs['dataset_id']]
adata

In [None]:
ct_mean_measured_df = get_cluster_mean_expression_matrix_low_mem(adata, 'cell_type')

In [None]:
adata_10x_human = adata[adata.obs['dataset_id'].isin([
     'Arutyunyan_Placenta_HS_2023:all',
     'Bhaduri_CtxDev_HS_2021:all',
     'Conde_Immune_HS_2022:all',
#     'Cowan_Retina_HS_2020:fovea',
#     'Cowan_Retina_HS_2020:periphery',
     'Elmentaite_intestine_HS_2021:all',
     'Eraslan_MultiTissue_HS_2022:all',
     'Fawkner-Corbett_IntestineDev_HS_2021:all',
     'He_LungDev_HS_2022:all',
     'Jardine_BloodDev_HS_2021:normal',
     'Khaled_Breast_HS_2023:all',
     'Kuppe_Heart_HS_2022:all',
     'Lake_Kidney_HS_2023:all',
     'Lengyel_FallopianTube_HS_2022:all',
     'Litvinukova_Heart_HS_2020:all',
     'Park_Thymus_HS_2020:all',
     'Sikkema_Lung_HS_2023:core',
     'Streets_Adipose_HS_2023:all',
     'Suo_ImmuneDev_HS_2022:all',
     'Tabula_Sapiens_HS_2022:all',
     'VentoTormo_Placenta_HS_2018:all',
     'Wiedemann_Skin_HS_2023:all',
     'Yu_MultiTissue_HS_2021:all'
])].copy()
adata_10x_human

In [None]:
adata_10x_human.var['n_measured_batches'] = 0

for batch in np.unique(adata_10x_human.obs['dataset_id']):
    adata_10x_human.var['n_measured_batches'] += list(
        adata_10x_human[adata_10x_human.obs['dataset_id'] == batch].layers['measure_mask'][0])

adata_10x_human.var['n_measured_batches'].hist(bins=50)

In [None]:
from sklearn.neighbors import NearestNeighbors
neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(adata_10x_human.obsm['X_scmg'])
adata.obs['10x_human_neighbor_idx'] = neigh.kneighbors(adata.obsm['X_scmg'], return_distance=False)
adata.obs['10x_human_neighbor_batch_id'] = np.array(adata_10x_human.obs['dataset_id']
                                                    )[adata.obs['10x_human_neighbor_idx']]

In [None]:
adata_decoded = decode_cell_state_embedding(model_ce, adata.obsm['X_scmg'], 
        #dataset_names=['Tabula_Sapiens_HS_2022:all'] * adata.shape[0],
        #dataset_names=['Qiu_whole_embryo_dev_MM_2024:all'] * adata.shape[0],
        #dataset_names=['Tabula_Muris_MM_2020:10x'] * adata.shape[0],
        #dataset_names=adata.obs['dataset_id'].values,
        dataset_names=adata.obs['10x_human_neighbor_batch_id'].values,
                )

adata_decoded.obs = adata.obs.copy()
adata_decoded.obsm['X_umap'] = adata.obsm['X_umap']
adata_decoded

In [None]:
ct_mean_decoded_df = get_cluster_mean_expression_matrix_low_mem(adata_decoded, 'cell_type')

In [None]:
def entropy(v):
    p = v / np.sum(v)
    p = p[p > 0]  # Remove zero probabilities
    return -np.sum(p * np.log(p))

gene_comp_dict = {
    'gene_id' : [],
    'gene_name' : [],
    'mean_exp_measured' : [],
    'mean_exp_decoded' : [],
    'max_ct_exp_measured' : [],
    'max_ct_exp_decoded' : [],
    'entropy_measured' : [],
    'entropy_decoded' : [],
    'corr' : [],
}

for gene in tqdm(adata.var.index):
    gene_comp_dict['gene_id'].append(gene)
    gene_comp_dict['gene_name'].append(adata.var.loc[gene, 'human_gene_name'])
    v_measured = np.array(ct_mean_measured_df[gene])
    v_decoded = np.array(ct_mean_decoded_df[gene])
    gene_comp_dict['mean_exp_measured'].append(v_measured.mean())
    gene_comp_dict['mean_exp_decoded'].append(v_decoded.mean())
    gene_comp_dict['entropy_measured'].append(entropy(v_measured))
    gene_comp_dict['entropy_decoded'].append(entropy(v_decoded))
    gene_comp_dict['max_ct_exp_measured'].append(v_measured.max())
    gene_comp_dict['max_ct_exp_decoded'].append(v_decoded.max())
    gene_comp_dict['corr'].append(np.corrcoef(v_measured, v_decoded)[0, 1])

gene_comp_df = pd.DataFrame(gene_comp_dict).set_index('gene_id')


In [None]:
adata_decoded.var = gene_comp_df.copy()
adata_decoded.write_h5ad('adata_decoded_human_10x.h5ad')
adata_decoded

In [None]:
fig, ax = plt.subplots(figsize=(4,4), dpi=100)

slope, intercept = scipy.stats.linregress(gene_comp_df['mean_exp_measured'], gene_comp_df['mean_exp_decoded'])[:2]
ax.axline((0, intercept), slope=slope, color='black', linestyle='--')

ax.scatter(gene_comp_df['mean_exp_measured'], gene_comp_df['mean_exp_decoded'], s=1, c='grey', rasterized=True)
ax.set_xlabel('Mean Expression (Measured)')
ax.set_ylabel('Mean Expression (Decoded)')
ax.set_title('Mean Expression: Measured vs Decoded')

r = np.corrcoef(gene_comp_df['mean_exp_measured'], gene_comp_df['mean_exp_decoded'])[0, 1]
ax.text(0.05, 0.9, f'R={r:.3f}', transform=ax.transAxes)
fig.savefig(os.path.join(plot_output_path, 'mean_expression_measured_vs_decoded.pdf'))
plt.show()

In [None]:
gene_comp_df['mean_exp_fitted'] = slope * gene_comp_df['mean_exp_measured'] + intercept
gene_comp_df['mean_exp_diff_from_fitted'] = gene_comp_df['mean_exp_decoded'] - gene_comp_df['mean_exp_fitted']

In [None]:
gene_comp_df.sort_values('mean_exp_diff_from_fitted')

In [None]:
gene_comp_df[gene_comp_df['gene_name'].isin(['H3Y1', 'TMSB4Y', 'MT-CO3', 'GAPDH'])]

In [None]:
np.array(gene_comp_df.sort_values('mean_exp_diff_from_fitted', ascending=False)['gene_name'][:20])

In [None]:
fig, ax = plt.subplots(figsize=(4,4), dpi=100)
gene_comp_df['max_ct_exp_measured'].hist(bins=100, range=(0,3), ax=ax, color='grey')
ax.axvline(x=0.2, color='black', linestyle='--')
ax.grid(False)
ax.set_xlabel('Max Cell Type Expression (Measured)')
ax.set_ylabel('Frequency')
fig.savefig(os.path.join(plot_output_path, 'max_cell_type_expression_measured_hist.pdf'))

In [None]:
np.array(gene_comp_df[
    (gene_comp_df['max_ct_exp_measured'] > 0.2)
    & (gene_comp_df['max_ct_exp_measured'] < 0.205)
]['gene_name'])

In [None]:
fig, ax = plt.subplots(figsize=(4,4), dpi=100)

gene_comp_df[gene_comp_df['max_ct_exp_measured'] > 0.2]['corr'].hist(bins=100, ax=ax, color='grey')
ax.set_xlabel('Correlation (Measured vs Decoded)')
ax.set_ylabel('Frequency')
ax.grid(False)
fig.savefig(os.path.join(plot_output_path, 'correlation_measured_vs_decoded_hist.pdf'))

In [None]:
fig, ax = plt.subplots(figsize=(4,4), dpi=100)

gene_comp_df[gene_comp_df['max_ct_exp_measured'] > 0.2]['entropy_decoded'].hist(bins=100, range=(3, 6.7), ax=ax, color='grey')
ax.axvline(x=6.4, color='black', linestyle='--')
ax.set_xlabel('Entropy (Decoded)')
ax.set_ylabel('Frequency')
ax.grid(False)
fig.savefig(os.path.join(plot_output_path, 'entropy_decoded_hist.pdf'))

In [None]:
np.array(gene_comp_df[
    (gene_comp_df['max_ct_exp_measured'] > 0.2)
    & (gene_comp_df['entropy_decoded'] > 6.4)
    & (gene_comp_df['entropy_decoded'] < 6.41)
]['gene_name'])[:30]

In [None]:
np.array(gene_comp_df[
    (gene_comp_df['max_ct_exp_measured'] > 0.2)
    & (gene_comp_df['corr'] < 0.801)
    & (gene_comp_df['corr'] > 0.8)
]['gene_name'])

In [None]:
gene_comp_df[
    (gene_comp_df['max_ct_exp_measured'] > 0.2)
    & (gene_comp_df['corr'] < 0.801)
    & (gene_comp_df['corr'] > 0.8)
]

In [None]:
gene_comp_df[
    (gene_comp_df['gene_name'].str.startswith('SPI1'))
]


In [None]:
adata_named = adata_decoded.copy()

adata_named.var.index = list(adata.var['human_gene_name'])
adata_named.var_names_make_unique()

In [None]:
np.array(adata_named.var[adata_named.var.index.str.startswith('MT-')].index)

In [None]:
genes_to_plot = ['NANOG', 'POU5F1', 'SOX2']
#genes_to_plot = ['H3Y1', 'TMSB4Y', 'MT-CO3', 'GAPDH',]
genes_to_plot = [g for g in genes_to_plot if g in adata_named.var_names]

sc.pl.umap(adata_named, color=genes_to_plot, cmap='inferno_r')