# NOTE:
This is the same tutorial notebook from the main UCE repo. We have re run it with tabula sapiens v2 and the human brain cell atlas, which take a long time to run, for reproducibility and to demonstrate how to run the benchmark for large datasets by repeatedly resampling

# Large Scale Embedding benchmarks

This notebook includes an example showing how to run large scale embedding benchmarks using scIB [(single-cell integration benchmark)](https://www.nature.com/articles/s41592-021-01336-8)

We use the GPU accelerated version implemented here: https://github.com/YosefLab/scib-metrics

Please follow installation instructions in that repo. 

*Note: installing Faiss can be difficult and may take some time*

*Running the full benchmarking suite on many cells can take many hours, even on GPUs with large amounts of memory, such as A100s, and with many threads*

## Load Imports and define Benchmark Function

In [1]:
import numpy as np
import scanpy as sc

from scib_metrics.benchmark import Benchmarker

import faiss

from scib_metrics.nearest_neighbors import NeighborsResults

# Faiss GPU accelerate nearest neighbors methods
def faiss_hnsw_nn(X: np.ndarray, k: int):
    """Gpu HNSW nearest neighbor search using faiss.

    See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md
    for index param details.
    """
    X = np.ascontiguousarray(X, dtype=np.float32)
    res = faiss.StandardGpuResources()
    M = 32
    index = faiss.IndexHNSWFlat(X.shape[1], M, faiss.METRIC_L2)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(X)
    distances, indices = gpu_index.search(X, k)
    del index
    del gpu_index
    # distances are squared
    return NeighborsResults(indices=indices, distances=np.sqrt(distances))


def faiss_brute_force_nn(X: np.ndarray, k: int):
    """Gpu brute force nearest neighbor search using faiss."""
    X = np.ascontiguousarray(X, dtype=np.float32)
    res = faiss.StandardGpuResources()
    index = faiss.IndexFlatL2(X.shape[1])
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(X)
    distances, indices = gpu_index.search(X, k)
    del index
    del gpu_index
    # distances are squared
    return NeighborsResults(indices=indices, distances=np.sqrt(distances))

  from anndata import __version__ as anndata_version
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)


In [21]:
import warnings
warnings.filterwarnings("ignore")
from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection
import pandas as pd

## Benchmarking Function, returns dataframe of scores
def benchmark(ad, label_key="cell_type", batch_key="sample_id", obsm_keys=["X_uce", "X_scGPT", "X_geneformer"], do_batchcons=False):
    print(f"Running using CT key:", label_key)
    biocons = BioConservation()
    if do_batchcons:
        batchcons = BatchCorrection(pcr_comparison=False)
    else:
        batchcons = BatchCorrection(ilisi_knn=False, kbet_per_label=False, graph_connectivity=False, pcr_comparison=False)
    bm = Benchmarker(
        ad,
        batch_key=batch_key,
        label_key=label_key,
        embedding_obsm_keys=obsm_keys,
        bio_conservation_metrics=biocons,
        batch_correction_metrics=batchcons,
        n_jobs=64,
    )
    bm.prepare(neighbor_computer=faiss_brute_force_nn)
    bm.benchmark()
    df = bm.get_results(min_max_scale=False)
    return df

### Load in anndata

For this example, we will benchmark cells from developing mouse brain.

You can download an anndata object with UCE, scGPT and Geneformer embeddings precalulated from [here](https://drive.google.com/drive/folders/1f63fh0ykgEhCrkd_EVvIootBw7LYDVI7)

In [3]:
tabula_ad = sc.read("export_data/new_tabula_scib.h5ad")
tabula_ad

AnnData object with n_obs × n_vars = 581430 × 45792
    obs: 'donor', 'tissue', 'anatomical_position', 'method', 'cdna_plate', 'library_plate', 'notes', 'cdna_well', 'old_index', 'assay', 'sample_id', 'sample', 'replicate', '10X_run', '10X_barcode', 'ambient_removal', 'donor_method', 'donor_assay', 'donor_tissue', 'donor_tissue_assay', 'cell_ontology_class', 'cell_ontology_id', 'compartment', 'broad_cell_class', 'free_annotation', 'manually_annotated', 'published_2022', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ercc', 'pct_counts_ercc', '_scvi_batch', '_scvi_labels', 'scvi_leiden_donorassay_full', 'age', 'sex', 'ethnicity', 'n_genes', 'donor_num', 'cell_type_coarse', 'n_counts'
    var: 'n_cells'
    uns: 'log1p', 'neighbors', 'umap'
    obsm: 'X_geneformer', 'X_scarches', 'X_scgpt', 'X_scvi', 'X_tgpt', 'X_uce'
    obsp: 'connectivities', 'distances'

In [4]:
cell_type_column = "cell_ontology_class"
batch_column = "sample_id"

In [5]:
len(tabula_ad.obs[cell_type_column].unique()) # Number of unique cell types

162

In [6]:
len(tabula_ad.obs[batch_column].unique()) # Number of unique batches

167

# Running the Benchmark on the full dataset

In [7]:
# X_geneformer', 'X_scgpt', 'X_uce', 'X_tgpt', 'X_scvi', 'X_scarches'

In [None]:
new_tabula_benchmark_results_df = benchmark(tabula_ad, label_key=cell_type_column,  batch_key=batch_column, obsm_keys=['X_uce', 'X_geneformer',  'X_scgpt', "X_tgpt", 'X_scvi', 'X_scarches'], do_batchcons=True)
new_tabula_benchmark_results_df

Running using CT key: cell_ontology_class


Computing neighbors:   0%|                                | 0/6 [00:00<?, ?it/s]

In [19]:
# NOTE: For scVI and scARCHES, the results are just for one seed (seed 0) shown here
# The result in Supplementary Table 1 is after running 10 seeds for each and averaging scores

In [13]:
new_tabula_benchmark_results_df

Unnamed: 0_level_0,Isolated labels,KMeans NMI,KMeans ARI,Silhouette label,cLISI,Silhouette batch,iLISI,KBET,Graph connectivity,Batch correction,Bio conservation,Total
Embedding,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
X_uce,0.647063,0.704803,0.229917,0.537544,0.999913,0.87662,0.006863,0.414848,0.691408,0.497434,0.623848,0.573283
X_geneformer,0.530591,0.553988,0.142943,0.454141,0.99937,0.846715,0.006589,0.312182,0.644451,0.452484,0.536207,0.502718
X_scgpt,0.452378,0.290269,0.039445,0.403705,0.993694,0.816081,0.036039,0.29483,0.349412,0.37409,0.435898,0.411175
X_tgpt,0.555026,0.497,0.111297,0.449734,0.998918,0.875747,0.007018,0.298695,0.598287,0.444937,0.522395,0.491412
X_scvi,0.592119,0.670051,0.194243,0.511048,0.999397,0.822333,0.015383,0.435243,0.701316,0.493569,0.593372,0.55345
X_scarches,0.615001,0.681848,0.218949,0.542676,0.999307,0.827007,0.017227,0.42448,0.739538,0.502063,0.611556,0.567759
Metric Type,Bio conservation,Bio conservation,Bio conservation,Bio conservation,Bio conservation,Batch correction,Batch correction,Batch correction,Batch correction,Aggregate score,Aggregate score,Aggregate score


In [None]:
# 'X_pca', 'X_scvi_seed0', 'X_scarches_seed0'

# Running the Benchmark using Resampling (Human Brain Cell Atlas)

Running the benchmark on the full dataset can take a very long time. Instead, we can run on medium sized samples of cells.

In [15]:
sample_size = 500_000 # number of cells

In [16]:
N_RESAMPLES = 2 # actually is 10 in the paper but this takes a very long time to run

In [18]:
hbca_ad = sc.read("export_data/brain_atlas_uce_scgpt_geneformer.h5ad")
hbca_ad

AnnData object with n_obs × n_vars = 2480956 × 17928
    obs: 'ROIGroup', 'ROIGroupCoarse', 'ROIGroupFine', 'roi', 'organism_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'assay_ontology_term_id', 'sex_ontology_term_id', 'development_stage_ontology_term_id', 'donor_id', 'suspension_type', 'dissection', 'fraction_mitochondrial', 'fraction_unspliced', 'cell_cycle_score', 'total_genes', 'total_UMIs', 'sample_id', 'supercluster_term', 'cluster_id', 'subcluster_id', 'cell_type_ontology_term_id', 'tissue_ontology_term_id', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'n_genes', 'n_counts'
    var: 'Biotype', 'Chromosome', 'End', 'Gene', 'Start', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'ensembl_id', 'n_cells'
    uns: 'batch_condition', 'citation', 'log1p', 'schema_referenc

In [23]:
cell_type_column = "cluster_id"
batch_key = "sample_id" # not used

In [24]:
len(hbca_ad.obs[cell_type_column].unique())

382

In [25]:
from tqdm.auto import tqdm
sample_score_dfs = []

for i in tqdm(range(N_RESAMPLES)):
    # benchmark one sample
    # sample is drawn with random state i
    subsample_ad = sc.pp.subsample(hbca_ad, copy=True, n_obs=sample_size, random_state=i)
    sample_df = benchmark(subsample_ad, label_key=cell_type_column,  batch_key=batch_column)
    # show the results for this sample
    display(subsample_ad)
    # add it to the results for all samples
    sample_score_dfs.append(sample_df)

  0%|          | 0/2 [00:00<?, ?it/s]

Running using CT key: cluster_id



Computing neighbors:   0%|                                                                                                     | 0/3 [00:00<?, ?it/s][A
Computing neighbors:  33%|███████████████████████████████                                                              | 1/3 [01:22<02:44, 82.49s/it][A
Computing neighbors:  67%|██████████████████████████████████████████████████████████████                               | 2/3 [01:58<00:55, 55.40s/it][A
Computing neighbors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [02:35<00:00, 51.76s/it][A
Embeddings:   0%|[32m                                                                                                              [0m| 0/3 [00:00<?, ?it/s][0m
Metrics:   0%|[34m                                                                                                                [0m| 0/10 [00:00<?, ?it/s][0m[A
Metrics:   0%|[34m                                       

AnnData object with n_obs × n_vars = 500000 × 17928
    obs: 'ROIGroup', 'ROIGroupCoarse', 'ROIGroupFine', 'roi', 'organism_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'assay_ontology_term_id', 'sex_ontology_term_id', 'development_stage_ontology_term_id', 'donor_id', 'suspension_type', 'dissection', 'fraction_mitochondrial', 'fraction_unspliced', 'cell_cycle_score', 'total_genes', 'total_UMIs', 'sample_id', 'supercluster_term', 'cluster_id', 'subcluster_id', 'cell_type_ontology_term_id', 'tissue_ontology_term_id', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'n_genes', 'n_counts'
    var: 'Biotype', 'Chromosome', 'End', 'Gene', 'Start', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'ensembl_id', 'n_cells'
    uns: 'batch_condition', 'citation', 'log1p', 'schema_reference

Running using CT key: cluster_id



Computing neighbors:   0%|                                                                                                     | 0/3 [00:00<?, ?it/s][A
Computing neighbors:  33%|███████████████████████████████                                                              | 1/3 [00:36<01:13, 36.80s/it][A
Computing neighbors:  67%|██████████████████████████████████████████████████████████████                               | 2/3 [00:53<00:24, 24.85s/it][A
Computing neighbors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:09<00:00, 23.26s/it][A
Embeddings:   0%|[32m                                                                                                              [0m| 0/3 [00:00<?, ?it/s][0m
Metrics:   0%|[34m                                                                                                                [0m| 0/10 [00:00<?, ?it/s][0m[A
Metrics:   0%|[34m                                       

AnnData object with n_obs × n_vars = 500000 × 17928
    obs: 'ROIGroup', 'ROIGroupCoarse', 'ROIGroupFine', 'roi', 'organism_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'assay_ontology_term_id', 'sex_ontology_term_id', 'development_stage_ontology_term_id', 'donor_id', 'suspension_type', 'dissection', 'fraction_mitochondrial', 'fraction_unspliced', 'cell_cycle_score', 'total_genes', 'total_UMIs', 'sample_id', 'supercluster_term', 'cluster_id', 'subcluster_id', 'cell_type_ontology_term_id', 'tissue_ontology_term_id', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'n_genes', 'n_counts'
    var: 'Biotype', 'Chromosome', 'End', 'Gene', 'Start', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'ensembl_id', 'n_cells'
    uns: 'batch_condition', 'citation', 'log1p', 'schema_reference

### Final Scores

We can aggregate the scores from all the samples, taking the mean value (and standard deviation of the score)

In [26]:
grouped_mean = pd.concat([df.drop("Metric Type").reset_index() for df in sample_score_dfs]).groupby("Embedding").agg(np.mean)
# Note: we drop the "Metric Type" row since it contains strings which we can't take the mean of

In [27]:
grouped_std = pd.concat([df.drop("Metric Type").reset_index() for df in sample_score_dfs]).groupby("Embedding").agg(np.std)
# Note: we drop the "Metric Type" row since it contains strings which we can't take the std of

In [28]:
grouped_mean

Unnamed: 0_level_0,Isolated labels,KMeans NMI,KMeans ARI,Silhouette label,cLISI,Silhouette batch,Batch correction,Bio conservation,Total
Embedding,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
X_geneformer,0.441117,0.282088,0.039184,0.393853,0.990547,0.758299,0.758299,0.429358,0.560934
X_scGPT,0.56094,0.596169,0.141131,0.483329,0.997043,0.809811,0.809811,0.555722,0.657358
X_uce,0.580308,0.702658,0.25047,0.50807,0.997679,0.771417,0.771417,0.607837,0.673269


In [29]:
grouped_mean["Bio conservation"]

Embedding
X_geneformer    0.429358
X_scGPT         0.555722
X_uce           0.607837
Name: Bio conservation, dtype: object