In [None]:
import pandas as pd
import numpy as np
import plotnine as p9
import glob
import yaml
from plotnine_prism import *

In [None]:
with open("../config.yaml", "r") as stream:
    DATASET_INFO = yaml.safe_load(stream)
DATASET_INFO

In [None]:
pearson_variation_paths = glob.glob("../*/out_benchmark*/prediction/model_evaluation_table.csv")
pearson_variation_paths = [f for f in pearson_variation_paths if 'COAD' not in f]
pearson_variation_paths

In [None]:
data_var = []
for file in pearson_variation_paths:
    tab = pd.read_csv(file)
    dataset_name = file.split("/")[1]
    slide_type = file.split("/")[2].split("_")[-1]
    
    tab["Dataset"] = DATASET_INFO["DATASET_NAME"][f"{dataset_name}_{slide_type}"]
    data_var.append(tab)
data_var = pd.concat(data_var)
data_var.Dataset = pd.Categorical(data_var.Dataset, DATASET_INFO["DATASET_NAME"].values())
data_var.pearson_median = data_var.pearson_median.astype(float)
data_var.pearson_std = data_var.pearson_std.astype(float)
data_var

In [None]:
data_var = data_var[~data_var.model.isin(["LinearRegression", "MLP"])]

In [None]:
tab = data_var.copy()
position_dodge_width = 0.4


tab["model_rank"] = tab.groupby("Dataset", observed=False).pearson_median.rank(ascending=False)
tab["Method"] = pd.Categorical(tab.model, tab.groupby("model").model_rank.agg("median").sort_values().index)
tab.top_n = tab.top_n.astype("category")
g = (p9.ggplot(tab, p9.aes("top_n", "pearson_median", color="Method"))
 + p9.geom_line(p9.aes(color="Method", group="Method"), linetype="dashed", size=1,
                position=p9.position_dodge(width=position_dodge_width))
 + p9.geom_point(p9.aes(color="Method"), position=p9.position_dodge(width=position_dodge_width), size=0.7) 
 + p9.facet_wrap("~Dataset", scales="free_y", ncol=2)
 + p9.geom_errorbar(p9.aes(x="top_n", ymin="pearson_median-pearson_std",
                           ymax="pearson_median+pearson_std", color="Method"), 
                    width=0.4, alpha=1, size=0.5,
                    position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.theme(panel_spacing_y=0, panel_spacing_x=0, figure_size=(10, 5), 
            #axis_text_x = p9.element_blank(), 
            legend_position="right",
            text=p9.element_text(size=17),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16))
 + p9.ylab("Pearson correlation")
 + p9.xlab("Most variable genes")
 + p9.theme(axis_text_x = p9.element_text(angle = 90, hjust = 1))
 + scale_color_prism(palette = "colors")
 + p9.guides(color=p9.guide_legend(nrow=10, override_aes = p9.aes(shape = ".")))
 #+ p9.scale_y_continuous(breaks=list(np.arange(int(tab['auc_median'].min()), int(tab['auc_median'].max())+1, 0.1)))
)
g.save("figures/Figure6B-TCGA_benchmark.png", dpi=300)
g

In [None]:
tab = data_var.copy()
position_dodge_width = 0.2


tab["model_rank"] = tab.groupby("Dataset", observed=False).pearson_median.rank(ascending=False)
tab["Method"] = pd.Categorical(tab.model, tab.groupby("model").model_rank.agg("median").sort_values().index)

tab["Dataset_name"] = tab["Dataset"].apply(lambda x: x.split(" (")[0])
tab["Dataset_name"] = pd.Categorical(tab["Dataset_name"], ["TCGA SKCM", "TCGA KIRC"])
tab["Dataset_type"] = tab.apply(lambda x: x.Dataset.split(" (")[1].replace(")", "") if "bulk" not in x.model else "bulk\nRNA", axis=1)
tab["Data type"] = tab["Dataset_type"]

tab["Method_Dataset_type"] = tab.apply(lambda x: f"{x.Method}_{x.Dataset_type}", axis=1)
tab["Dataset_type_Method"] = tab.apply(lambda x: f"{x.Dataset_type}_{x.Method}", axis=1)

tab.top_n = tab.top_n.astype("category")
g = (p9.ggplot(tab, p9.aes("top_n", "pearson_median", color="Method"))
 + p9.geom_line(p9.aes(color="Method", group="Method_Dataset_type", linetype="Data type"), 
                position=p9.position_dodge(width=position_dodge_width))
 + p9.geom_point(p9.aes(group="Method_Dataset_type", color="Method"), position=p9.position_dodge(width=position_dodge_width), size=0.7) 
 + p9.facet_wrap("~Dataset_name", scales="free_y", ncol=2)
 + p9.geom_errorbar(p9.aes(x="top_n", ymin="pearson_median-pearson_std",
                           ymax="pearson_median+pearson_std", color="Method", group="Method_Dataset_type"), 
                    width=0.4, alpha=1, size=0.5,
                    position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.theme(panel_spacing_y=0, panel_spacing_x=0, figure_size=(10, 4), 
            #axis_text_x = p9.element_blank(), 
            legend_position="right",
            text=p9.element_text(size=17),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16))
 + p9.ylab("Pearson correlation")
 + p9.xlab("Most variable genes")
 + p9.theme(axis_text_x = p9.element_text(angle = 90, hjust = 1))
 + scale_color_prism(palette = "colors")
 + p9.guides(color=p9.guide_legend(nrow=10, override_aes = p9.aes(shape = ".")))
 #+ p9.scale_y_continuous(breaks=list(np.arange(int(tab['auc_median'].min()), int(tab['auc_median'].max())+1, 0.1)))
)
g.save("figures/Figure6B-TCGA_benchmark_linetype.png", dpi=300)
g