In [None]:
import pandas as pd
import numpy as np
import plotnine as p9
import yaml
import glob
from plotnine_prism import *

In [None]:
with open("../config.yaml", "r") as stream:
    DATASET_INFO = yaml.safe_load(stream)
DATASET_INFO

In [None]:
survival_analysis_paths = glob.glob("../*/out_benchmark*/prediction/survival*.csv")
survival_analysis_paths = [f for f in survival_analysis_paths if 'COAD' not in f]
survival_analysis_paths

In [None]:
data_survival = []
for file in survival_analysis_paths:
    tab = pd.read_csv(file)
    dataset_name = file.split("/")[1]
    slide_type = file.split("/")[2].split("_")[-1]
    
    tab["Dataset"] = DATASET_INFO["DATASET_NAME"][f"{dataset_name}_{slide_type}"]
    
    data_survival.append(tab)
data_survival = pd.concat(data_survival)
data_survival.Dataset = pd.Categorical(data_survival.Dataset, DATASET_INFO["DATASET_NAME"].values())
data_survival.n_patients = pd.Categorical(data_survival.n_patients, sorted(data_survival.n_patients.unique()))
data_survival

In [None]:
tab = data_survival.copy()
tab = tab.query("n_patients == 125")
tab.method = tab.method.apply(lambda x: x if "bulk" not in x else "bulk RNA")
#tab["Dataset"] = tab.apply(lambda x: x.Dataset if "bulk" not in x.method else f'{x.Dataset.split(" (")[0]} (bulkRNA)', axis=1)
tab["Dataset_name"] = tab["Dataset"].apply(lambda x: x.split(" (")[0])
tab["Dataset_name"] = pd.Categorical(tab["Dataset_name"], ["TCGA SKCM", "TCGA KIRC"])
tab["Dataset_type"] = tab.apply(lambda x: x.Dataset.split(" (")[1].replace(")", "") if "bulk" not in x.method else "bulk\nRNA", axis=1)

position_dodge_width = 0.7
tab["method_rank"] = tab.groupby("Dataset", observed=False).c_mean.rank(ascending=False)
tab["Method"] = pd.Categorical(tab.method, ["DeepSpot", "STNet", "BLEEP", "bulk RNA"])

tab = tab.reset_index()
# Function to drop the second 'bulk' in each group
def drop_second_bulk(group):
    # Find indices where Dataset_type == "bulk"
    bulk_indices = group[group['Dataset_type'] == 'bulk\nRNA'].index.tolist()
    
    # If there are more than one 'bulk', drop the second one
    if len(bulk_indices) > 1:
        group = group.drop(bulk_indices[1])  # Drop the second occurrence
    
    return group

# Apply the function to each group
tab = tab.groupby('Dataset_name').apply(drop_second_bulk).reset_index(drop=True)
tab["Dataset_type"] = pd.Categorical(tab["Dataset_type"], ["bulk\nRNA", "FF", "FFPE"])

methods_to_include = ["DeepSpot", "STNet", "BLEEP"]

tab["Method_Dataset_type"] = tab.apply(lambda x: f"{x.Method}_{x.Dataset_type}", axis=1)
tab["Dataset_type_Method"] = tab.apply(lambda x: f"{x.Dataset_type}_{x.Method}", axis=1)
g = (
    p9.ggplot(tab, p9.aes("Dataset_type", "c_mean", fill="Method", group="Method"))
    + p9.geom_col(
        width=0.7,  # You can adjust this to control the width of the bars
        position=p9.position_dodge(width=0.8, preserve = "single"),  # Adjust the dodge width for separation
        size=0.7,
    )
    + p9.facet_wrap("~Dataset_name", ncol=2, scales="free_y")
    + p9.geom_errorbar(
        p9.aes(
            x="Dataset_type", ymin="c_mean-c_std", ymax="c_mean+c_std", group="Method"
        ),
        width=0.2, alpha=1, size=0.5, color="black",
        position=p9.position_dodge(width=0.8, preserve = "single"),
    )
    + p9.theme_bw()
    + p9.theme(
        panel_spacing_y=0,
        panel_spacing_x=0,
        figure_size=(7, 4.5),
        legend_position="right",
        text=p9.element_text(size=17),
        strip_text=p9.element_text(size=17),
        legend_title=p9.element_text(size=17),
        legend_text=p9.element_text(size=16),
        plot_title=p9.element_text(hjust=0.5, size=17),
        axis_text_x=p9.element_text(angle=90, hjust=0.5),
    )
    + p9.ylab("C-index")
    + p9.xlab("Data type")
    + p9.ggtitle("Survival analysis")
    + p9.coord_cartesian(ylim=(tab.c_mean.min()-0.01, None))
    + scale_fill_prism(palette="colors")
)
g.save("figures/Figure6C-TCGA_C_index.png", dpi=300)
g

In [None]:
tab = data_survival.copy()
tab = tab[tab.method != "bulk RNA-seq"]
tab["Dataset_name"] = tab["Dataset"].apply(lambda x: x.split(" (")[0])
tab["Dataset_type"] = tab["Dataset"].apply(lambda x: x.split(" (")[1].replace(")", ""))
position_dodge_width = 0.4
tab["method_rank"] = tab.groupby("Dataset", observed=False).c_mean.rank(ascending=False)
tab["Method"] = pd.Categorical(tab.method, tab.groupby("method").method_rank.agg("median").sort_values().index)
tab["Method_Dataset_type"] = tab.apply(lambda x: f"{x.Method}_{x.Dataset_type}", axis=1)
g = (p9.ggplot(tab, p9.aes("n_patients", "c_mean", color="Method"))
 + p9.geom_line(p9.aes(color="Method", group="Method_Dataset_type"), linetype="dashed", 
                position=p9.position_dodge(width=position_dodge_width))
 + p9.geom_point(p9.aes(color="Method", group="Method_Dataset_type"), position=p9.position_dodge(width=position_dodge_width), size=0.7) 
 + p9.facet_wrap("~Dataset_name+Dataset_type", ncol=2, scales="free_y")
 + p9.geom_errorbar(p9.aes(x="n_patients", ymin="c_mean-c_std",
                           ymax="c_mean+c_std", color="Method", group="Method_Dataset_type"), 
                    width=0.4, alpha=1, size=0.5,
                    position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.theme(panel_spacing_y=0, panel_spacing_x=0, figure_size=(10, 8.), 
            #axis_text_x = p9.element_blank(), 
            legend_position="right",
            text=p9.element_text(size=17),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16),
            plot_title = p9.element_text(hjust = 0.5, size=17))
 + p9.ylab("C-index")
 + p9.xlab("Training #patients")
 + p9.ggtitle("Survival analysis")
 + p9.theme(axis_text_x = p9.element_text(angle = 90, hjust = 1),
            )
 + scale_color_prism(palette = "colors")
 + p9.guides(color=p9.guide_legend(nrow=10, override_aes = p9.aes(shape = ".")))
)
g.save("figures/FigureS5-TCGA_C_index.png", dpi=300)
g

In [None]:
tumor_type_paths = glob.glob("../TCGA_SKCM/out_benchmark*/prediction/tumor_type_prediction.csv")
tumor_type_paths = [f for f in tumor_type_paths if 'KIRC' not in f and 'COAD' not in f]
tumor_type_paths

In [None]:
tumpo_type = []
for file in tumor_type_paths:
    tab = pd.read_csv(file)
    dataset_name = file.split("/")[1]
    slide_type = file.split("/")[2].split("_")[-1]
    
    tab["Dataset"] = DATASET_INFO["DATASET_NAME"][f"{dataset_name}_{slide_type}"]
    
    tumpo_type.append(tab)
tumpo_type = pd.concat(tumpo_type)
tumpo_type.Dataset = pd.Categorical(tumpo_type.Dataset, DATASET_INFO["DATASET_NAME"].values())
#data_survival.c_mean = data_survival.c_mean.astype(float).round(3)
#data_survival.c_std = data_survival.c_std.astype(float).round(3)
#data_survival.n_patients = data_survival.n_patients.astype(int)
tumpo_type.n_patients = pd.Categorical(tumpo_type.n_patients, sorted(tumpo_type.n_patients.unique()))
tumpo_type

In [None]:
tab = tumpo_type.copy()
tab["Dataset_name"] = tab["Dataset"].apply(lambda x: x.split(" (")[0])

tab["Dataset_name"] = pd.Categorical(tab["Dataset_name"], ["TCGA SKCM"])
tab["Dataset_type"] = tab.apply(lambda x: x.Dataset.split(" (")[1].replace(")", "") if "bulk" not in x.method else "bulk\nRNA", axis=1)

position_dodge_width = 0.3
tab["method_rank"] = tab.groupby("Dataset", observed=False).f1_mean.rank(ascending=False)
tab.method = tab.method.apply(lambda x: x if "bulk" not in x else "bulk RNA")
tab["Method"] = pd.Categorical(tab.method, ["DeepSpot", "STNet", "BLEEP", "bulk RNA"])

tab = tab.reset_index()
# Function to drop the second 'bulk' in each group
def drop_second_bulk(group):
    # Find indices where Dataset_type == "bulk"
    bulk_indices = group[group['Dataset_type'] == 'bulk\nRNA'].index.tolist()
    
    # If there are more than one 'bulk', drop the second one
    if len(bulk_indices) > 1:
        group = group.drop(bulk_indices[1])  # Drop the second occurrence
    
    return group

# Apply the function to each group
tab = tab.groupby('Dataset_name').apply(drop_second_bulk).reset_index(drop=True)
tab["Dataset_type"] = pd.Categorical(tab["Dataset_type"], ["bulk\nRNA", "FF", "FFPE"])
tab["Method_Dataset_type"] = tab.apply(lambda x: f"{x.Method}_{x.Dataset_type}", axis=1)
tab["Dataset_type_Method"] = tab.apply(lambda x: f"{x.Dataset_type}_{x.Method}", axis=1)

g = (
    p9.ggplot(tab, p9.aes("Dataset_type", "f1_mean", fill="Method", group="Method"))
    + p9.geom_col(
        width=0.7,  # You can adjust this to control the width of the bars
        position=p9.position_dodge(width=0.8, preserve = "single"),  # Adjust the dodge width for separation
        size=0.7,
    )
    + p9.facet_wrap("~Dataset_name", ncol=1, scales="free_y")  # Use free_y scales for dynamic y-axis
    + p9.geom_errorbar(
        p9.aes(
            x="Dataset_type", ymin="f1_mean-f1_std", ymax="f1_mean+f1_std", group="Method"
        ),
        width=0.2, alpha=1, size=0.5, color="black",
        position=p9.position_dodge(width=0.8, preserve = "single"),
    )
    + p9.theme_bw()
    + p9.theme(
        panel_spacing_y=0,
        panel_spacing_x=0,
        figure_size=(4.5, 4.5),
        legend_position="right",
        text=p9.element_text(size=17),
        strip_text=p9.element_text(size=17),
        legend_title=p9.element_text(size=17),
        legend_text=p9.element_text(size=16),
        plot_title=p9.element_text(hjust=0.5, size=17),
        axis_text_x=p9.element_text(angle=90, hjust=0.5),
    )
    + p9.ylab("F1 score")
    + p9.xlab("Data type")
    + p9.ggtitle("Metastatic vs Primary tumor")
    + scale_fill_prism(palette="colors")
    + p9.coord_cartesian(ylim=(0.72, None))
)
g.save("figures/Figure6D-TCGA_F1.png", dpi=300)
g

In [None]:
tab = tab[tab.method != "bulk RNA"]
tab["Method"] = tab["Method"].astype(str)
tab["method_rank"] = tab.groupby("Dataset", observed=False).f1_mean.rank(ascending=False)
tab["Method"] = pd.Categorical(tab.method, ["DeepSpot", "STNet", "BLEEP"])

g = (p9.ggplot(tab, p9.aes("n_patients", "f1_mean", color="Method"))
 + p9.geom_line(p9.aes(color="Method", group="Method"), linetype="dashed", 
                position=p9.position_dodge(width=position_dodge_width))
 + p9.geom_point(p9.aes(color="Method"), position=p9.position_dodge(width=position_dodge_width), size=0.7) 
 + p9.facet_wrap("~Dataset", scales="free_y", ncol=2)
 + p9.geom_errorbar(p9.aes(x="n_patients", ymin="f1_mean-f1_std",
                           ymax="f1_mean+f1_std", color="Method"), 
                    width=0.4, alpha=1, size=0.5,
                    position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.theme(panel_spacing_y=0, panel_spacing_x=0, figure_size=(10, 5), 
            #axis_text_x = p9.element_blank(), 
            legend_position="right",
            text=p9.element_text(size=17),
            strip_text=p9.element_text(size=17),
            legend_title=p9.element_text(size=17),
            legend_text=p9.element_text(size=16))
 + p9.ylab("F1 score")
 + p9.xlab("Training #patients")
 + p9.ggtitle("Metastatic vs Primary tumor")
 + p9.theme(axis_text_x = p9.element_text(angle = 90, hjust = 1),
            plot_title = p9.element_text(hjust = 0.5))
 + scale_color_prism(palette = "colors")
 + p9.guides(color=p9.guide_legend(nrow=10, override_aes = p9.aes(shape = ".")))
 #+ p9.scale_y_continuous(breaks=list(np.arange(int(tab['auc_median'].min()), int(tab['auc_median'].max())+1, 0.1)))
)
g.save("figures/FigureS6-TCGA_F1_score.png", dpi=300)
g