In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300
import matplotlib.pyplot as plt
from plotnine_prism import *
import plotnine as p9
import pandas as pd
import numpy as np
import glob
import yaml
import sys
sys.path.append("../")
from src.utils import bootstrapping, normalize

In [None]:
with open("../model_and_dataset_info.yaml", "r") as stream:
    model_and_dataset_info = yaml.safe_load(stream)
model_and_dataset_info

In [None]:
target_dataset = { 
                  '10x_TuPro_v2': 'Tumor Profiler (~4)',
                  "simulated_data_5_clusters": "Simulated Data (5)",
                  'maynard_human_brain_analysis': 'LIBD Human DLPFC (~6)',
                  "simulated_data_10_clusters": "Simulated Data (10)", 
                  "simulated_data_15_clusters": "Simulated Data (15)", 
}      

In [None]:
loss_paths = glob.glob("../*/out_ablation/*/AESTETIK_triplet_loss_multi_evaluate/loss/*")
loss_paths = [dataset for dataset in loss_paths if dataset.split("/")[1] in model_and_dataset_info["dataset"].keys()]
len(loss_paths)

In [None]:
loss_pairs = []
for loss_path_on in loss_paths:
    loss_path_off = loss_path_on.replace("AESTETIK_triplet_loss_multi_evaluate", "AESTETIK_triplet_loss_single_evaluate")
    loss_pairs.append([loss_path_on, loss_path_off])
loss_pairs[1]

In [None]:
compare = []
max_std = -100
for multi_on, multi_off in loss_pairs:
    sample = multi_on.split("/")[-1].replace("model-", "").replace("-best_param.npy", "")
    dataset = multi_on.split("/")[1]
    dataset = target_dataset[dataset]
    train_split = multi_on.split("/")[3]
    multi_on = np.load(multi_on)
    multi_off = np.load(multi_off)
    multi_on_std, multi_off_std = np.diff(multi_on).std(), np.diff(multi_off).std()
    compare.append([sample, train_split, dataset, multi_on_std, multi_off_std])
    
    if max_std < (multi_on_std / multi_off_std) and multi_on[20:50].min() < multi_off[20:50].min():
        print("found")
        max_std = multi_on_std / multi_off_std
        multi_on_plot, multi_off_plot = multi_on, multi_off
    
compare_df = pd.DataFrame(compare, columns=["sample", "train_split", "dataset", "multi_on_std", "multi_off_std"])
compare_df.dataset = pd.Categorical(compare_df.dataset, target_dataset.values())
compare_df

In [None]:
tab = compare_df.groupby(["train_split", "dataset"]).agg("mean").reset_index().melt(id_vars=["train_split", "dataset"])
tab = tab[~tab.value.isna()]
tab.head()

In [None]:
#tab["value"] = tab.groupby(["dataset"]).value.transform(lambda x: normalize(x))

In [None]:
tab["Triplet loss"] = tab.variable.apply(lambda x: "multi" if "on" in x else "single")
tab["Triplet loss"] = pd.Categorical(tab["Triplet loss"], ["single", "multi"])
tab = tab.groupby(["dataset", "Triplet loss"]).value.apply(lambda x: bootstrapping(x)).reset_index()
tab.loc[tab.value.isna(), "value"] = tab.loc[tab.value.isna(), "value"].apply(lambda x: [np.nan,np.nan])
tab = pd.DataFrame(tab["value"].to_list(), columns=['value_median', 'value_std'], index=[tab["dataset"], tab["Triplet loss"]]).reset_index()
tab

In [None]:
#tab["value_median"] = tab.groupby("dataset").value_median.transform(lambda x: (x - x.mean()))

In [None]:
position_dodge_width = 0.5
p = (p9.ggplot(tab, p9.aes("dataset", "value_median")) 
 + p9.geom_point(p9.aes(color="Triplet loss"), shape="D", size=3, position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + scale_color_prism(palette = "colors")
 + p9.ylab("Loss diviation")
 + p9.xlab("")
 + p9.geom_errorbar(p9.aes(x="dataset", ymin="value_median-value_std",ymax="value_median+value_std", color="Triplet loss"), 
                    width=0.001, alpha=1, size=1,
                   position=p9.position_dodge(width=position_dodge_width))
 + p9.theme(subplots_adjust={'wspace': 0.0}, figure_size=(6, 5), axis_text_x = p9.element_text(angle = 15, hjust=0.5))
 + p9.theme(text=p9.element_text(size=15),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16))
)
p.save(filename = "figures/loss_ablation.png", dpi=300)
p

In [None]:
#df = pd.DataFrame(zip(multi_on, multi_off), columns=["multi", "single"]).melt(var_name="Triplet loss", value_name="Loss")
#df["Iteration"] = [*list(range(0,100)), *list(range(0,100))]
#df["Triplet loss"] = pd.Categorical(df["Triplet loss"], ["single", "multi"])
#
#p = (p9.ggplot(df, p9.aes("Iteration", "Loss", color="Triplet loss")) 
# + p9.geom_line(size=1)
# + p9.theme_bw()
# + scale_color_prism(palette = "colors")
# + p9.theme(subplots_adjust={'wspace': 0.0}, figure_size=(8, 5))
# + p9.theme(
#            text=p9.element_text(size=15),
#            strip_text=p9.element_text(size=17),
#            legend_title=p9.element_text(size=17),
#            legend_text=p9.element_text(size=16))
#)
#p.save(filename = "figures/loss_example.png", dpi=300)
#p