#### This notebook computes the median ARI score across different hyperparameter settings.

In [None]:
import os
#os.chdir('../../10x_tupro/')
#os.chdir('../../maynard_human_brain_analysis/')
#os.chdir('../../her2_positive_breast_tumors/')
#os.chdir('../../../../')

In [None]:
from sklearn.metrics.cluster import adjusted_rand_score
from tqdm import tqdm
import scanpy as sc
import pandas as pd
import shutil
import glob
import yaml

In [None]:
from time import gmtime, strftime
strftime("%Y-%m-%d %H:%M:%S", gmtime())

In [None]:
model = "AESTETIK"
out_folder = "out_benchmark"
cross_validation_combination = "151675_151676_test_151507_151508_151509_151510_151669_151670_151671_151672_151673_151674"

In [None]:
samples = cross_validation_combination.split("_test")[0].replace("split_train_", "").split("_")
samples

In [None]:
ari_result = []
for sample in tqdm(samples):
    # Load adata
    adata_path = os.path.join(out_folder, "data", "h5ad", f"{sample}.h5ad")
    adata = sc.read(adata_path)
    adata.obs["Barcode"] = adata.obs.index.values
    
    # Find cluster paths
    cluster_paths = glob.glob(os.path.join(out_folder, cross_validation_combination, f"{model}_fine_tune", "clusters", f"model-{sample}-*.csv"))
    
    for cluster_path in cluster_paths:
        # Extract parameter name from the cluster path
        param_name = os.path.splitext(os.path.basename(cluster_path))[0].replace(f"model-{sample}-", "")
        
        # Read cluster data
        df = pd.read_csv(cluster_path)
        cluster_label_dict = pd.Series(df[df.columns[1]].values, index=df[df.columns[0]].values).to_dict()
        
        # Assign cluster labels to adata
        adata.obs[param_name] = adata.obs.Barcode.map(cluster_label_dict).astype(str)
        
        # Calculate ARI
        ari = adjusted_rand_score(adata.obs.ground_truth, adata.obs[param_name])
        
        # Append results
        ari_result.append([param_name, sample, ari])

# Create DataFrame
ari_result = pd.DataFrame(ari_result, columns=["model", "sample", "ari"])
ari_result.head()

In [None]:
# groupby param and compute median, sorted by ARI
ari_result = ari_result.groupby("model").ari.agg("median").reset_index().sort_values("ari", ascending=False)
ari_result

In [None]:
top_model = ari_result.model.values[0]
top_model

In [None]:
parameter_path = f"{out_folder}/{cross_validation_combination}/{model}_fine_tune/parameters/{top_model}.yaml"
parameter_path

In [None]:
# my best parameters 
with open(parameter_path, "r") as stream:
    parameters = yaml.safe_load(stream)
parameters

In [None]:
shutil.copy(parameter_path, f"{out_folder}/{cross_validation_combination}/{model}_fine_tune/parameters/best_param.yaml")