In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

In [None]:
import sys
sys.path.append("../")

In [None]:
from src.utils import bootstrapping, normalize
from matplotlib import pyplot as plt
from plotnine_prism import *
import patchworklib as pw
import plotnine as p9
import numpy as np
import pandas as pd
import numpy as np
import yaml
import glob

In [None]:
perf_report_files = glob.glob("../*/out_ablation/summary/model_performance_full.csv")
print(len(perf_report_files))
perf_report_files

In [None]:
with open("../model_and_dataset_info.yaml", "r") as stream:
    model_and_dataset_info = yaml.safe_load(stream)

model_modality = model_and_dataset_info["model_modality"]
dataset_name = model_and_dataset_info["dataset"]

In [None]:
data = []
for file in perf_report_files:
    df = pd.read_csv(file)
    dataset = file.split("/")[1]
    dataset = dataset_name[dataset]
    df["dataset"] = dataset
    data.append(df)
data = pd.concat(data)
data.dataset = pd.Categorical(data.dataset, [d for d in dataset_name.values() if d in data.dataset.unique()])
data

In [None]:
data.dataset.unique()

In [None]:
data["hparam"] = data.model.apply(lambda x: "\n".join(x.split("_")[1:-1]) if len(x.split("_")) > 1 else x)
data["hparam_values"] = data.model.apply(lambda x: int(x.split("_")[-1]) if x.split("_")[-1].isnumeric() or x.split("_")[-1] == "-1" else "full" if len(x.split("_")) == 1 else x.split("_")[-1])

In [None]:
data.hparam.unique()

In [None]:
def transform_num(x):
    value = x.model.split("_")[-1]
    
    hparam = x.hparam
    
    if value.isnumeric():
        value = int(value)
    elif "." in value:
        value = float(value)
    else:
        pass
    
    return value
        
        

In [None]:
tab = data.copy().query("model != 'AESTETIK'")
tab["ari"] = tab.groupby(["hparam", "dataset"]).ari.transform(lambda x: normalize(x))
tab = tab.groupby(["model", "dataset"]).ari.apply(lambda x: bootstrapping(x)).reset_index()
tab.loc[tab.ari.isna(), "ari"] = tab.loc[tab.ari.isna(), "ari"].apply(lambda x: [np.nan,np.nan])
tab = pd.DataFrame(tab["ari"].to_list(), columns=['ARI_median', 'ARI_std'], index=[tab.model, tab.dataset]).reset_index()
tab["hparam"] = tab.model.apply(lambda x: " ".join(x.split("_")[1:-1]) if len(x.split("_")) > 1 else x)
tab["hparam_values"] = tab.apply(lambda x: transform_num(x), axis=1)
tab.head()

In [None]:
tab.hparam.unique()

In [None]:
tab.hparam_values.unique()

In [None]:
hparam_of_int = {'train size': "Train size", 
                 'triplet loss': "Triplet loss", 
                 'rec loss': "Reconstruction loss", 
                 'refine cluster': "Refine cluster", 
                 'clustering method': "Clustering method"}

In [None]:
scace_facet = [max(0.4, dict(tab[["hparam", "hparam_values"]].drop_duplicates().hparam.value_counts() / 10)[a]) for a in hparam_of_int]
scace_facet

In [None]:
hparam_of_int = {'window size': "Window size"}
scace_facet = [max(0.4, dict(tab[["hparam", "hparam_values"]].drop_duplicates().hparam.value_counts() / 10)[a]) for a in hparam_of_int]
scace_facet

In [None]:
tab.dataset.unique()

In [None]:
plot_data = tab.copy()
plot_data = plot_data.query("dataset in ['LIBD Human DLPFC', 'Tumor Profiler', 'Simulated Data (10)']").copy()
plot_data.dataset = pd.Categorical(plot_data.dataset, [d for d in dataset_name.values() if d in plot_data.dataset.unique()])

order = ["off", "on", "multi", "0", "0.01", "0.1", "0.25", "0.5", "0.75", "1" ,"10", "100", "1000", "2000", "5000", "all", "1.5", "2", 
         "3", "5", "7", "9", "11", 'kmeans', 'mclust', "bgm"]
position_dodge_width = 0.6
plot_data["Dataset"] = plot_data.dataset
plot_data.hparam = plot_data.hparam.astype(str)
plot_data = plot_data.query("hparam in @hparam_of_int.keys()")
plot_data.hparam = plot_data.hparam.apply(lambda x: hparam_of_int[x])
plot_data.hparam = pd.Categorical(plot_data.hparam.values, hparam_of_int.values())
plot_data.hparam_values = pd.Categorical(plot_data.hparam_values.astype(str), order)
p = (p9.ggplot(plot_data, p9.aes("hparam_values", "ARI_median")) 
 + p9.geom_point(p9.aes(color="Dataset"), shape="D", size=3, position=p9.position_dodge(width=position_dodge_width))
 + p9.geom_line(p9.aes(color="dataset", group="dataset"), linetype="dashed", alpha=0.4, size=1, position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.facet_grid("~hparam", scales="free_x", space={"x": scace_facet, "y":[1]})
 + scale_color_prism(palette = "colors")
 + p9.ylab("ARI\nz-score")
 + p9.xlab("Hyperparameter value")
 + p9.geom_errorbar(p9.aes(x="hparam_values", ymin="ARI_median-ARI_std",ymax="ARI_median+ARI_std", color="Dataset"), 
                    width=0.001, alpha=1, size=1,
                   position=p9.position_dodge(width=position_dodge_width))
 + p9.theme(subplots_adjust={'wspace': 0.0}, figure_size=(6, 6), axis_text_x = p9.element_text(angle = 25))
 + p9.theme(text=p9.element_text(size=15),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16))
)
p.save(filename = "figures/ablation_study_ari_only_window.png", dpi=300)
p