In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from evaluation.utils import entropy_batch_mixing, knn_purity, nmi, asw_c, asw_b
import scanpy as sc
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import json
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
save_umaps = False
calc_metrics = True

figure = 2

# 'scvi' or 'scanvi'
model = "scanvi"

# Choose results of deeply injected model or first injected model
deep_conds = [False]

In [None]:
# For Figure 3 and Full Integration choose dataset 'pancreas' or 'brain'
data = "immune_all_human_fig6"

# For Figure 3, 4, 6 choose Test number (Reference to Query ratio), For figure 2 surgery options
test_nrs = [1,2,3,4]

# For Figure 5 choose OOD experiment
ood_nr = 2

In [None]:
if figure == 1 or figure == 7 or figure == 5 or (figure == 0 and model == 'scvi'):
    test_nrs = [0]
if figure == 0:
    deep_conds = [True]
for deep_cond in deep_conds:
    for test_nr in test_nrs:
        if deep_cond:
            deep_label = "deep_cond"
        else:
            deep_label = "first_cond"
        if figure == 0:
            if model == 'scvi':
                dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/full_integration/{model}/{data}/')
            else:
                dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/full_integration/{model}/{data}/label_ratio_{test_nr}/')  
        if figure == 1:
            dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/figure_1/{model}/{deep_label}/')
        if figure == 2:
            dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/figure_2/{model}/{deep_label}/')
            if test_nr == 1:
                dir_path = f'{dir_path}reference/'
            if test_nr == 2:
                dir_path = f'{dir_path}freezed_expr/'
            if test_nr == 3:
                dir_path = f'{dir_path}freezed/'
            if test_nr == 4:
                dir_path = f'{dir_path}unfreezed/'
        elif figure == 3:
            dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/figure_3/{model}/{data}/test_{test_nr}_{deep_label}/')
        elif figure == 4:
            dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/figure_4/{model}/test_{test_nr}_{deep_label}/')
        elif figure == 5:
            dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/figure_5/{model}/ood_{ood_nr}_{deep_label}/')
        elif figure == 6:
            dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/figure_6/{model}/test_{test_nr}_{deep_label}/')
        elif figure == 7:
            dir_path = os.path.expanduser(f'~/Documents/benchmarking_adata/figure_7/{model}/{deep_label}/')
            
        adata_dict = dict()
        for file in os.listdir(dir_path):
            if file.endswith(".h5ad"):
                adata_dict[os.path.splitext(file)[0]] = sc.read(os.path.join(dir_path, file))
                
        if save_umaps:
            for key, value in adata_dict.items():
                sc.pp.neighbors(value)
                sc.tl.leiden(value)
                sc.tl.umap(value)
                plt.figure()
                sc.pl.umap(
                    value,
                    color=["batch", "celltype"],
                    frameon=False,
                    ncols=1,
                    show=False
                )
                plt.savefig(f'{dir_path}{key}_umap.png', bbox_inches='tight')
                if model == "scanvi":
                    sc.pl.umap(
                        value,
                        color=["predictions", "celltype"],
                        frameon=False,
                        ncols=1,
                        show=False
                    )
                    plt.savefig(f'{dir_path}{key}_umap_pred.png', bbox_inches='tight')
                    
        if calc_metrics:
            results = dict()
            for key, adata in adata_dict.items():
                print("\n", key)
                data_result = dict()
                knn_s = knn_purity(adata)
                nmi_s = nmi(adata)
                asw_c_s = asw_c(adata)
                bio_con = (knn_s + nmi_s + asw_c_s) / 3
                if len(np.unique(adata.obs.batch).tolist()) > 1:
                    ebm_s = entropy_batch_mixing(adata)
                    asw_b_s = asw_b(adata)
                    batch_mix = (ebm_s + asw_b_s) / 2
                    latent_overall = (batch_mix + bio_con) / 2
                else:
                    ebm_s = asw_b_s = batch_mix = 0
                    latent_overall = bio_con
                data_result["latent_score"] = latent_overall
                data_result["batch_mixing"] = batch_mix
                data_result["bio_conservation"] = bio_con
                data_result["ebm"] = ebm_s
                data_result["knn"] = knn_s
                data_result["nmi"] = nmi_s
                data_result["asw_b"] = asw_b_s
                data_result["asw_c"] = asw_c_s
                if model =="scanvi":
                    data_result["accuracy"] = np.mean(adata.obs.predictions == adata.obs.celltype.tolist())
                    print("Accuracy:", data_result["accuracy"])
                    data_result["f1"] = f1_score(adata.obs.predictions, adata.obs.celltype, average='macro')
                    print("F1:", data_result["f1"])
                results[key] = data_result
                
        with open(f'{dir_path}metric_scores.txt', 'w') as filehandle:
            json.dump(results, filehandle)