In [None]:
import os
#os.chdir('../../10x_tupro/')
#os.chdir('../../maynard_human_brain_analysis/')
#os.chdir('../../her2_positive_breast_tumors/')
#os.chdir('../../human_breast_cancers/')
#os.chdir('../../cosMx_human_liver_cancer/')
#os.chdir('../../cosMx_human_liver_normal/')

In [None]:
import sys
sys.path.append('../')
from src.utils import bootstrapping

In [None]:
from sklearn.metrics.cluster import adjusted_rand_score
from plotnine_prism import *
from tqdm import tqdm
import plotnine as p9
import scanpy as sc
import pandas as pd
import shutil
import glob
import yaml

In [None]:
from time import gmtime, strftime
strftime("%Y-%m-%d %H:%M:%S", gmtime())

In [None]:
with open("../model_and_dataset_info.yaml", "r") as stream:
    model_and_dataset_info = yaml.safe_load(stream)
model_and_dataset_info

In [None]:
out_folder = "out_benchmark"

In [None]:
cluster_paths = glob.glob(f"{out_folder}/*/*_evaluate/clusters*/*")
print(len(cluster_paths))
cluster_paths[:2]

In [None]:
sorted_cluster_paths = {}
for cluster_path in tqdm(cluster_paths):
    sample = cluster_path.split("model-")[1].replace("-best_param.csv", "")
    
    if sample in sorted_cluster_paths:
        sorted_cluster_paths[sample] = [cluster_path, *sorted_cluster_paths[sample]]
    else:
        sorted_cluster_paths[sample] = [cluster_path]

In [None]:
ari_result = []
for sample in tqdm(sorted_cluster_paths.keys()):
    adata = sc.read(f"{out_folder}/data/h5ad/{sample}.h5ad")
    for cluster_path in sorted_cluster_paths[sample]:
        split = cluster_path.split("/")[1].split("/")[0]
        model = cluster_path.split("_evaluate")[0].split("/")[-1]
        adata.obs["Barcode"] = adata.obs.index.values
        df = pd.read_csv(cluster_path)
        cluster_label_dict = pd.Series(df[df.columns[1]].values, index=df[df.columns[0]].values).to_dict()
        adata.obs[model] = adata.obs.Barcode.apply(lambda x: cluster_label_dict[x]).astype(str)
        ari = adjusted_rand_score(adata.obs.ground_truth, adata.obs[model])
        ari_result.append([split, model, sample, ari])
    
ari_result = pd.DataFrame(ari_result, columns=["split", "model", "sample", "ari"])
ari_result.head()

In [None]:
ari_result.groupby("model").ari.agg("median").sort_values(ascending=False).reset_index()
ari_result

In [None]:
tab = ari_result.groupby(["model", "split"]).ari.agg("median").reset_index()
tab["fine_tune"] = tab["split"].apply(lambda x: x.split("_test")[0])
tab.model = tab.model.astype(str)
tab.model = pd.Categorical(tab.model, tab.groupby("model").ari.agg("median").sort_values(ascending=True).index)
tab["modality"] = tab.model.apply(lambda x: model_and_dataset_info["model_modality"][x.split("_")[0]])
tab

In [None]:
current_path = !pwd
dataset = model_and_dataset_info["dataset"][current_path[0].split("/")[-1]]
dataset

In [None]:
g = (p9.ggplot(tab, p9.aes("model", "ari")) 
     + p9.geom_boxplot(p9.aes(fill="modality"), alpha=0.6)
     + p9.theme_bw()
     + p9.coord_flip()
     + scale_fill_prism(palette = "colors")
     + p9.ylab("ARI")
     + p9.xlab("Model")
     + p9.ggtitle(dataset)
     #+ p9.theme(legend_position=(.5, 0))
    )

if tab.fine_tune.unique().size < 13:
    print(g
         + p9.geom_jitter(p9.aes(shape="fine_tune"), position = p9.position_dodge(width=0.75), size=2, alpha=0.5, show_legend=False))
else:
    print(g
         + p9.geom_jitter(alpha=0.3, size=0.5))

In [None]:
tab["dataset"] = dataset
tab.to_csv(f"{out_folder}/summary/model_performance_full.csv", index=False)

In [None]:
tab = tab.groupby("model").ari.apply(lambda x: bootstrapping(x)).reset_index()
tab = pd.DataFrame(tab["ari"].to_list(), columns=['ARI_median', 'ARI_std'], index=tab.model).reset_index()
tab["modality"] = tab.model.apply(lambda x: model_and_dataset_info["model_modality"][x.split("_")[0]])
tab["dataset"] = dataset
tab

In [None]:
tab.to_csv(f"{out_folder}/summary/model_performance_bootstrapping.csv", index=False)

In [None]:
tab.model = tab.model.astype(str)
tab.model = pd.Categorical(tab.model, tab.sort_values("ARI_median", ascending=True).model)
g = (p9.ggplot(tab, p9.aes("model", "ARI_median")) 
 + p9.geom_point(p9.aes(color="modality"), shape="D", size=3)
 + p9.theme_bw()
 + p9.coord_flip()
 + scale_fill_prism(palette = "colors")
 + p9.ylab("ARI")
 + p9.xlab("Model")
 + p9.ggtitle(dataset)
 + p9.geom_errorbar(p9.aes(x="model", ymin="ARI_median-ARI_std",ymax="ARI_median+ARI_std", color="modality"), 
                    alpha=1, size=1, width=0.001)
 + scale_color_prism(palette = "colors")
)
g

In [None]:
g.save(filename = f"{out_folder}/summary/{dataset.replace(' ', '_')}_bootstrap_benchmark.png", dpi=300)

In [None]:
for split in ari_result.split.unique():
    tab = ari_result.query(f"split == '{split}'").copy()
    color = ["#990000" if m == "AESTETIK" else "#FFFFFF" for m in tab.model]
    tab.model = tab.model.astype(str)
    tab.model = pd.Categorical(tab.model, tab.groupby("model").ari.agg("median").sort_values(ascending=True).index)
    print(p9.ggplot(tab, p9.aes("model", "ari")) 
             + p9.geom_boxplot(p9.aes(fill=color), show_legend=False, alpha=.8)
             + p9.geom_jitter(p9.aes(color="sample"), position = p9.position_dodge(width=0.75))
             + p9.theme_bw()
             + p9.coord_flip()
             + p9.scale_fill_manual(values=["#990000", "#FFFFFF"])
             + p9.ggtitle(split)
             + scale_color_prism(palette = "colors")
        )
    

In [None]:
best_parameters = glob.glob(f"{out_folder}/*/*_fine_tune/parameters/best_param.yaml")
data = []
for parameter_path in best_parameters:
    split = parameter_path.split("/")[1] # .split("split_train_")[1]
    model = parameter_path.split("_fine_tune")[0].split("/")[-1]
    with open(parameter_path, "r") as stream:
        parameters = yaml.safe_load(stream)
    parameters = str(parameters)
    data.append([split, model, parameters])
data = pd.DataFrame(data, columns=["split", "model", "parameters"])
data.sort_values("model", inplace=True)
data

In [None]:
for _, row in data.iterrows():
    print(row["split"].split("_")[0], row.model, row.parameters)

In [None]:
def get_cluster_path(split, model, sample):
    path = glob.glob(f"{out_folder}/{split}/{model}_evaluate/clusters*/model-{sample}-best_param.csv")
    assert len(path) == 1, f"{path}; {sample}"
    return path[0]


best_split_sample = ari_result.groupby(["sample", "model"]).ari.agg("max").reset_index()
best_split_sample = best_split_sample.merge(ari_result)

best_split_sample["path"] = best_split_sample.apply(lambda x: get_cluster_path(x["split"], x.model, x["sample"]),axis=1)
best_split_sample = best_split_sample[["sample", "model", "path", "ari"]]
best_split_sample = best_split_sample[~best_split_sample[["sample", "model"]].duplicated()]
best_split_sample[best_split_sample.model == "AESTETIK"]

In [None]:
best_split_sample.to_csv(f"{out_folder}/summary/summary_best_split_sample.csv", index=False)