In [None]:
import pandas as pd
import numpy as np
import plotnine as p9
import glob
import yaml
from plotnine_prism import *

In [None]:
with open("../config.yaml", "r") as stream:
    DATASET_INFO = yaml.safe_load(stream)
DATASET_INFO

In [None]:
pearson_variation_paths = glob.glob("../*/out_benchmark/evaluation/pearson_variation.csv")
pearson_variation_paths


In [None]:
data_var = []
for file in pearson_variation_paths:
    tab = pd.read_csv(file)
    dataset = file.split("/")[1]
    if dataset not in DATASET_INFO["DATASET_NAME"]: continue
    tab["Dataset"] = DATASET_INFO["DATASET_NAME"][dataset]
    data_var.append(tab)
data_var = pd.concat(data_var)
data_var.Dataset = pd.Categorical(data_var.Dataset, DATASET_INFO["DATASET_NAME"].values())
data_var.pearson_median = data_var.pearson_median.astype(float)
data_var.pearson_std = data_var.pearson_std.astype(float)
data_var

In [None]:
data_var.query("Dataset == 'Tumor Profiler (n=18)' and top_n in [500]").sort_values("pearson_median")

In [None]:
data_var.query("top_n in [3000]").sort_values("pearson_median")

In [None]:
tab = data_var.copy()
position_dodge_width = 0.5

tab["model_rank"] = tab.groupby("Dataset", observed=False).pearson_median.rank(ascending=False)
tab["Model"] = pd.Categorical(tab.model, tab.groupby("model").model_rank.agg("median").sort_values().index)
tab.top_n = tab.top_n.astype("category")

g = (p9.ggplot(tab, p9.aes("top_n", "pearson_median", color="Model"))
 + p9.geom_line(p9.aes(color="Model", group="Model"), linetype="dashed", 
                position=p9.position_dodge(width=position_dodge_width))
 + p9.geom_point(p9.aes(color="Model"), position=p9.position_dodge(width=position_dodge_width), size=0.7) 
 + p9.facet_wrap("~Dataset", scales="free_y", ncol=2)
 + p9.geom_errorbar(p9.aes(x="top_n", ymin="pearson_median-pearson_std",
                           ymax="pearson_median+pearson_std", color="Model"), 
                    width=0.4, alpha=1, size=0.5,
                    position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.theme(panel_spacing_y=0, panel_spacing_x=0, figure_size=(14, 6), 
            #axis_text_x = p9.element_blank(), 
            legend_position="right",
            text=p9.element_text(size=17),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16))
 + p9.ylab("Pearson correlation")
 + p9.xlab("Most variable genes")
 + p9.theme(axis_text_x = p9.element_text(angle = 90, hjust = 1))
 + scale_color_prism(palette = "colors")
 + p9.guides(color=p9.guide_legend(nrow=10, override_aes = p9.aes(shape = ".")))
)
g.save("figures/Figure3A-benchmark.png", dpi=300)
g