In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from evaluation.utils import entropy_batch_mixing, knn_purity
import scanpy as sc
import seaborn as sns
import scIB as scib
import pandas as pd
from matplotlib import pyplot as plt

In [2]:
adata = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/mouse_brain_subsampled_normalized_hvg.h5ad'))
batch_key = 'study'
label_key = 'cell_type'
model = 'scanvi'
dataset = 'brain'
version = 'first'
ratios = [1, 2, 3]

In [3]:
def compute_metrics(latent_adata, adata, rqr=None, batch_key='study', label_key='cell_type'):
    latent_adata.obsm['X_pca'] = latent_adata.X
    print(adata.shape, latent_adata.shape)
    n_batches = len(adata.obs[batch_key].unique().tolist())
    
    scores = scib.metrics.metrics(adata, latent_adata, batch_key, label_key, 
                                  nmi_=True, ari_=True, silhouette_=True, pcr_=True, graph_conn_=True, 
                                  isolated_labels_=True, hvg_score_=False)
    scores = scores.T
    scores = scores[['NMI_cluster/label', 'ARI_cluster/label', 'ASW_label', 'ASW_label/batch', 
                     'PCR_batch', 'isolated_label_F1', 'isolated_label_silhouette', 'graph_conn']]
    
    ebm = entropy_batch_mixing(latent_adata, batch_key, n_neighbors=15)
    knn = knn_purity(latent_adata, label_key, n_neighbors=15)
    
    scores['EBM'] = ebm
    scores['KNN'] = knn
    scores['method'] = model
    scores['data'] = dataset
    scores['rqr'] = rqr / n_batches if rqr is not None else None
    scores.rqr = scores.rqr.round(2)
    scores['reference_time'] = 0.0
    scores['query_time'] = 0.0
    
    return scores

## Calculating metrics for all ratios:

In [4]:
scores = None
for ratio in ratios:
    test_num = ratio
    latent_adata = sc.read(os.path.expanduser(f'~/Documents/benchmarking_results/figure_3/{model}/{dataset}/test_{test_num}_{version}_cond/full_data.h5ad'))
    latent_adata.obs[batch_key] = latent_adata.obs['batch'].values
    latent_adata.obs[label_key] = latent_adata.obs['celltype'].values  
    df = compute_metrics(latent_adata, adata, ratio, batch_key, label_key)
    scores = pd.concat([scores, df], axis=0) if scores is not None else df                

(15681, 1000) (15681, 10)
clustering...
NMI...
ARI...
silhouette score...
PC regression...
isolated labels...
Graph connectivity...
Calculating EBM with n_cat = 5


  _,labs = connected_components(adata_post_sub.uns['neighbors']['connectivities'], connection='strong')


EBM: 0.21688406380342495
KNN-P: 0.8641900910088618
(15681, 1000) (15681, 10)
clustering...
NMI...
ARI...
silhouette score...
PC regression...
isolated labels...
Graph connectivity...
Calculating EBM with n_cat = 5


  _,labs = connected_components(adata_post_sub.uns['neighbors']['connectivities'], connection='strong')


EBM: 0.2372365691236798
KNN-P: 0.8722375698217373
(15681, 1000) (15681, 10)
clustering...
NMI...
ARI...
silhouette score...
PC regression...
isolated labels...
Graph connectivity...
Calculating EBM with n_cat = 5


  _,labs = connected_components(adata_post_sub.uns['neighbors']['connectivities'], connection='strong')


EBM: 0.32779604651626465
KNN-P: 0.8856723157832422
(15681, 1000) (15681, 10)
clustering...
NMI...
ARI...
silhouette score...
PC regression...
isolated labels...
Graph connectivity...
Calculating EBM with n_cat = 5


  _,labs = connected_components(adata_post_sub.uns['neighbors']['connectivities'], connection='strong')


EBM: 0.3479491217464211
KNN-P: 0.899339174367721


In [5]:
scores

Unnamed: 0,NMI_cluster/label,ARI_cluster/label,ASW_label,ASW_label/batch,PCR_batch,isolated_label_F1,isolated_label_silhouette,graph_conn,EBM,KNN,method,data,rqr,reference_time,query_time
0,0.680134,0.562045,0.595679,0.842552,0.296231,0.824792,0.554949,0.98941,0.216884,0.86419,scanvi,pancreas,0.2,0.0,0.0
0,0.741267,0.742484,0.609009,0.885071,0.645215,0.850882,0.55946,0.988062,0.237237,0.872238,scanvi,pancreas,0.4,0.0,0.0
0,0.750032,0.749108,0.621274,0.901401,0.746723,0.855715,0.566215,0.987378,0.327796,0.885672,scanvi,pancreas,0.6,0.0,0.0
0,0.765541,0.757958,0.63155,0.899998,0.73749,0.872497,0.56615,0.987276,0.347949,0.899339,scanvi,pancreas,0.8,0.0,0.0


In [6]:
scores.to_csv(os.path.expanduser(f'~/Documents/benchmarking_results/figure_3/{model}/{dataset}/rqr_{dataset}_{model}_{version}_cond.csv'), index=False)