In [1]:
import numpy as np
import scanpy as sc
import pandas as pd
import anndata as ad
import scib_metrics
import scib
import os
#import faiss
from scipy import sparse
#import torch
from rich import print

In [2]:
#TF_CPP_MIN_LOG_LEVEL=0

In [3]:
#import torch
#torch.cuda.is_available()

In [4]:
def define_path(model_type, integration_type):
    adata_path = "../results/scglue/point_nine_corr/all_samples/" + model_type + '/' + integration_type + '/'
    return adata_path

In [5]:
def load_adata(adata_path):
    print("loading trimodal adata..\n\n")
    combined = ad.read_h5ad(adata_path+"combined.h5ad")
    
    print(sparse.issparse(combined.X))
    print(sparse.issparse(combined.obsp['distances']))
    print(sparse.issparse(combined.obsp['connectivities']))
            
    
    print(combined)
    return combined

In [6]:
def remove_obs_names_duplicates(combined):
    # Check for duplicate observation names
    print("Number of duplicate observation names:", sum(combined.obs.index.duplicated()))

    # Rename the observations with a unique suffix
    unique_obs_names = combined.obs.index
    unique_obs_names_suffix = [f"{name}_{i}" for i, name in enumerate(unique_obs_names)]
    combined.obs.index = unique_obs_names_suffix

    # Check that there are no more duplicate observation names
    print("Number of duplicate observation names:", sum(combined.obs.index.duplicated()))

In [7]:
def compute_neighbours(combined):
    print("computing neighbours on scglue latent space..\n\n")
    sc.pp.neighbors(combined, 15,  use_rep='X_glue')

In [8]:
def define_variables(combined, label_key, batch_key):
    print("defining variables..\n\n")
    X = combined.obsm['X_glue']
    combined.obs[label_key+'_code'] = combined.obs[label_key].cat.codes.to_numpy()
    combined.obs[batch_key+'_code'] = combined.obs[batch_key].cat.codes.to_numpy() #for scib
    distances_nn = combined.obsp['distances']
    connectivities = combined.obsp['connectivities']
    
    return X, label_key, batch_key, distances_nn, connectivities

In [9]:
def run_metrics(combined, X, label_key, batch_key, distances_nn, connectivities):
    print("running metrics..\n\n")
    batch_effect_metrics = {}
    #removal of batch effects
    print("running pcr..\n\n")
    #pcr = scib_metrics.utils.principal_component_regression(X=X, covariate=combined.obs[batch_key+'_code'])
    print("running graph connectivity..\n\n")
    #graph_connectivity = scib_metrics.graph_connectivity(X=distances_nn, labels=combined.obs[label_key+'_code'])
    print("running ilisi knn (scib implementation).. \n\n")
    #ilisi_knn = scib.metrics.ilisi_graph(combined, batch_key, "knn", use_rep='X_glue')
    print("running silhouette batch (scib implementation)..\n\n")
    silhouette_batch = scib.metrics.silhouette_batch(combined, batch_key, label_key, 'X_glue')
    
    #conservation of variance from cell identity labels (label conservation metrics)
    label_conserv_metrics = {}
    print("running silhouette label.. \n\n")
    silhouette_label =  scib.metrics.silhouette(combined, label_key, 'X_glue')
    print("running nmi, ari with leiden..\n\n")
    nmi_ari_leiden = scib_metrics.nmi_ari_cluster_labels_leiden(X=connectivities, labels=combined.obs[label_key+'_code'])
    print("running clisi knn..\n\n")
    clisi_knn =  scib.metrics.clisi_graph(combined, label_key, 'knn', use_rep='X_glue')
    
    print("generating batch effect metrics dictionary.. \n\n")
    batch_effect_metrics.update([('pcr', pcr), ('graph_connectivity', graph_connectivity),
                                 ('ilisi_knn', ilisi_knn), ('silhouette_batch', silhouette_batch)]) 
    
    print("generating label conservation metrics dictionary.. \n\n")
    label_conserv_metrics.update([('silhouette_label', silhouette_label),
                                  ('nmi_ari_leiden', nmi_ari_leiden), 
                                  ('clisi_knn', clisi_knn)])
    
    print("batch effect metrics for current integration coefficient:")
    print(batch_effect_metrics)
    print("label conservation metrics for current integration coefficient:")
    print(label_conserv_metrics)
    
    return batch_effect_metrics, label_conserv_metrics

In [10]:
def save_metrics(adata_path, batch_effect_metrics, label_conserv_metrics, label_key, batch_key):
    metrics_path = adata_path + 'scib/'
    os.makedirs(metrics_path, exist_ok=True)
    with open(metrics_path+"scib-"+batch_key+'-'+label_key+'.txt', "w") as f:
        f.write("batch_effect_metrics:\n\n")
        for key, value in batch_effect_metrics.items():
            f.write(f"{key}: {value}\n")
        print("\n")
        f.write("label conservation metrics:\n\n")
        for key, value in label_conserv_metrics.items():
            f.write(f"{key}: {value}\n")

In [11]:
def main(model_types = ['paired', 'unpaired'],
         integration_types = ['trimodal', 'full', 'cite'],
         label_key='Annotation_major_subset',
         batch_key='Domain'):
    
    print("analysis starting..\n\n")
    for model_type in model_types: #run both paired and unpaired
        for integration_type in integration_types: #run full, cite-only, and trimodal models      
            print(f"Computing metrics for model type:'{model_type}' and integration type '{integration_type}'\n\n")
            adata_path = define_path(model_type, integration_type)
            combined = load_adata(adata_path)
            remove_obs_names_duplicates(combined)
            X, label_key, batch_key, distances_nn, connectivities = define_variables(combined, label_key, batch_key)
            #compute_neighbours(combined)
            batch_effect_metrics, label_conserv_metrics = run_metrics(combined, X, label_key,
                                                                      batch_key, distances_nn, connectivities)
            #save_metrics(adata_path, batch_effect_metrics, label_conserv_metrics, label_key, batch_key)
    
    print("analysis finished")

In [None]:
main(label_key = 'Annotation_major_subset', batch_key = 'Domain', model_types=['paired'], integration_types=['trimodal'])

  utils.warn_names_duplicates("obs")


In [None]:
main(label_key = 'Annotation_cell_type', batch_key = 'Domain', model_types=['paired'], integration_types=['trimodal'])

In [None]:
main(label_key = 'Annotation_major_subset', batch_key = 'Domain_major', model_types=['paired'], integration_types=['trimodal'])

In [None]:
main(label_key = 'Annotation_cell_type', batch_key = 'Domain_major', model_types=['paired'], integration_types=['trimodal'])