In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import anndata as ad
import scanpy.external as sce
from sklearn import preprocessing
import pickle5 as pickle
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
from sklearn import preprocessing
import sklearn
import seaborn as sns

from utils import *

from sklearn import metrics
eps=1e-100

def custom_annot(data, fmt_func):
    """Return formatted annotations."""
    annot = np.vectorize(fmt_func)(data)
    return annot

# Custom annotation function
def fmt(x):
    return '' if x == 0 else '{:.0f}'.format(x)


# cell-type annotation of hold-out atlas 3 data

In [None]:
ad_embed_all=sc.read_h5ad(f"source_data/atlas/transfer_celltype_merscope.h5ad")
ad_embed_all_merscope = ad_embed_all[ad_embed_all.obs['name']!='starmap',:]
keep_celltype=ad_embed_all_merscope.obs['gt_cell_type_main'].value_counts().keys()[ad_embed_all_merscope.obs['gt_cell_type_main'].value_counts()>10]

ad_embed_all_merscope=ad_embed_all_merscope[ad_embed_all_merscope.obs['gt_cell_type_main'].isin(keep_celltype),:]
ad_embed_all_merscope=ad_embed_all_merscope[ad_embed_all_merscope.obs['transfer_gt_cell_type_main_merscope'].isin(keep_celltype),:]

In [None]:
metrics.accuracy_score(ad_embed_all_merscope.obs['gt_cell_type_main'], 
                    ad_embed_all_merscope.obs['transfer_gt_cell_type_main_merscope'])


In [None]:
cross_tab = pd.crosstab(ad_embed_all_merscope.obs['gt_cell_type_main'], 
                        ad_embed_all_merscope.obs['transfer_gt_cell_type_main_merscope'])
cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=1), axis=0)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=0), axis=1)

In [None]:
# Plot heatmap
plt.figure(figsize=(20,15))
ax=sns.heatmap(cross_tab_normalized, cmap='Blues')
ax.hlines(list(np.arange(29)), color='k',linewidth=0.1,*ax.get_xlim())
ax.vlines(list(np.arange(29)), color='k',linewidth=0.1,*ax.get_ylim())
plt.title("Normalized Correspondence of Two Categories")
# plt.savefig('figures/merscope_celltype.pdf',dpi=300, transparent=True)
plt.show()


# cell-type annotation of hold-out atlas 1 data

In [None]:
ad_embed_all=sc.read_h5ad(f"source_data/atlas/transfer_celltype_starmap.h5ad")
ad_embed_all_starmap = ad_embed_all[ad_embed_all.obs['name']=='starmap',:]

keep_celltype=ad_embed_all_starmap.obs['gt_cell_type_main'].value_counts().keys()[ad_embed_all_starmap.obs['gt_cell_type_main'].value_counts()>10]


ad_embed_all_starmap=ad_embed_all_starmap[ad_embed_all_starmap.obs['gt_cell_type_main'].isin(keep_celltype),:]
ad_embed_all_starmap=ad_embed_all_starmap[ad_embed_all_starmap.obs['transfer_gt_cell_type_main_starmap'].isin(keep_celltype),:]

ad_embed_all_starmap_test = ad_embed_all_starmap[~ad_embed_all_starmap.obs['gt_cell_type_main'].isna()]

In [None]:
metrics.accuracy_score(ad_embed_all_starmap_test.obs['gt_cell_type_main'], 
                    ad_embed_all_starmap_test.obs['transfer_gt_cell_type_main_starmap'])

In [None]:
cross_tab = pd.crosstab(ad_embed_all_starmap_test.obs['gt_cell_type_main'], 
                        ad_embed_all_starmap_test.obs['transfer_gt_cell_type_main_starmap'])
cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=1), axis=0)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=0), axis=1)

In [None]:
plt.figure(figsize=(15,11.25))
ax=sns.heatmap(cross_tab_normalized, cmap='Blues')
ax.hlines(list(np.arange(29)), color='k',linewidth=0.1,*ax.get_xlim())
ax.vlines(list(np.arange(29)), color='k',linewidth=0.1,*ax.get_ylim())
plt.title("Normalized Correspondence of Two Categories")
# plt.savefig('figures/starmap_celltype.pdf',dpi=300, transparent=True)
plt.show()


# tissue-region annotation of hold-out atlas 1 data

In [None]:
ad_embed_all=sc.read_h5ad(f"source_data/atlas/transfer_tissueregion_starmap.h5ad")
ad_embed_all_starmap = ad_embed_all[ad_embed_all.obs['name']=='starmap',:]

sc.pp.neighbors(ad_embed_all_starmap,use_rep='X')
sc.tl.umap(ad_embed_all_starmap)

In [None]:
color_code=sns.color_palette('Paired',17)
dic_color={}
for ind,i in enumerate(ad_embed_all_starmap.obs['gt_tissue_region_main'].value_counts().keys()):
    dic_color[i]=color_code[ind]


In [None]:
fig,ax=plt.subplots(figsize=(7,7))
sc.pl.umap(ad_embed_all_starmap,color='gt_tissue_region_main',
           palette=dic_color,ax=ax)


In [None]:
fig,ax=plt.subplots(figsize=(7,7))
sc.pl.umap(ad_embed_all_starmap,color='transfer_gt_tissue_region_main_starmap',
           palette=dic_color,ax=ax)


In [None]:
sub_entm=ad_embed_all_starmap[ad_embed_all_starmap.obs['transfer_gt_tissue_region_main_starmap']=='CTX_1']

sc.tl.leiden(sub_entm,resolution=0.05)

ad_embed_all_starmap.obs['transfer_gt_tissue_region_main_starmap_correct']=ad_embed_all_starmap.obs['transfer_gt_tissue_region_main_starmap']
ad_embed_all_starmap.obs['transfer_gt_tissue_region_main_starmap_correct'] = ad_embed_all_starmap.obs['transfer_gt_tissue_region_main_starmap_correct'].astype('str')
ad_embed_all_starmap.obs.loc[sub_entm[sub_entm.obs['leiden']=='2'].obs.index,'transfer_gt_tissue_region_main_starmap_correct']='ENTm'

In [None]:
fig,ax=plt.subplots(figsize=(7,7))
sc.pl.umap(ad_embed_all_starmap,color='transfer_gt_tissue_region_main_starmap_correct',
           palette=dic_color,ax=ax)


In [None]:
metrics.accuracy_score(ad_embed_all_starmap.obs['gt_tissue_region_main'], 
                    ad_embed_all_starmap.obs['transfer_gt_tissue_region_main_starmap_correct'])



In [None]:
cross_tab = pd.crosstab(ad_embed_all_starmap.obs['gt_tissue_region_main'], 
                        ad_embed_all_starmap.obs['transfer_gt_tissue_region_main_starmap_correct'])
cross_tab_normalized = cross_tab.div(cross_tab.sum(axis=1), axis=0)
cross_tab_normalized = cross_tab_normalized.div(cross_tab_normalized.sum(axis=0), axis=1)

# Plot heatmap
plt.figure(figsize=(12,9))
ax=sns.heatmap(cross_tab_normalized, cmap='Blues')
ax.hlines(list(np.arange(29)), color='k',linewidth=0.1,*ax.get_xlim())
ax.vlines(list(np.arange(29)), color='k',linewidth=0.1,*ax.get_ylim())
plt.title("Normalized Correspondence of Two Categories")
plt.show()
