In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

In [None]:
from matplotlib import pyplot as plt
from plotnine_prism import *
import plotnine as p9
import pandas as pd
import numpy as np
import glob
import yaml

In [None]:
with open("../model_and_dataset_info.yaml", "r") as stream:
    model_and_dataset_info = yaml.safe_load(stream)

dataset_name = model_and_dataset_info["dataset"]
dataset_name

In [None]:
perf_report_files = glob.glob("../*/out_benchmark/summary/model_performance_bootstrapping.csv")
perf_report_files = [dataset for dataset in perf_report_files if dataset.split("/")[1] in dataset_name.keys()]
perf_report_files

In [None]:
data = []
for file in perf_report_files:
    df = pd.read_csv(file)
    data.append(df)
data = pd.concat(data)
data["Dataset"] = data.dataset
data.ARI_median = data.ARI_median.astype(float)
data.ARI_std = data.ARI_std.astype(float)
data

In [None]:
datasets = ['LIBD Human DLPFC', 'Human Breast Cancer','Tumor Profiler', 'Human Liver Normal', 'Human Liver Cancer']
models = ['SpaGCN', 'BayesSpace', 'AESTETIK', 'GraphST', 'STAGATE', 'MUSE', 'Leiden', 'stLearn']


tab = data.query("model in @models and dataset in @datasets").copy()
tab = tab.query("dataset in @datasets").copy()
tab.Dataset = pd.Categorical(tab.Dataset, datasets)

tab["model_rank"] = tab.groupby("dataset").ARI_median.rank(ascending=False)
tab.model = pd.Categorical(tab.model, tab.groupby("model").model_rank.agg("median").sort_values().index)

position_dodge_width = 0.8
tab["Model"] = tab.model

tab["Modality"] = pd.Categorical(tab["modality"], ['transcriptomics', 
                                  'transcriptomics + spatial', 
                                  'transcriptomics + image', 
                                  'transcriptomics + spatial + image'])

In [None]:
p = (p9.ggplot(tab, p9.aes("Dataset", "ARI_median")) 
 + p9.geom_point(p9.aes(color="Model", shape="Modality"), size=3, position=p9.position_dodge(width=position_dodge_width)) 
 + p9.facet_wrap("~Dataset", scales="free", ncol=5)
 + p9.geom_errorbar(p9.aes(x="Dataset", ymin="ARI_median-ARI_std",ymax="ARI_median+ARI_std", color="Model"), 
                    width=0.001, alpha=1, size=1,
                   position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.theme(subplots_adjust={'wspace': 0.15}, figure_size=(20, 5), axis_text_x = p9.element_blank(), 
            legend_position="right",
            text=p9.element_text(size=15),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16))
 + p9.ylab("ARI")
 + p9.xlab("")
 + scale_color_prism(palette = "colors")
 + p9.guides(color=p9.guide_legend(nrow=4, override_aes = p9.aes(shape = ".")))
 + p9.scale_y_continuous(breaks=list(np.arange(int(tab['ARI_median'].min()), int(tab['ARI_median'].max())+1, 0.1)))
)
p.save("figures/ari_performance.png", dpi=300)
p