# NOTE:
This is the same tutorial notebook from the main UCE repo. We have re run it with tabula sapiens v2 and the human brain cell atlas, which take a long time to run, for reproducibility and to demonstrate how to run the benchmark for large datasets by repeatedly resampling

# Large Scale Embedding benchmarks

This notebook includes an example showing how to run large scale embedding benchmarks using scIB [(single-cell integration benchmark)](https://www.nature.com/articles/s41592-021-01336-8)

We use the GPU accelerated version implemented here: https://github.com/YosefLab/scib-metrics

Please follow installation instructions in that repo. 

*Note: installing Faiss can be difficult and may take some time*

*Running the full benchmarking suite on many cells can take many hours, even on GPUs with large amounts of memory, such as A100s, and with many threads*

## Load Imports and define Benchmark Function

In [1]:
import numpy as np
import scanpy as sc

from scib_metrics.benchmark import Benchmarker

import faiss

from scib_metrics.nearest_neighbors import NeighborsResults

# Faiss GPU accelerate nearest neighbors methods
def faiss_hnsw_nn(X: np.ndarray, k: int):
    """Gpu HNSW nearest neighbor search using faiss.

    See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md
    for index param details.
    """
    X = np.ascontiguousarray(X, dtype=np.float32)
    res = faiss.StandardGpuResources()
    M = 32
    index = faiss.IndexHNSWFlat(X.shape[1], M, faiss.METRIC_L2)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(X)
    distances, indices = gpu_index.search(X, k)
    del index
    del gpu_index
    # distances are squared
    return NeighborsResults(indices=indices, distances=np.sqrt(distances))


def faiss_brute_force_nn(X: np.ndarray, k: int):
    """Gpu brute force nearest neighbor search using faiss."""
    X = np.ascontiguousarray(X, dtype=np.float32)
    res = faiss.StandardGpuResources()
    index = faiss.IndexFlatL2(X.shape[1])
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(X)
    distances, indices = gpu_index.search(X, k)
    del index
    del gpu_index
    # distances are squared
    return NeighborsResults(indices=indices, distances=np.sqrt(distances))

  from anndata import __version__ as anndata_version
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)


In [2]:
import warnings
warnings.filterwarnings("ignore")
from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection
import pandas as pd

## Benchmarking Function, returns dataframe of scores
def benchmark(ad, label_key="cell_type", batch_key="sample_id", obsm_keys=["X_uce", "X_scGPT", "X_geneformer"]):
    print(f"Running using CT key:", label_key)
    biocons = BioConservation()
    batchcons = BatchCorrection(pcr_comparison=False)
    
    bm = Benchmarker(
        ad,
        batch_key=batch_key,
        label_key=label_key,
        embedding_obsm_keys=obsm_keys,
        bio_conservation_metrics=biocons,
        batch_correction_metrics=None,
        n_jobs=48,
    )
    bm.prepare(neighbor_computer=faiss_brute_force_nn)
    bm.benchmark()
    df = bm.get_results(min_max_scale=False)
    return df

### Load in anndata

For this example, we will benchmark cells from developing mouse brain.

You can download an anndata object with UCE, scGPT and Geneformer embeddings precalulated from [here](https://drive.google.com/drive/folders/1f63fh0ykgEhCrkd_EVvIootBw7LYDVI7)

In [3]:
tabula_ad = sc.read("export_data/new_tabula_all_zero_shot_lognorm.h5ad")
tabula_ad

AnnData object with n_obs × n_vars = 581430 × 45792
    obs: 'donor', 'tissue', 'anatomical_position', 'method', 'cdna_plate', 'library_plate', 'notes', 'cdna_well', 'old_index', 'assay', 'sample_id', 'sample', 'replicate', '10X_run', '10X_barcode', 'ambient_removal', 'donor_method', 'donor_assay', 'donor_tissue', 'donor_tissue_assay', 'cell_ontology_class', 'cell_ontology_id', 'compartment', 'broad_cell_class', 'free_annotation', 'manually_annotated', 'published_2022', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ercc', 'pct_counts_ercc', '_scvi_batch', '_scvi_labels', 'scvi_leiden_donorassay_full', 'age', 'sex', 'ethnicity', 'n_genes', 'donor_num', 'cell_type_coarse', 'n_counts'
    var: 'n_cells'
    uns: 'log1p', 'neighbors', 'pca', 'umap'
    obsm: 'X_geneformer', 'X_pca', 'X_scarches_seed0', 'X_scgpt', 'X_scvi_seed0', 'X_tgpt', 'X_uce'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [4]:
cell_type_column = "cell_ontology_class"
batch_column = "sample_id"

In [5]:
len(tabula_ad.obs[cell_type_column].unique()) # Number of unique cell types

162

In [6]:
len(tabula_ad.obs[batch_column].unique()) # Number of unique batches

167

# Running the Benchmark on the full dataset

In [20]:
tabula_ad.obsm["X_uce"]

array([[ 3.8547866e-02, -2.5216730e-02,  9.2354268e-03, ...,
        -8.5126543e-03,  6.3443370e-03, -5.0794845e-03],
       [ 1.2552227e-02,  1.3078367e-02, -8.7940758e-03, ...,
        -2.1045651e-02, -5.7123909e-03,  1.3679673e-02],
       [ 4.4793326e-02, -1.8928792e-02,  1.3293503e-02, ...,
        -9.8256329e-03,  7.7304617e-03, -3.6268145e-02],
       ...,
       [ 1.3699106e-02, -2.5936974e-02, -2.7112808e-05, ...,
         2.1766354e-03, -3.2877855e-02, -6.6242784e-02],
       [-9.9301739e-03, -1.1609541e-02, -9.7841416e-03, ...,
        -1.6720660e-02,  5.5857226e-03, -1.1244461e-02],
       [ 4.5787334e-04, -2.6317516e-02, -1.3252082e-02, ...,
        -1.6145144e-02,  1.5398547e-02, -2.4077728e-02]], dtype=float32)

In [None]:
subsample_ad = sc.pp.subsample(tabula_ad, n_obs=560_000, copy=True)

In [None]:
new_tabula_benchmark_results_df = benchmark(subsample_ad, label_key=cell_type_column,  batch_key=batch_column, obsm_keys=['X_uce', 'X_geneformer',  'X_scgpt','X_tgpt', 'X_pca', 'X_scvi_seed0', 'X_scarches_seed0'])
new_tabula_benchmark_results_df

In [None]:
# 'X_pca', 'X_scvi_seed0', 'X_scarches_seed0'

In [None]:
1/0

# Running the Benchmark using Resampling (Human Brain Cell Atlas)

Running the benchmark on the full dataset can take a very long time. Instead, we can run on medium sized samples of cells.

In [None]:
sample_size = 100_000 # number of cells

In [None]:
from tqdm.auto import tqdm
sample_score_dfs = []

for i in tqdm(range(N_RESAMPLES)):
    # benchmark one sample
    # sample is drawn with random state i
    subsample_ad = sc.pp.subsample(ad, copy=True, n_obs=sample_size, random_state=i)
    sample_df = benchmark(subsample_ad, label_key=cell_type_column,  batch_key=batch_column)
    # show the results for this sample
    display(subsample_ad)
    # add it to the results for all samples
    sample_score_dfs.append(sample_df)

# Final Scores

We can aggregate the scores from all the samples, taking the mean value (and standard deviation of the score)

In [None]:
grouped_mean = pd.concat([df.drop("Metric Type").reset_index() for df in sample_score_dfs]).groupby("Embedding").agg(np.mean)
# Note: we drop the "Metric Type" row since it contains strings which we can't take the mean of

In [None]:
grouped_std = pd.concat([df.drop("Metric Type").reset_index() for df in sample_score_dfs]).groupby("Embedding").agg(np.std)
# Note: we drop the "Metric Type" row since it contains strings which we can't take the std of

In [None]:
grouped_mean

In [None]:
grouped_mean["Bio conservation"]