# Common trunk: Violin Plots and Precision-Recall Curves

## Packages and functions

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches
import matplotlib.colors as mcolors
import os
import time
from scipy.stats import sem, t
from sklearn.metrics import precision_recall_curve, auc, PrecisionRecallDisplay
import matplotlib.font_manager as fm

#######################################################################################################################################
# Parameters for the plotting:
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] #change font to a known standard font
plt.rcParams["font.size"] = 16
plt.rcParams["axes.labelsize"] = 16
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["legend.fontsize"] = 16
plt.rcParams["figure.titlesize"] = 14



#######################################################################################################################################

def plot_colorbar(scenario_colors, alpha=1.0):
    """
    Plots a colorbar based on the given dictionary of scenario colors.

    Parameters:
    - scenario_colors (dict): A dictionary where keys are scenario labels and values are hex color codes.

    Example:
    plot_colorbar({
        'A': '#36617B',  # Color for scenario A
        'B': '#4993AA',  # Color for scenario B
        'C': '#e3c983',  # Color for scenario C
        'D': '#ECB354',  # Color for scenario D
        'E': '#B95224',  # Color for scenario E
    })
    """
    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(5.5, 1.1))  # Size corresponding to 275x55 pixels
    # Iterate over the items in the dictionary to create colored rectangles and labels
    for i, (label, color) in enumerate(scenario_colors.items()):
        # Draw the colored rectangles
        rect = patches.Rectangle((i * 55, 0), 55, 55, linewidth=2, edgecolor='white', facecolor=color)
        ax.add_patch(rect)
        # Add labels below each rectangle
        ax.text(i * 55 + 27.5, -10, label, ha='center', va='top', fontsize=10, color='black')

    # Set the limits and turn off the axes
    ax.set_xlim(0, 275)
    ax.set_ylim(-20, 55)
    ax.axis('off')
    plt.show()

# set the seaborn parameters
sns.set_theme(context='poster',style='white')
sns.set_palette('bright')

###########################################################################################################################################################################################################################################

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches
import os
from scipy.stats import sem, t

#################################################################################################################################
########################################    Violin Plots    ###############################################################################
###################################################################################################################################
def create_violin_plots(dataframes, keys_ordered, color_dict, title_dict,
                        display, title_display, gap, inner_detail="quart",
                        highlight_column=None, highlight_color="grey", highlight_marker="o",
                        highlight_size=0.5, alpha=0.7,
                        n_rows=1, n_cols=5, split_by_sex=False, truth="status",
                        font_size=24, fig_path=None):
    """
    Create violin plots for the given dataframes with custom settings.

    Parameters:
    - dataframes (dict): Dictionary of DataFrames keyed by identifiers.
    - keys_ordered (list): List of keys to order the plots.
    - color_dict (dict): Dictionary mapping keys to colors.
    - title_dict (dict): Dictionary mapping keys to titles for each subplot.
    - display (str): Label to describe the content or type of the plot for saving.
    - n_rows (int): Number of rows in the subplot grid.
    - split_by_sex (bool): Whether to split the data by sex and create separate plots for each.
    - highlight_column (str, optional): Column name to use for highlighting (e.g. a certain diagnosis).
    - highlight_color (str): Color of the highlight points.
    - highlight_marker (str): Marker style for highlighted points.
    - highlight_size (int): Size of the highlighted markers.
    - alpha (float): Transparency of the highlight markers.

    Returns:
    - fig (Figure): The figure object containing the plots.
    - axes (array): Array of Axes objects containing the subplots.
    """

    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, sharey=True, figsize=(n_cols * 2, n_rows * 10))
    plt.subplots_adjust(wspace=0, hspace=0.1)

    hue_column=truth
    change_display=False
    if n_rows == 1:
        axes = np.array([axes])

    ax_rav = axes.ravel()

    for i, (key, ax) in enumerate(zip(keys_ordered, ax_rav)):
        df = dataframes[key].copy()
        df['ones'] = key
        color = color_dict[key]

        palette = {0: adjust_alpha(color, 0.5), 1: adjust_alpha(color, 1)}

        if inner_detail == "quart":
            inner = "quart"
        elif inner_detail == "ci":
            inner = None  # No inner lines; we will calculate CIs manually

        # Set consistent x-axis for all subplots
        ax.set_xlim(-0.5, 0.5)

        if split_by_sex:
            for sex in df['SEX'].unique():
                palette = {0: adjust_alpha(color, 0.7), 1: adjust_alpha(color, 0.7)}
                df_sex = df[df['SEX'] == sex]
                sns.violinplot(data=df_sex, y="y_pred", x="ones", hue=hue_column, split=True, dodge="auto", gap=gap, inner=None, ax=ax, #hue_column indicates the column to choose for split of half-violins
                               linecolor='white', linewidth=2, palette=palette, saturation=1)
                # Adjusting the style for the specific sex
                if sex == 'Female':
                    for artist in ax.collections:
                        artist.set_hatch('////')
            ax.set_title(title_dict[key], fontsize=24, pad=20)
        else:
            sns.violinplot(data=df, y="y_pred", x="ones", hue=hue_column, split=True, dodge="auto", gap=gap, inner=inner, ax=ax,
                           linecolor='white', linewidth=4, palette=palette, saturation=1)
            ax.set_title(title_dict[key], fontsize=font_size, pad=15, rotation=45, horizontalalignment='left')

        # Overlay highlight markers and corresponding legend if the feature exists
        if highlight_column and highlight_column in df.columns:
            highlight_cases = df[df[highlight_column] == 1]  # Select rows where feature is present

            change_display= True

            # Use fixed x positions to avoid affecting violin sizes
            x_positions = np.zeros(len(highlight_cases))

            # Adjust x-coordinates based on 'status' (truth column) for split violin alignment
            # Use smaller fixed offsets instead of proportional positioning
            x_positions[highlight_cases[truth] == 0] = -0.12 + np.random.normal(0, 0.02, sum(highlight_cases[truth] == 0))
            x_positions[highlight_cases[truth] == 1] = 0.12 + np.random.normal(0, 0.02, sum(highlight_cases[truth] == 1))

            ax.scatter(
                x_positions, highlight_cases["y_pred"],
                color=highlight_color, marker=highlight_marker,
                s=highlight_size, alpha=alpha, edgecolors="black",
                zorder=10  # Ensure dots are drawn on top
            )
            # Add a legend for the highlighted points
            scatter_legend = mlines.Line2D(
                [], [],
                color=highlight_color,
                marker=highlight_marker,
                linestyle='None',
                markersize=np.sqrt(highlight_size)*16,  # Scale to reasonable size
                markeredgecolor='black',
                alpha=alpha,
                label=highlight_column.replace('_', ' ').title()
            )

            # Position it near the existing legend
            fig.legend(
                handles=[scatter_legend],
                loc='lower right',
                bbox_to_anchor=(0.9, 0.1),
                fontsize=font_size-3,
                frameon=False,
                handletextpad=0.5,
                prop={'family': 'Arial', 'weight': 'normal'}
            )

        # Reset all axes to consistent sizes
        ax.set_xlim(-0.5, 0.5)
        ax.legend().set_visible(False)
        ax.set_xlabel('')
        ax.set_xticks([])
        ax.set_ylabel('', fontsize=font_size+4)
        ax.set_ylim((0, 1))
        ax.tick_params(axis='y', labelsize=font_size+4)
        ax.spines['left'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(True)

    for ax in ax_rav:
        for c in ax.collections:
            c.set_edgecolor('face')
            c.set_linewidth(4)

    for i in range(n_rows):
        axes[i, 0].set_ylabel('Predicted Probability', fontsize=font_size+4)
        axes[i, 0].spines['left'].set_visible(True)
        axes[i, -1].spines['right'].set_visible(True)

    if split_by_sex:
        custom_lines = [
            patches.Patch(facecolor=adjust_alpha('#000000', 1.0), edgecolor='white', linewidth=2, label='Male'),
            patches.Patch(facecolor=adjust_alpha('#000000', 0.7), edgecolor='white', linewidth=2, hatch='////', label='Female')
        ]
        fig.legend(handles=custom_lines, loc='upper left', bbox_to_anchor=(0.13, 0.88),  fontsize=font_size, title=None, frameon=False)


    n_cases = dataframes[keys_ordered[0]][truth].sum()
    n_controls = len(dataframes[keys_ordered[0]]) - n_cases

    if not split_by_sex:
        tn_patch = patches.Patch(facecolor=adjust_alpha('#808080', 0.3), edgecolor=adjust_alpha('#808080', 1), linewidth=3, label=f'Controls (n={n_controls})')  # Grey color
        tp_patch = patches.Patch(facecolor=adjust_alpha('#808080', 1), edgecolor=adjust_alpha('#808080', 1), linewidth=3, label=f'Cases (n={n_cases})')
        fig.legend(handles=[tn_patch, tp_patch], bbox_to_anchor=(0.13, 0.1), loc='lower left', fontsize=font_size-3, title=None, frameon=False)

    fig.text(0.5, 0.05, title_display, ha='center', fontsize=font_size+6)

    if change_display: #if highlight_column is activated, this needs to be passed to the figure name
        display=display+highlight_column

    if fig_path:
        svg_path = os.path.join(fig_path, f"Violins_{display}.svg")
        fig.savefig(svg_path, format='svg', bbox_inches='tight', transparent=True)

    return fig, axes


#######################################################################################################################################

def adjust_alpha(hex_color, alpha=1.0):
    """ Adjust the alpha value of a hex color. """
    import matplotlib.colors as mcolors
    color_rgba = mcolors.to_rgba(hex_color, alpha)
    return color_rgba

#######################################################################################################################################
def adjust_background():
    paths = plt.gca().collections
    for path in paths:
        verts = path.get_offsets()
        x = [vert[0] for vert in verts]
        y = [vert[1] for vert in verts]

        # Create a patch filled with a pattern
        poly = patches.Polygon(np.column_stack((x, y)), closed=True, edgecolor='none', facecolor='lightgray', hatch='\\')

        # Add the patch to the plot
        plt.gca().add_patch(poly)


## Data Import

In [None]:
path = "/home/jupyter/workspaces/machinelearningforlivercancerriskprediction"
fig_path = f"{path}/visuals"


# get the data from the combined
excel_file = pd.ExcelFile(path+'/combined_output/val/Prediction_values_combined.xlsx')
sheet_names = excel_file.sheet_names

# Create DataFrames
dataframes = {}
for sheet_name in sheet_names:
    dataframes[sheet_name] = pd.read_excel(excel_file, sheet_name)
    print(f"Sheet {sheet_name} read and saved to dataframes dictionary")
dataframes.pop('Sheet')
# dataframes.pop([i for i in list(dataframes.keys()) if str(i).endswith('Model_D') or str(i).endswith('Demographics')][0])
dataframes.keys()

for key in dataframes: #Convert true cancerreg columns to boolean
    if 'true_cancerreg' in dataframes[key].columns:
        dataframes[key]['true_cancerreg'] = dataframes[key]['true_cancerreg'].astype(bool)


df_nan = pd.read_excel(path+"/HCC/df_nan.xlsx")
df_covariates = pd.read_csv(path+"/HCC/HCC_covariates_1.7.txt", sep="\t")
df_covariates = df_covariates[df_covariates['person_id'].isin(df_nan['person_id'])]
df_covariates = df_covariates.rename(columns={"person_id": "eid"})
df_covariates = df_covariates[["eid", "SEX"]]
df_covariates

df_ethnicity_boolean = pd.read_csv(os.path.join(path, "HCC/df_ethnicity_boolean.csv")) #Load ethnicity white/non-white

In [None]:
df_ethnicity_boolean

In [None]:
df_nan = pd.read_excel(path+"/HCC/df_nan.xlsx")
df_covariates = pd.read_csv(path+"/HCC/HCC_covariates_1.7.txt", sep="\t")

wanted_cols = ["person_id", "eid", "AGE", "SEX", "BMI"]

df_covariates = df_covariates[[col for col in df_covariates.columns if col in wanted_cols]]

#df_covariates = df_covariates[df_covariates['person_id'].isin(df_nan['person_id'])]
df_covariates = df_covariates.rename(columns={"person_id": "eid"})
df_covariates

In [None]:
dataframes["all_Model_AMAP-RFC"]

In [None]:
for key in dataframes:
     dataframes[key] = pd.merge(dataframes[key], df_covariates, on='eid')
     dataframes[key] = pd.merge(dataframes[key], df_ethnicity_boolean, on= 'eid')




In [None]:
# Import AMAP Scores and append them to dataframes dictionary
# amap_all=pd.read_csv(path+'/HCC/df_amap.csv')
# amap_all["aMAP"] = amap_all["aMAP"].apply(lambda x: x if 0 <= x <= 1 else pd.NA)
# amap_all = amap_all.dropna()
# amap_all['aMAP'] = pd.to_numeric(amap_all['aMAP'], errors='coerce')
# amap_all['status'] = pd.to_numeric(amap_all['status'], errors='coerce')

# amap_all.rename(columns={'aMAP': 'proba', 'status': 'true'}, inplace=True)

# dataframes['all_aMAP'] = amap_all[['proba', 'true', "SEX"]]

# amap_all




# amap_par=pd.read_csv(path+'/Models/df_amap_par.csv')
# amap_par.rename(columns={'aMAP': 'proba', 'status': 'true', 'status_cancerreg' : 'true_cancerreg', 'gender' : "SEX"}, inplace=True)
# dataframes['par_aMAP'] = amap_par[['proba', 'true', 'true_cancerreg', "SEX"]]

In [None]:

benchmarks= pd.read_csv(path+'/data/df_benchmark.csv')
benchmarks_to_add = ["aMAP", "APRI", "NFS", "FIB4"]





for item in benchmarks_to_add:
    temp_df = benchmarks[[item, "status"]].copy()  # select prediction column + status
    temp_df = temp_df.rename(columns={item: "y_pred"})  # rename columns
    dataframes[f"all_{item}"] = temp_df  # save under all_{item} key

dataframes

In [None]:
 #Load benchmark data and impute missing rows


benchmarks= pd.read_csv(path+'/data/df_benchmark.csv')
benchmarks_to_add = ["aMAP", "APRI", "NFS", "FIB4"]


# Verify the imputation
print("NA counts before imputation:")
print(benchmarks[benchmarks_to_add].isnull().sum())

# Impute the specified columns with their respective means
for column in benchmarks_to_add:
    benchmarks[column].fillna(benchmarks[column].mean(), inplace=True)

# Verify the imputation
print("NA counts after imputation:")
print(benchmarks[benchmarks_to_add].isnull().sum())

# Optional: Display summary statistics of imputed columns
print("\nSummary of imputed columns:")
print(benchmarks[benchmarks_to_add].describe())



for item in benchmarks_to_add:
    temp_df = benchmarks[[item, "status"]].copy()  # select prediction column + status
    temp_df = temp_df.rename(columns={item: "y_pred"})  # rename columns
    dataframes[f"all_{item}"] = temp_df  # save under all_{item} key

dataframes

In [None]:
# color_map = {
#     "df_covariates": (0.21176, 0.38039, 0.48235),
#     "df_diagnosis": (0.28627, 0.57647, 0.66667),
#     "df_blood": (0.93725, 0.85490, 0.67451),
#     "df_snp": (0.92549, 0.70196, 0.32941),
#     "df_metabolomics": (0.72549, 0.32157, 0.14118)
# }

color_map = {
    "df_covariates": (0.28627450980392155, 0.5843137254901961, 0.6784313725490196),  # #4995AD
    "df_diagnosis": (0.2196078431372549, 0.3333333333333333, 0.4745098039215686),  # #385579
    "df_blood": (0.7568627450980392, 0.21176470588235294, 0.09019607843137255),  # #C13617
    "df_snp": (0.9411764705882353, 0.5647058823529412, 0.24313725490196078),  # #F0903E
    "df_metabolomics": (0.9411764705882353, 0.7843137254901961, 0.4470588235294118),  # #F0C872
}






plot_colorbar(color_map)

In [None]:
def summarize_dataframes(dataframes_dict):
    summary = {}
    for key, df in dataframes_dict.items():
        summary[key] = {
            "shape": df.shape,
            "columns": df.columns.tolist(),
            "dtypes": df.dtypes.to_dict(),
            "non_null_counts": df.count().to_dict(),
            "total_nas": df.isna().sum().sum(),
            "numeric_summary": df.describe().to_dict() if df.select_dtypes(include=[np.number]).columns.any() else None
        }

        # Check for 'proba' and 'true' columns specifically
        if 'proba' in df.columns:
            summary[key]["proba_range"] = (df['proba'].min(), df['proba'].max())
        if 'true' in df.columns:
            summary[key]["true_value_counts"] = df['true'].value_counts().to_dict()

    return summary


# Use the function to summarize your dataframes
dataframes_summary = summarize_dataframes(dataframes)

for key, summary in dataframes_summary.items():
    print(f"\nSummary for {key}:")
    print(f"Shape: {summary['shape']}")
    print(f"Columns: {summary['columns']}")
    print(f"Data types: {summary['dtypes']}")
    print(f"Non-null counts: {summary['non_null_counts']}")
    print(f"Total NAs: {summary['total_nas']}")
    if 'proba_range' in summary:
        print(f"Proba range: {summary['proba_range']}")
    if 'true_value_counts' in summary:
        print(f"True value counts: {summary['true_value_counts']}")
    print("\n" + "-"*50)


# Violins

## All Models "All"

In [None]:
keys_ordered_all=['all_Model_Demographics', 'all_Model_Diagnosis', 'all_Model_Blood', 'all_Model_SNP', 'all_Model_Metabolomics', 'all_Model_A', 'all_Model_B', 'all_Model_C','all_Model_D', 'all_Model_E',]
color_dict_all={'all_Model_A'                   :color_map["df_covariates"],
            'all_Model_B'                   :color_map['df_diagnosis'],
            'all_Model_C'                   :color_map["df_blood"],
            'all_Model_D'                   :color_map["df_snp"],
            'all_Model_E'                   :color_map['df_metabolomics'],


            'all_Model_Demographics'        :color_map["df_covariates"],
            'all_Model_Diagnosis'          :color_map["df_diagnosis"],
            'all_Model_Blood'               :color_map["df_blood"],
            'all_Model_SNP'                 :color_map['df_snp'],
            'all_Model_Metabolomics'        :color_map['df_metabolomics']}

title_dict_all={'all_Model_A'                   :'Model A',
            'all_Model_B'                   :'Model B',
            'all_Model_C'                   :'Model C',
            'all_Model_D'                   :'Model D',
            'all_Model_E'                   :'Model E',


            'all_Model_Demographics'        :'Demographics',
            'all_Model_Diagnosis'          :'Diagnosis',
            'all_Model_Blood'               :'Blood',
            'all_Model_SNP'                 :'Genomics',
            'all_Model_Metabolomics'        :'Metabolomics'}

create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_all, color_dict=color_dict_all, title_dict=title_dict_all, display="all", title_display="All", inner_detail="quart", n_cols=10, gap=0, hue_column="true_cancerreg")

## All only incremental models

In [None]:
keys_ordered_all=['all_Model_A', 'all_Model_B', 'all_Model_C','all_Model_D', 'all_Model_E',]
color_dict_all={'all_Model_A'                   :color_map["df_covariates"],
            'all_Model_B'                   :color_map['df_diagnosis'],
            'all_Model_C'                   :color_map["df_blood"],
            'all_Model_D'                   :color_map["df_snp"],
            'all_Model_E'                   :color_map['df_metabolomics'],}


title_dict_all={'all_Model_A'                   :'Model A',
            'all_Model_B'                   :'Model B',
            'all_Model_C'                   :'Model C',
            'all_Model_D'                   :'Model D',
            'all_Model_E'                   :'Model E',}

create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_all, color_dict=color_dict_all, title_dict=title_dict_all, display="all_inc", title_display="All", inner_detail="quart", n_cols=5, gap=0, split_by_sex=False, hue_column="true_cancerreg")

## All strat. sex

In [None]:
keys_ordered_all=['all_Model_A', 'all_Model_B', 'all_Model_C','all_Model_D', 'all_Model_E',]
color_dict_all={'all_Model_A'                   :color_map["df_covariates"],
            'all_Model_B'                   :color_map['df_diagnosis'],
            'all_Model_C'                   :color_map["df_blood"],
            'all_Model_D'                   :color_map["df_snp"],
            'all_Model_E'                   :color_map['df_metabolomics'],}


title_dict_all={'all_Model_A'                   :'Model A',
            'all_Model_B'                   :'Model B',
            'all_Model_C'                   :'Model C',
            'all_Model_D'                   :'Model D',
            'all_Model_E'                   :'Model E',}

create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_all, color_dict=color_dict_all, title_dict=title_dict_all, display="all_sex", title_display="All", inner_detail="quart", n_cols=5, gap=0, split_by_sex=True, hue_column="true_cancerreg")

## All Models "PAR"

In [None]:
keys_ordered_par=['par_Model_Demographics', 'par_Model_Diagnosis', 'par_Model_Blood', 'par_Model_SNP', 'par_Model_Metabolomics', 'par_Model_A', 'par_Model_B', 'par_Model_C','par_Model_D', 'par_Model_E']
color_dict_par={'par_Model_A'                   :color_map["df_covariates"],
            'par_Model_B'                   :color_map['df_diagnosis'],
            'par_Model_C'                   :color_map["df_blood"],
            'par_Model_D'                   :color_map["df_snp"],
            'par_Model_E'                   :color_map['df_metabolomics'],


            'par_Model_Demographics'        :color_map["df_covariates"],
            'par_Model_Diagnosis'          :color_map["df_diagnosis"],
            'par_Model_Blood'               :color_map["df_blood"],
            'par_Model_SNP'                 :color_map['df_snp'],
            'par_Model_Metabolomics'        :color_map['df_metabolomics']}

title_dict_par={'par_Model_A'                   :'Model A',
            'par_Model_B'                   :'Model B',
            'par_Model_C'                   :'Model C',
            'par_Model_D'                   :'Model D',
            'par_Model_E'                   :'Model E',


            'par_Model_Demographics'        :'Demographics',
            'par_Model_Diagnosis'          :'Diagnosis',
            'par_Model_Blood'               :'Blood',
            'par_Model_SNP'                 :'Genomics',
            'par_Model_Metabolomics'        :'Metabolomics'}


create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_par, color_dict=color_dict_par, title_dict=title_dict_par, title_display="Chronic Liver Disease", display="par", n_cols=10, gap=0, hue_column="true_cancerreg")

## Incremental "PAR"

In [None]:
keys_ordered_par_sex=['par_Model_A', 'par_Model_B', 'par_Model_C','par_Model_D', 'par_Model_E']
color_dict_par={'par_Model_A'                   :color_map["df_covariates"],
            'par_Model_B'                   :color_map['df_diagnosis'],
            'par_Model_C'                   :color_map["df_blood"],
            'par_Model_D'                   :color_map["df_snp"],
            'par_Model_E'                   :color_map['df_metabolomics']
            }


title_dict_par={'par_Model_A'                   :'Model A',
            'par_Model_B'                   :'Model B',
            'par_Model_C'                   :'Model C',
            'par_Model_D'                   :'Model D',
            'par_Model_E'                   :'Model E',
            }

create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_par_sex, color_dict=color_dict_par, title_dict=title_dict_par, display="par_inc", title_display="Chronic Liver Disease",n_cols=5, gap=0, split_by_sex=False)

## PAR strat. SEX

In [None]:
keys_ordered_par_sex=['par_Model_A', 'par_Model_B', 'par_Model_C','par_Model_D', 'par_Model_E']
color_dict_par={'par_Model_A'                   :color_map["df_covariates"],
            'par_Model_B'                   :color_map['df_diagnosis'],
            'par_Model_C'                   :color_map["df_blood"],
            'par_Model_D'                   :color_map["df_snp"],
            'par_Model_E'                   :color_map['df_metabolomics']
            }


title_dict_par={'par_Model_A'                   :'Model A',
            'par_Model_B'                   :'Model B',
            'par_Model_C'                   :'Model C',
            'par_Model_D'                   :'Model D',
            'par_Model_E'                   :'Model E',
            }

create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_par_sex, color_dict=color_dict_par, title_dict=title_dict_par, display="par_sex", title_display="Chronic Liver Disease",n_cols=5, gap=0, split_by_sex=True)

## All Reduced Models + Benchmarking Literature Scores

In [None]:
keys_ordered_reduced=['all_Model_TOP30', 'all_Model_TOP15', 'all_Model_AMAP-RFC', 'all_aMAP', 'all_APRI', 'all_FIB4']
color_dict_all={
    'all_aMAP' : '#808080',
    'all_APRI' : '#808080',
    'all_FIB4' : '#808080',
    'all_Model_AMAP-RFC': '#c9c9c9',
    #'all_Model_TOP75' : '#cb6043',

    'all_Model_TOP30' : '#d1846e',
    'all_Model_TOP15' : '#d0a79a'
}

title_dict_all={
    'all_aMAP' : 'aMAP',
    'all_FIB4' : 'FIB4',
    'all_APRI' : 'APRI',
    'all_Model_AMAP-RFC': 'aMAP RFC',
    #'all_Model_TOP75' : 'TOP75',
    'all_Model_TOP30' : 'TOP30',
    'all_Model_TOP15' : 'TOP15'
}
create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_reduced, color_dict=color_dict_all, title_dict=title_dict_all, display="small_benchmark_all", title_display="All", inner_detail="quart", n_cols=6, gap=0, font_size=24, split_by_sex=False, fig_path=fig_path)

##### Split by sex

In [None]:
keys_ordered_reduced=['all_Model_TOP75', 'all_Model_TOP30', 'all_Model_TOP15', 'all_Model_AMAP-RFC', 'all_aMAP']
color_dict_all={
    'all_aMAP' : '#808080',
    'all_Model_AMAP-RFC': '#c9c9c9',
    'all_Model_TOP75' : '#cb6043',
    'all_Model_TOP30' : '#d1846e',
    'all_Model_TOP15' : '#d0a79a',

}

title_dict_all={
    'all_aMAP' : 'aMAP',
    'all_Model_AMAP-RFC': 'aMAP RFC',
    'all_Model_TOP75' : 'TOP75',
    'all_Model_TOP30' : 'TOP30',
    'all_Model_TOP15' : 'TOP15',

}
create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_reduced, color_dict=color_dict_all, title_dict=title_dict_all, display="small_benchmark_all_strat", title_display="All", inner_detail="quart", n_cols=5, gap=0, hue_column="true", font_size=24, split_by_sex=True)

##### Violins small models PAR

In [None]:
keys_ordered_reduced=['par_Model_C','par_Model_TOP75', 'par_Model_TOP30', 'par_Model_TOP15', 'par_Model_AMAP-RFC', 'par_aMAP']
color_dict_par={
    'par_Model_C': '#C13617',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'par_aMAP' : '#808080',
    'par_Model_AMAP-RFC': '#c9c9c9',
    'par_Model_TOP75' : '#cb6043',
    'par_Model_TOP30' : '#d1846e',
    'par_Model_TOP15' : '#d0a79a'
}

title_dict_par={
    'par_Model_C': 'Model C',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'par_aMAP' : 'aMAP',
    'par_Model_AMAP-RFC': 'aMAP RFC',
    'par_Model_TOP75' : 'TOP75',
    'par_Model_TOP30' : 'TOP30',
    'par_Model_TOP15' : 'TOP15'
}
create_violin_plots(dataframes=dataframes, keys_ordered=keys_ordered_reduced, color_dict=color_dict_par, title_dict=title_dict_par, display="small_benchmark_par", title_display="Patients at risk", inner_detail="quart", n_cols=6, gap=0, hue_column="true_cancerreg")

# Precision Recall Plots


In [None]:
def plot_precision_recall_curves(dataframes, keys_ordered, colors,  fig, ax, y_label=None, x_label=None, display='', ylim=(0,1), fill_bet=False, title='', fig_path=None, line_style='-', dotted_keys=None, plot_legend=True, lw=2, font_size=16, truth="status_cancerreg"):
    """
    Plots overlaying precision-recall curves for multiple datasets.

    Parameters:
    - dataframes (dict): Dictionary of dataframes containing columns 'proba' and 'true'.
    - keys_ordered (list): List of keys in the order they should be plotted.
    - colors (dict): Dictionary of colors corresponding to each key.
    - fig (matplotlib.figure.Figure): Figure object to plot on.
    - ax (matplotlib.axes.Axes): Axes object to plot on.
    - display (str): Label for the display.
    - fill_bet (bool): Whether to fill the area between the standard deviation bounds.
    - title (str): Title of the plot.
    - fig_path (str): Path to save the figure.
    - line_style (str): Line style for the mean curve.
    - dotted_keys (list): List of keys to be displayed with dotted lines.
    """

    mean_precisions = []
    base_recall = np.linspace(0, 1, 100)
    if dotted_keys is None:
        dotted_keys = []

    for key in keys_ordered:
        df = dataframes[key]
        precision, recall, _ = precision_recall_curve(df[truth], df["y_pred"])
        pr_auc = auc(recall, precision)

        # Interpolation to remove increasing parts in precision
        precision_inv = np.fliplr([precision])[0]
        recall_inv = np.fliplr([recall])[0]
        j = len(precision_inv) - 2
        while j >= 0:
            if precision_inv[j + 1] > precision_inv[j]:
                precision_inv[j] = precision_inv[j + 1]
            j -= 1

        # Use the corrected precision for interpolation
        decreasing_max_precision = np.maximum.accumulate(precision_inv[::-1])[::-1]
        mean_precision = np.interp(base_recall, recall[::-1], decreasing_max_precision)
        mean_precisions.append(mean_precision)

        estimator = key.split('_')[0]  # Extract the estimator name from the key
        if estimator == "CatBoost":
            linestyle = '-'  # solid line for CatBoost
        elif estimator == "RFC":
            linestyle = '--'  # dashed line for RFC
        else:
            linestyle = line_style  # fallback

                # Split the string by underscores
        parts = key.split('_')

            # Take the first part and the last part
        label = f"{parts[0]} - {parts[-2]}_{parts[-1]}"
        ax.plot(recall, precision, alpha=1, lw=lw, linestyle=linestyle, color=colors[key], label=f'{label} ({pr_auc:.3f})')

    fig.set_size_inches(5,4.1)

    mean_precisions = np.array(mean_precisions)
    mean_precision = mean_precisions.mean(axis=0)
    std_precision = mean_precisions.std(axis=0)
    pr_auc = auc(base_recall, mean_precision)

    if fill_bet:
        precision_upper = np.minimum(mean_precision + std_precision, 1)
        precision_lower = mean_precision - std_precision
        ax.fill_between(base_recall, precision_lower, precision_upper, color='grey', alpha=0.2)

    ax.set_xlim([0.0, 1.0])
    ax.set_ylim(ylim)
    ax.tick_params(axis='both', which='major', pad=1)
    ax.tick_params(axis='both', which='minor', pad=1)
    if x_label is None:
         ax.set_xlabel('Recall (TP / (TP + FN))', fontsize=font_size)
    else :
        ax.set_xlabel(x_label, fontsize=font_size)

    if y_label is None:
        ax.set_ylabel('Precision (TP / (TP + FP))', fontsize=font_size)
    else :
        ax.set_ylabel(y_label, fontsize=font_size)
    #ax.xaxis.labelpad = -2
    ax.xaxis.set_tick_params(pad=-5)
    #ax.yaxis.labelpad = -2
    ax.yaxis.set_tick_params(pad=-5)
    ax.tick_params(axis='both', which='major', labelsize=font_size)
    ax.set_title(f"{title}: {display}", fontsize=font_size, pad=5)

    plt.rcParams.update({'font.size': font_size})

    if plot_legend:
        condensed_font = fm.FontProperties(family='DejaVu Sans', style='normal', weight='normal', stretch='condensed')
        ax.legend(loc="upper right", bbox_to_anchor=(1.01, 1), fontsize=12, frameon=False, prop=condensed_font)

    for spine in ax.spines.values():
        spine.set_linewidth(0.8)  # Set the linewidth to 0.5, adjust as needed

    svg_path = os.path.join(fig_path, f"Prec_Recall_{display}_{ylim}.svg")
    fig.savefig(svg_path, format='svg', bbox_inches='tight', transparent=True)

## All Incremental

In [None]:
scenarios_colors = {
    'all_Model_A': '#4995AD',  # Color for scenario A (Covariates)
    'all_Model_B': '#385579',  # Color for scenario B (Covariates, Diagnosis)
    'all_Model_C': '#C13617',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'all_Model_D': '#F0903E',  # Color for scenario D (Covariates, Diagnosis, Blood, Genetics)
    'all_Model_E': '#F0C872',  # Color for scenario E (Covariates, Diagnosis, Blood, Genetics, Metabolomics)
    #'all_Model_Blood': '#C13617',
    'all_aMAP' : '#808080'
}


keys_ordered_all=['all_Model_A', 'all_Model_B', 'all_Model_C','all_Model_D', 'all_Model_E', 'all_aMAP']
# add desired models
colors = {key: scenarios_colors[key] for key in keys_ordered_all}



fig_all, ax_all = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_all, colors, fig=fig_all, ax=ax_all, ylim=(0, 1), display= "All", fill_bet=False, title='Precision-Recall Curves', fig_path=fig_path, dotted_keys=("all_aMAP"), font_size=12)
plt.show()

### All Incremental Zoom in

In [None]:
scenarios_colors = {
    'all_Model_A': '#4995AD',  # Color for scenario A (Covariates)
    'all_Model_B': '#385579',  # Color for scenario B (Covariates, Diagnosis)
    'all_Model_C': '#C13617',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'all_Model_D': '#F0903E',  # Color for scenario D (Covariates, Diagnosis, Blood, Genetics)
    'all_Model_E': '#F0C872',  # Color for scenario E (Covariates, Diagnosis, Blood, Genetics, Metabolomics)
    #'all_Model_Blood': '#C13617',
    'all_aMAP' : '#808080'
}


keys_ordered_all=['all_Model_A', 'all_Model_B', 'all_Model_C','all_Model_D', 'all_Model_E','all_aMAP']
# add desired models
colors = {key: scenarios_colors[key] for key in keys_ordered_all}



fig_all, ax_all = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_all, colors, fig=fig_all, ax=ax_all, ylim=(0, 0.2), display= "All", fill_bet=False, title='Precision-Recall Curves', fig_path=fig_path, dotted_keys=("all_aMAP"), plot_legend=False, lw=3, font_size=16)
plt.show()

## PAR Incremental

In [None]:
scenarios_colors = {
    'par_Model_A': '#4995AD',  # Color for scenario A (Covariates)
    'par_Model_B': '#385579',  # Color for scenario B (Covariates, Diagnosis)
    'par_Model_C': '#C13617',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'par_Model_D': '#F0903E',  # Color for scenario D (Covariates, Diagnosis, Blood, Genetics)
    'par_Model_E': '#F0C872',  # Color for scenario E (Covariates, Diagnosis, Blood, Genetics, Metabolomics)
    'par_aMAP' : '#808080'
}


keys_ordered_par=['par_Model_A', 'par_Model_B', 'par_Model_C','par_Model_D', 'par_Model_E', 'par_aMAP']
colors = {key: scenarios_colors[key] for key in keys_ordered_par}

fig_cld, ax_cld = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_par, colors, fig=fig_cld, ax=ax_cld, display= "Chronic Liver Disease", ylim=(0, 1), fill_bet=False, title='Precision-Recall Curves', fig_path=fig_path, dotted_keys= "par_aMAP", lw=3, font_size=12)
plt.show()

### PAR Incremental Zoom In

In [None]:
scenarios_colors = {
    'par_Model_A': '#4995AD',  # Color for scenario A (Covariates)
    'par_Model_B': '#385579',  # Color for scenario B (Covariates, Diagnosis)
    'par_Model_C': '#C13617',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'par_Model_D': '#F0903E',  # Color for scenario D (Covariates, Diagnosis, Blood, Genetics)
    'par_Model_E': '#F0C872',  # Color for scenario E (Covariates, Diagnosis, Blood, Genetics, Metabolomics)
    'par_amap' : '#808080'
}


keys_ordered_par=['par_Model_A', 'par_Model_B', 'par_Model_C','par_Model_D', 'par_Model_E', 'par_amap']
colors = {key: scenarios_colors[key] for key in keys_ordered_par}

fig_cld, ax_cld = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_par, colors, fig=fig_cld, ax=ax_cld, display= "PAR", ylim=(0, 0.2), fill_bet=False, title='Precision-Recall Curves', fig_path=fig_path, dotted_keys= "par_amap", plot_legend=False, lw=3, font_size=16)
plt.show()

## All Reduced Models

### All

In [None]:
keys_ordered_reduced=['all_Model_TOP75', 'all_Model_TOP30', 'all_Model_TOP15', 'all_Model_AMAP-RFC', 'all_aMAP']

scenarios_colors = {
    'all_aMAP' : '#808080',
    'all_Model_AMAP-RFC': '#c9c9c9',
    'all_Model_TOP75' : '#cb6043',
    'all_Model_TOP30' : '#d1846e',
    'all_Model_TOP15' : '#d0a79a'
}


# add desired models
colors = {key: scenarios_colors[key] for key in keys_ordered_reduced}



fig_all, ax_all = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_reduced, colors, fig=fig_all,
                             ax=ax_all, ylim=(0, 1), display= "All - Reduced_models", fill_bet=False,
                             title='Precision-Recall Curves', fig_path=fig_path, dotted_keys=("all_amap"),
                             plot_legend=True, lw=3, font_size=12, truth="status")
plt.show()

In [None]:
keys_ordered_reduced=['all_Model_TOP30', 'all_Model_TOP15', 'all_Model_AMAP-RFC', 'all_aMAP']
scenarios_colors = {
    'all_aMAP' : '#808080',
    'all_Model_AMAP-RFC': '#c9c9c9',
    #'all_Model_TOP75' : '#cb6043',
    'all_Model_TOP30' : '#d1846e',
    'all_Model_TOP15' : '#d0a79a'
}
colors = {key: scenarios_colors[key] for key in keys_ordered_reduced}

# Without stratification
fig_all, ax_all = plt.subplots()
plot_precision_recall_curves(dataframes, keys_ordered_reduced, colors, fig=fig_all,
                             ax=ax_all, ylim=(0, 1), display="All - Reduced_models", fill_bet=False,
                             title='Precision-Recall Curves', fig_path=fig_path, dotted_keys=("all_amap"),
                             plot_legend=True, lw=3, font_size=12, truth="status")


Closer view

In [None]:
keys_ordered_reduced=['all_Model_TOP30', 'all_Model_TOP15', 'all_Model_AMAP-RFC', 'all_aMAP']
scenarios_colors = {
    'all_aMAP' : '#808080',
    'all_Model_AMAP-RFC': '#c9c9c9',
    #'all_Model_TOP75' : '#cb6043',
    'all_Model_TOP30' : '#d1846e',
    'all_Model_TOP15' : '#d0a79a'
}

# add desired models
colors = {key: scenarios_colors[key] for key in keys_ordered_reduced}



fig_all, ax_all = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_reduced, colors, fig=fig_all,
                             ax=ax_all, ylim=(0, 0.2), display= "All - Reduced_models", fill_bet=False,
                             title='Precision-Recall Curves', fig_path=fig_path, dotted_keys=("all_amap"),
                             plot_legend=False, lw=3, font_size=16, truth="status")
plt.show()

#### PAR

In [None]:
keys_ordered_reduced=['par_Model_C','par_Model_TOP75', 'par_Model_TOP30', 'par_Model_TOP15', 'par_Model_AMAP-RFC', 'par_aMAP']

scenarios_colors = {
    'par_Model_C': '#C13617',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'par_aMAP' : '#808080',
    'par_Model_AMAP-RFC': '#c9c9c9',
    'par_Model_TOP75' : '#cb6043',
    'par_Model_TOP30' : '#d1846e',
    'par_Model_TOP15' : '#d0a79a'
}

# add desired models
colors = {key: scenarios_colors[key] for key in keys_ordered_reduced}



fig_par, ax_par = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_reduced, colors, fig=fig_par, ax=ax_par, ylim=(0, 1), display= "Reduced_models", fill_bet=False, title='Precision-Recall Curves', fig_path=fig_path, dotted_keys=("par_aMAP"), plot_legend=True, lw=3, font_size=12)
plt.show()

In [None]:
keys_ordered_reduced=['par_Model_C','par_Model_TOP75', 'par_Model_TOP30', 'par_Model_TOP15', 'par_Model_AMAP-RFC', 'par_amap']

scenarios_colors = {
    'par_Model_C': '#C13617',  # Color for scenario C (Covariates, Diagnosis, Blood)
    'par_amap' : '#808080',
    'par_Model_AMAP-RFC': '#c9c9c9',
    'par_Model_TOP75' : '#cb6043',
    'par_Model_TOP30' : '#d1846e',
    'par_Model_TOP15' : '#d0a79a'
}

# add desired models
colors = {key: scenarios_colors[key] for key in keys_ordered_reduced}



fig_par, ax_par = plt.subplots()

plot_precision_recall_curves(dataframes, keys_ordered_reduced, colors, fig=fig_par, ax=ax_par, ylim=(0, 0.2), display= "Reduced_models", fill_bet=False, title='Precision-Recall Curves', fig_path=fig_path, dotted_keys=("par_amap"), plot_legend=False, lw=3, font_size=16)
plt.show()