In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd

import anndata
import scanpy as sc

import torch

from scmg.preprocessing.data_standardization import GeneNameMapper
from scmg.model.contrastive_embedding import CellEmbedder, embed_adata, decode_adata

gene_name_mapper = GeneNameMapper()

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
output_path = 'Burclaff_intestine_cell_type_search_plots'
os.makedirs(output_path, exist_ok=True)

In [None]:
# Load the autoencoder model
model_ce_path = '../../../contrastive_embedding/trained_embedder/'

model = torch.load(os.path.join(model_ce_path, 'model.pt'))
model.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model.to(device)
model.eval()

In [None]:
adata_ct_ref = sc.read_h5ad('../../ref_cell_adata.h5ad')
adata_ct_ref

In [None]:
adata = sc.read_h5ad('../../../contrastive_embedding/test_embedding/Burclaff_intestine_HS_2022_all.h5ad')
#adata = sc.read_h5ad('../../../contrastive_embedding/test_embedding/Triana_BoneMarrow_HS_2021_healthy.h5ad')


adata.var.index = adata.var['feature_id']
adata

In [None]:
embed_adata(model, adata, batch_size=8192)

In [None]:
major_ct_df = pd.read_csv('../../../cell_type_analysis/major_cell_type_annotation.csv')
major_ct_df

In [None]:
sorted(major_ct_df['major_cell_type'].value_counts().index)

In [None]:
from scmg.model.cell_type_search import CellTypeSearcher

cts = CellTypeSearcher(adata_ct_ref)

In [None]:
adata.obs['cell_type'].value_counts()

In [None]:
#query_emb = adata.obsm['X_ce_latent'][adata.obs['cell_type'] == 
#                        'colon goblet cell']
query_emb = adata.obsm['X_ce_latent']

cell_match_df = cts.search_ref_cell(query_emb)

adata.obsm['X_project_umap'] = cell_match_df[['umap_x', 'umap_y']].values
adata.obs['project_dist'] = cell_match_df['distance'].values

In [None]:
sc.pl.umap(adata_ct_ref, color='major_cell_type')

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=300)
sc.pl.umap(adata_ct_ref, ax=ax, show=False)
sc.pl.embedding(adata, basis='X_project_umap', color='cell_type', ax=ax, frameon=False)
#sc.pl.embedding(adata, basis='X_project_umap', color='project_dist', ax=ax)

fig.savefig(os.path.join(output_path, 'umap_projected_cells.pdf'))

In [None]:
plt.scatter(adata_ct_ref.obsm['X_umap'][:, 0], adata_ct_ref.obsm['X_umap'][:, 1], 
            s=1, c='lightgrey')
#plt.scatter(cell_match_df['umap_x'], cell_match_df['umap_y'], s=1, c='red')
sct = plt.scatter(cell_match_df['umap_x'], cell_match_df['umap_y'],
             s=1, c=cell_match_df['distance'], cmap='viridis')
plt.colorbar(sct)

In [None]:
query_cell_types = np.unique(adata.obs['cell_type'])

all_ct_match_df = pd.DataFrame(
    data=np.zeros((len(query_cell_types), len(cts.cell_types)), dtype=np.float32),
    index=query_cell_types, columns=cts.cell_types
)

for ct in tqdm(query_cell_types):
    query_emb = adata.obsm['X_ce_latent'][adata.obs['cell_type'] == ct]
    ct_match_df = cts.search_ref_cell_types(query_emb)

    all_ct_match_df.loc[ct] = ct_match_df['weight']
    

In [None]:
for ct in query_cell_types:
    display(all_ct_match_df.loc[ct].sort_values(ascending=False)[:5])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
sc.pl.embedding(adata, basis='X_project_umap', color='cell_type',
                legend_loc='on data', ax=ax, legend_fontsize=5)

In [None]:
query_cts_to_show = [
    'BEST4+ intestinal epithelial cell, human',
    'intestinal crypt stem cell of colon',
    'transit amplifying cell of colon',
    'enterocyte of epithelium of large intestine',
    'intestinal crypt stem cell of small intestine',
    'transit amplifying cell of small intestine',
    'enterocyte of epithelium of small intestine',
    'microfold cell of epithelium of small intestine',
    'epithelial cell of small intestine',
    'enteroendocrine cell of colon',
    'enteroendocrine cell of small intestine',
    'intestinal tuft cell',
    'tuft cell of colon',
    'paneth cell of epithelium of small intestine',
    'progenitor cell',
    'colon goblet cell',
    'small intestine goblet cell',
]


ref_cts_to_show = [
    'paneth cell of colon',
    'transit amplifying cell', 
    'intestinal crypt stem cell',
    'transit amplifying cell of colon',
    'transit amplifying cell of small intestine',
    'Midgut/Hindgut epithelial cells',
    'intestinal enteroendocrine cell',
    'intestinal tuft cell',
    'Intestinal goblet cells',
    'large intestine goblet cell', 

]

ct_mtx_to_show = all_ct_match_df.loc[query_cts_to_show, ref_cts_to_show]
#ct_mtx_to_show = ct_mtx_to_show / ct_mtx_to_show.values.max(axis=1)[:, None]

fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(ct_mtx_to_show.T, cmap='viridis', ax=ax, vmax=0.4, cbar_kws={'label': 'cell type match score'})
ax.set_xlabel('query cell type')
ax.set_ylabel('reference cell type')

fig.savefig(os.path.join(output_path, 'cell_type_match_heatmap.pdf'))

In [None]:
adata.obs['project_dist'] = cell_match_df['distance'].values
adata.obs['ref_cell'] = cell_match_df['ref_cell'].values
adata.obs['projected_cell_type'] = adata_ct_ref.obs['cell_type'].loc[
                                            cell_match_df['ref_cell']].values

projected_cell_count_map = adata.obs['ref_cell'].value_counts().to_dict()
adata_ct_ref.obs['projected_cell_count'] = adata_ct_ref.obs.index.map(
    lambda x: projected_cell_count_map.get(x, 0))

ct_confusion_df = pd.crosstab(adata.obs['cell_type'], adata.obs['projected_cell_type'])
ct_confusion_df = ct_confusion_df / ct_confusion_df.values.sum(axis=1)[:, None]

for ct_query in ct_confusion_df.index:
    display(ct_confusion_df.loc[ct_query].sort_values(ascending=False).head(5))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))


sns.heatmap(ct_confusion_df.loc[
    [
    'epithelial cell of small intestine',
    'BEST4+ intestinal epithelial cell, human',
    'intestinal crypt stem cell of small intestine',
    'intestinal crypt stem cell of colon',
    'enterocyte of epithelium of large intestine',
    'transit amplifying cell of colon',
    'transit amplifying cell of small intestine',
    'enterocyte of epithelium of small intestine',
    'microfold cell of epithelium of small intestine',
    'intestinal tuft cell',
    'tuft cell of colon',
    'paneth cell of epithelium of small intestine',
    'small intestine goblet cell',
    'progenitor cell',
    'colon goblet cell',
    'enteroendocrine cell of colon',
    'enteroendocrine cell of small intestine',
     ], 
    [
    'epithelial cell',
    'intestinal crypt stem cell',
    'transit amplifying cell',
    'transit amplifying cell of colon',
    'transit amplifying cell of small intestine',
    'enterocyte',
    'enterocyte of colon',
    'intestinal tuft cell',
    'intestine goblet cell',
    'large intestine goblet cell',
    'enteroendocrine cell',
     ]].T,
    cmap='viridis', vmax=0.7)

fig.savefig(os.path.join(output_path, 'cell_type_projection_confusion_matrix.pdf'))