In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd

import anndata
import scanpy as sc

import torch

from scmg.preprocessing.data_standardization import GeneNameMapper
from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_adata

gene_name_mapper = GeneNameMapper()

In [None]:
import matplotlib
import matplotlib.pyplot as plt

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)

In [None]:
# Load the autoencoder model
model_ce_path = '../../trained_embedder/'

model = torch.load(os.path.join(model_ce_path, 'model.pt'))
model.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model.to(device)
model.eval()

model.dataset_id_map

In [None]:
adata = sc.read_h5ad('../../test_embedding/Burclaff_intestine_HS_2022_all.h5ad')
adata.var.index = adata.var['feature_id']
adata

In [None]:
embed_adata(model, adata, batch_size=8192)

In [None]:
sc.pp.neighbors(adata, use_rep='X_ce_latent', n_neighbors=30)
sc.tl.umap(adata)

In [None]:
fig = sc.pl.umap(adata, color='cell_type', frameon=False, return_fig=True)
fig.savefig('umap_cell_type.pdf', dpi=300)

In [None]:
adata_pred_named = decode_adata(model, adata, ['Tabula_Sapiens_HS_2022:all'] * adata.shape[0])

adata_pred_named.var.index = gene_name_mapper.map_gene_names(
    adata_pred_named.var.index, 'human', 'human', 'id', 'name')
adata_pred_named.var_names_make_unique()

In [None]:
adata_gene_named = adata.copy()

adata_gene_named.var.index = gene_name_mapper.map_gene_names(
    adata_gene_named.var.index, 'human', 'human', 'id', 'name')
adata_gene_named.var_names_make_unique()
adata_gene_named = adata_gene_named[:, adata_gene_named.var.index.isin(
    adata_pred_named.var.index)]

sc.pp.normalize_total(adata_gene_named, target_sum=1e4)
sc.pp.log1p(adata_gene_named)

In [None]:
#sc.pp.highly_variable_genes(adata_gene_named, n_top_genes=2000)
sc.pp.highly_variable_genes(adata_gene_named, min_mean=0.05, max_mean=3, min_disp=0.5)

adata_gene_named = adata_gene_named[:, adata_gene_named.var.highly_variable].copy()
adata_gene_named.X = adata_gene_named.X.toarray()

adata_pred_named = adata_pred_named[:, adata_gene_named.var.index]

In [None]:
sc.tl.rank_genes_groups(adata_gene_named, groupby="cell_type", method="wilcoxon")

sc.pl.rank_genes_groups_heatmap(
    adata_gene_named,
    n_genes=3,
    swap_axes=True,
    cmap='inferno',
    vmax=5,
    save='rank_genes_groups_heatmap.pdf',
)

In [None]:
adata_pred_named.uns['rank_genes_groups'] = adata_gene_named.uns['rank_genes_groups']
adata_pred_named.uns['dendrogram_cell_type'] = adata_gene_named.uns['dendrogram_cell_type']

sc.pl.rank_genes_groups_heatmap(
    adata_pred_named,
    n_genes=3,
    swap_axes=True,
    cmap='inferno',
    vmax=5,
    save='rank_genes_groups_heatmap_decoded.pdf',
)

In [None]:
def get_cluster_mean_expression_matrix_low_mem(adata, cluster_column):
    '''Get a dataframe of mean gene expression of each cluster.'''
    cluster_names = np.unique(adata.obs[cluster_column].values)
    cluster_mean_df = pd.DataFrame(np.zeros((len(cluster_names), adata.shape[1]), dtype=np.float32), 
                                   index=cluster_names, columns=adata.var.index)
    
    for c in tqdm(cluster_names):
        X_c = adata[adata.obs[cluster_column] == c].X
        cluster_mean_df.loc[c] = X_c.mean(axis=0)
    
    return cluster_mean_df

ct_gene_exp_true_df = get_cluster_mean_expression_matrix_low_mem(adata_gene_named, 'cell_type')
ct_gene_exp_pred_df = get_cluster_mean_expression_matrix_low_mem(adata_pred_named, 'cell_type')

In [None]:
import scipy.stats

ct_corr_dict = {'gene': [], 'corr': []}

for gene in tqdm(adata_gene_named.var.index):
    if gene in adata_pred_named.var.index:
        gene_true = ct_gene_exp_true_df[gene].values
        gene_pred = ct_gene_exp_pred_df[gene].values
        corr = scipy.stats.pearsonr(gene_true, gene_pred)[0]
        
        ct_corr_dict['gene'].append(gene)
        ct_corr_dict['corr'].append(corr)

ct_corr_df = pd.DataFrame(ct_corr_dict)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
ax.hist(ct_corr_df['corr'], bins=20, density=False, color='gray')
ax.grid(False)
ax.set_xlabel('Pearson correlation')
ax.set_ylabel('Number of genes')
ax.set_title('cell type level')
fig.savefig('cell_type_correlation_hist.pdf')

In [None]:
import scipy.stats

corr_dict = {'gene': [], 'corr': []}

for gene in tqdm(adata_gene_named.var.index):
    if gene in adata_pred_named.var.index:
        gene_true = adata_gene_named[:, gene].X.flatten()
        gene_pred = adata_pred_named[:, gene].X.flatten()
        corr = scipy.stats.pearsonr(gene_true, gene_pred)[0]
        
        corr_dict['gene'].append(gene)
        corr_dict['corr'].append(corr)

corr_df = pd.DataFrame(corr_dict)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
ax.hist(corr_df['corr'], bins=20, density=False, color='gray')
ax.grid(False)
ax.set_xlabel('Pearson correlation')
ax.set_ylabel('Number of genes')
ax.set_title('single cell level')
fig.savefig('single_cell_correlation_hist.pdf')

In [None]:
corr_df.sort_values('corr')[1030:1050]

In [None]:
genes_to_plot = ['BRME1', 'LRRC31', 'HK2', 'DPP4']

fig = sc.pl.umap(adata_gene_named, color=genes_to_plot, return_fig=True, cmap='inferno_r')
fig.savefig('umap_genes.pdf', dpi=300)

fig = sc.pl.umap(adata_pred_named, color=genes_to_plot, return_fig=True, cmap='inferno_r')
fig.savefig('umap_genes_decoded.pdf', dpi=300)

corr_df[corr_df['gene'].isin(genes_to_plot)]