In [None]:
import os

import numpy as np
import pandas as pd

import anndata
import scanpy as sc

import torch

from scmg.preprocessing.data_standardization import GeneNameMapper
from scmg.model.contrastive_embedding import CellEmbedder, embed_adata

gene_name_mapper = GeneNameMapper()

In [None]:
import matplotlib
import matplotlib.pyplot as plt

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)

In [None]:
from scimilarity.utils import lognorm_counts, align_dataset
from scimilarity import CellQuery

cq = CellQuery('/GPUData_xingjie/test/scimilarity/model_v1.1')

def get_scimilarity_emb(adata):
    adata_cp = adata.copy()
    adata_cp.var.index = adata_cp.var['feature_name']

    adata_cp = align_dataset(adata_cp, cq.gene_order)
    adata_cp.layers['counts'] = adata_cp.X.copy()
    adata_cp = lognorm_counts(adata_cp)

    adata.obsm["X_scimilarity"] = cq.get_embeddings(adata_cp.X)

In [None]:
# Load the autoencoder model
model_ce_path = '../trained_embedder/'

model_ce = torch.load(os.path.join(model_ce_path, 'model.pt'))
model_ce.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_ce.to(device)
model_ce.eval()

def get_scmg_embedding(adata):
    adata_cp = adata.copy()
    adata_cp.var.index = adata_cp.var['feature_id']

    embed_adata(model_ce, adata_cp, batch_size=8192)
    adata.obsm["X_scmg"] = adata_cp.obsm["X_ce_latent"]

In [None]:
def subsample_adata_by_ct(adata, ct_count_series):
    selected_indices = []

    for ct in ct_count_series.index:
        n = int(ct_count_series.loc[ct])
        indices = adata.obs[adata.obs['cell_type'] == ct].index.values
        selected_indices.extend(np.random.choice(indices, n, replace=False))

    return adata[selected_indices].copy()

def calc_integration_matrices(adata):
    emb_keys = ['uce', 'scgpt', 'scvi', 'geneformer', 'X_scimilarity', 'X_scmg']

    metric_dict = {
        'emb_key' : [],
        'frac_interdataset' : [],
        'frac_cell_type_match' : [],
    }

    for emb_key in emb_keys:
        sc.pp.neighbors(adata, use_rep=emb_key, n_neighbors=30)
        rows, cols = adata.obsp['connectivities'].nonzero()

        c1_batch = adata.obs['batch'].values[rows]
        c2_batch = adata.obs['batch'].values[cols]
        frac_interdataset = np.sum(c1_batch != c2_batch) / len(rows)

        c1_ct = adata.obs['cell_type'].values[rows]
        c2_ct = adata.obs['cell_type'].values[cols]
        frac_cell_type_match = np.sum(c1_ct == c2_ct) / len(rows)

        metric_dict['emb_key'].append(emb_key)
        metric_dict['frac_interdataset'].append(frac_interdataset)
        metric_dict['frac_cell_type_match'].append(frac_cell_type_match)

    return pd.DataFrame(metric_dict)

def save_umap_plots(adata, output_path):
    emb_keys = ['uce', 'scgpt', 'scvi', 'geneformer', 'X_scimilarity', 'X_scmg']

    for emb_key in emb_keys:
        sc.pp.neighbors(adata, use_rep=emb_key, n_neighbors=30)
        sc.tl.umap(adata)

        fig = sc.pl.umap(adata, color=['batch', 'cell_type'], frameon=False,
                         return_fig=True, show=False)
        fig.savefig(os.path.join(output_path, f'umap_{emb_key}.pdf'))

In [None]:
dataset_pairs = [
    ('Tabula_Sapiens_HS_2022_all', 'Burclaff_intestine_HS_2022_all'),
    ('Tabula_Sapiens_HS_2022_all', 'Travaglini_Lung_HS_2021_10x'),
    ('Suo_ImmuneDev_HS_2022_all', 'Triana_BoneMarrow_HS_2021_healthy'),
    ('Eraslan_MultiTissue_HS_2022_all', 'Travaglini_Lung_HS_2021_10x'),
    ('Han_HS_2020_all', 'Travaglini_Lung_HS_2021_10x'),
    ('Tabula_Sapiens_HS_2022_all', 'Suo_ImmuneDev_HS_2022_all'),
    ('Tabula_Sapiens_HS_2022_all', 'Han_HS_2020_all'),
    ('Tabula_Sapiens_HS_2022_all', 'Eraslan_MultiTissue_HS_2022_all'),
    ('Tabula_Sapiens_HS_2022_all', 'Cao_dev_HS_2020_all'),
    ('Cao_dev_HS_2020_all', 'Han_HS_2020_all')

]

In [None]:

for ds1, ds2 in dataset_pairs:
    output_path = os.path.join('comparison_results', f'{ds1}_AND_{ds2}')
    os.makedirs(output_path, exist_ok=True)

    print(ds1, ds2)

    adata1 = sc.read_h5ad(f'{ds1}.h5ad')
    adata2 = sc.read_h5ad(f'{ds2}.h5ad')

    adata1 = adata1[adata1.obs['cell_type'] != 'native cell']
    adata2 = adata2[adata2.obs['cell_type'] != 'native cell']

    # Subsample the datasets to have the same number of cells per cell type
    ct_count_df = adata1.obs['cell_type'].value_counts().to_frame()
    ct_count_df['count2'] = adata2.obs['cell_type'].value_counts()
    ct_count_df.fillna(0, inplace=True)
    ct_count_df['subsample_count'] = ct_count_df.min(axis=1)
    print(np.sum(ct_count_df['subsample_count'] > 0), "common cell types")

    adata1 = subsample_adata_by_ct(adata1, ct_count_df['subsample_count'])
    adata2 = subsample_adata_by_ct(adata2, ct_count_df['subsample_count'])

    get_scimilarity_emb(adata1)
    get_scimilarity_emb(adata2)
    
    get_scmg_embedding(adata1)
    get_scmg_embedding(adata2)

    adata_merged = anndata.concat([adata1, adata2], label='batch')
    #sc.pp.subsample(adata_merged, n_obs=10000, copy=False)
    
    metric_df = calc_integration_matrices(adata_merged)
    metric_df.to_csv(os.path.join(output_path, 'integration_metrics.csv'))

    save_umap_plots(adata_merged, output_path)


In [None]:
# Load the metric dataframes
metric_df_list = []

for ds1, ds2 in dataset_pairs:
    output_path = os.path.join('comparison_results', f'{ds1}_AND_{ds2}')
    metric_df = pd.read_csv(os.path.join(output_path, 'integration_metrics.csv'),
                            index_col=0)
    metric_df['dataset_pair'] = f'{ds1}_AND_{ds2}'
    metric_df_list.append(metric_df)

metric_df = pd.concat(metric_df_list)

In [None]:
integration_type_map = {
    'Tabula_Sapiens_HS_2022_all_AND_Burclaff_intestine_HS_2022_all' : 'Yes',
    'Tabula_Sapiens_HS_2022_all_AND_Travaglini_Lung_HS_2021_10x' : 'Yes',
    'Suo_ImmuneDev_HS_2022_all_AND_Triana_BoneMarrow_HS_2021_healthy' : 'Yes',
    'Eraslan_MultiTissue_HS_2022_all_AND_Travaglini_Lung_HS_2021_10x' : 'Yes',
    'Han_HS_2020_all_AND_Travaglini_Lung_HS_2021_10x' : 'Yes',
    'Tabula_Sapiens_HS_2022_all_AND_Suo_ImmuneDev_HS_2022_all' : 'No',
    'Tabula_Sapiens_HS_2022_all_AND_Han_HS_2020_all' : 'No',
    'Tabula_Sapiens_HS_2022_all_AND_Eraslan_MultiTissue_HS_2022_all' : 'No',
    'Tabula_Sapiens_HS_2022_all_AND_Cao_dev_HS_2020_all' : 'No',
    'Cao_dev_HS_2020_all_AND_Han_HS_2020_all' : 'No'
}
metric_df['excluded from training'] = metric_df['dataset_pair'].map(integration_type_map)

plot_order = metric_df[['emb_key', 'frac_interdataset']].groupby('emb_key').median(
                ).sort_values('frac_interdataset').index.values

In [None]:
import seaborn as sns
from statannotations.Annotator import Annotator


fig, ax = plt.subplots(figsize=(10, 4))

sns.boxplot(data=metric_df, x='emb_key', y='frac_interdataset', showfliers=False, 
            color='.8', ax=ax, order=plot_order)

annotator = Annotator(
    ax,
    [('scvi', 'X_scmg'), ('X_scimilarity', 'X_scmg'), ('scgpt', 'X_scmg'),
     ('geneformer', 'X_scmg'), ('uce', 'X_scmg')],
    data=metric_df,
    x='emb_key',
    y='frac_interdataset',
    order=plot_order)
annotator.configure(test='t-test_welch', text_format='star', loc='outside')
annotator.apply_and_annotate()

sns.stripplot(data=metric_df, x='emb_key', y='frac_interdataset', hue='excluded from training', 
              alpha=1, ax=ax, order=plot_order, palette='Dark2')

ax.set_ylim(0, 0.15)
ax.set_xlabel('Embedding method')
ax.set_ylabel('Fraction of inter-dataset neighbors')

fig.savefig('comparison_results/barplot_frac_interdatset.pdf')

In [None]:
fig, ax = plt.subplots(figsize=(10, 4))

sns.boxplot(data=metric_df, x='emb_key', y='frac_cell_type_match', showfliers=False, 
            color='.8', ax=ax, order=plot_order)

annotator = Annotator(
    ax,
    [('scvi', 'X_scmg'), ('X_scimilarity', 'X_scmg'), ('scgpt', 'X_scmg'),
     ('geneformer', 'X_scmg'), ('uce', 'X_scmg')],
    data=metric_df,
    x='emb_key',
    y='frac_cell_type_match',
    order=plot_order)
annotator.configure(test='t-test_welch', text_format='star', loc='outside')
annotator.apply_and_annotate()

sns.stripplot(data=metric_df, x='emb_key', y='frac_cell_type_match', 
              hue='excluded from training', alpha=1, ax=ax, order=plot_order, palette='Dark2')

ax.set_xlabel('Embedding method')
ax.set_ylabel('Fraction of same cell type neighbors')
ax.set_ylim(None, 1)
fig.savefig('comparison_results/barplot_frac_same_cell_type.pdf')

In [None]:
import scipy

scipy.stats.ttest_ind(
    metric_df[metric_df['emb_key'] == 'geneformer']['frac_cell_type_match'],
    metric_df[metric_df['emb_key'] == 'X_scmg']['frac_cell_type_match']
)

In [None]:
scipy.stats.ttest_rel(
    metric_df[metric_df['emb_key'] == 'geneformer']['frac_cell_type_match'],
    metric_df[metric_df['emb_key'] == 'X_scmg']['frac_cell_type_match']
)