# Notebook 7 - Analyse des seuils

Ce notebook analyse les résultats générés par le notebook 5 (MAE thresholds).
Les figures reprennent celles du notebook 6 mais comparent les différents seuils
pour chaque méthode robuste dans des graphiques séparés.


In [9]:
import os
import pandas as pd
import numpy as np
from pathlib import Path

from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import seaborn as sns

from robust_evaluation_tools.robust_utils import (
    get_metrics,
    add_nb_patients_and_diseased,
)


In [10]:

MAINFOLDER = "RESULTS/MAE_THRESHOLDS"
harmonization_method = "classic"
ANALYSIS_FOLDER = f"{MAINFOLDER}/ANALYSIS_{harmonization_method}"
PLOT_FOLDER = f"{ANALYSIS_FOLDER}/THRESHOLD_PLOTS"

SYNTHETIC_SITES_VERSION = "v1"

metrics = get_metrics()

# Paramètres contrôlant l'analyse (laisser à None pour autodétection)
diseases = ["ALL"]  # mettre None pour prendre toutes les maladies disponibles
sample_sizes = None  # déduire automatiquement les tailles disponibles
disease_ratios = None  # déduire automatiquement les ratios disponibles
num_tests = None  # déduire automatiquement le nombre de répétitions
n_jobs = -1

ROBUST_METHODS = [
    "HC",
    "NoRobust",
    "IQR",
    "MAD",
    "MMS",
    "G_ZS",
    "ZS",
    "SN",
    "QN",
    "MLP7_ALL"
]

Path(PLOT_FOLDER).mkdir(parents=True, exist_ok=True)


In [11]:
def infer_results_grid(mainfolder, harmonization_method, diseases=None):
    """Explore les résultats générés par le notebook 5 et
    renvoie (diseases, sample_sizes, disease_ratios, num_tests).
    """
    base_dir = Path(mainfolder) / f"PROCESS_{harmonization_method}"
    if not base_dir.exists():
        raise FileNotFoundError(f"{base_dir} introuvable.")

    if diseases:
        disease_candidates = diseases
    else:
        disease_candidates = sorted(p.name for p in base_dir.iterdir() if p.is_dir())

    detected_diseases = []
    sample_sizes = set()
    disease_ratios = set()
    inferred_num_tests = 0

    for disease in disease_candidates:
        disease_dir = base_dir / disease
        if not disease_dir.is_dir():
            continue
        detected_diseases.append(disease)
        for size_ratio_dir in disease_dir.iterdir():
            if not size_ratio_dir.is_dir() or "_" not in size_ratio_dir.name:
                continue
            size_part, ratio_part = size_ratio_dir.name.split("_", 1)
            try:
                sample_size_val = int(size_part)
                disease_ratio_val = int(ratio_part) / 100
            except ValueError:
                continue
            sample_sizes.add(sample_size_val)
            disease_ratios.add(disease_ratio_val)

            run_ids = [
                int(run_dir.name)
                for run_dir in size_ratio_dir.iterdir()
                if run_dir.is_dir() and run_dir.name.isdigit()
            ]
            if run_ids:
                inferred_num_tests = max(inferred_num_tests, max(run_ids) + 1, len(run_ids))

    detected_diseases = sorted(set(detected_diseases))
    if not detected_diseases or not sample_sizes or not disease_ratios or inferred_num_tests == 0:
        raise ValueError(
            f"Impossible de déduire les paramètres depuis {base_dir}."
        )
    return detected_diseases, sorted(sample_sizes), sorted(disease_ratios), inferred_num_tests


def load_mae_or_maev_compilations(mainfolder, diseases, sample_sizes, disease_ratios, num_tests, mae_or_maev="mae"):
    tests, trains = [], []
    for d in diseases:
        for s in sample_sizes:
            for r in disease_ratios:
                for i in range(num_tests):
                    base = os.path.join(mainfolder, f"PROCESS_{harmonization_method}", d, f"{s}_{int(r*100)}", str(i))
                    test_path  = os.path.join(base, f"{mae_or_maev}_compilation_test.csv")
                    train_path = os.path.join(base, f"{mae_or_maev}_compilation_train.csv")
                    if os.path.isfile(test_path):
                        tests.append(pd.read_csv(test_path))
                    if os.path.isfile(train_path):
                        trains.append(pd.read_csv(train_path))
    df_test  = pd.concat(tests,  ignore_index=True) if tests  else pd.DataFrame()
    df_train = pd.concat(trains, ignore_index=True) if trains else pd.DataFrame()
    return df_test, df_train


def load_compilation(mae_or_maev: str,
                     split: str,
                     *,
                     mainfolder: str,
                     diseases: list[str],
                     sample_sizes: list[int],
                     disease_ratios: list[int],
                     num_tests: int) -> pd.DataFrame:
    if mae_or_maev not in {"mae", "maev", "smape", "std_mae"}:
        raise ValueError("mae_or_maev doit être 'mae' ou 'maev'")
    if split not in {"test", "train"}:
        raise ValueError("split doit être 'test' ou 'train'")

    df_test, df_train = load_mae_or_maev_compilations(
        mainfolder,
        diseases,
        sample_sizes,
        disease_ratios,
        num_tests,
        mae_or_maev=mae_or_maev
    )
    return df_test if split == "test" else df_train


In [12]:


def transformer_df_large_en_long(df_large: pd.DataFrame) -> pd.DataFrame:
    context_cols = ["site", "method", "robust_method", "robust_threshold", "disease", "metric"]
    bundle_cols = [col for col in df_large.columns if col not in context_cols]

    df_long = df_large.melt(
        id_vars=context_cols,
        value_vars=bundle_cols,
        var_name="bundle",
        value_name="mae"
    )
    return df_long


def compute_base_method(row: pd.Series) -> str:
    label = str(row.get("robust_method", ""))
    method = str(row.get("method", ""))
    lower_label = label.lower()
    if lower_label == "no":
        return "HC" if method == "hc" else "NoRobust"
    if lower_label == "hc":
        return "HC"
    for prefix in ROBUST_METHODS:
        if label.startswith(prefix):
            return prefix
    return label


def format_threshold_label(threshold) -> str:
    try:
        if pd.isna(threshold):
            return "default"
    except Exception:
        pass

    if isinstance(threshold, str):
        return threshold
    try:
        value = float(threshold)
    except (TypeError, ValueError):
        return str(threshold)
    if pd.isna(value):
        return "default"
    if value.is_integer():
        return str(int(value))
    formatted = f"{value}"
    return formatted.rstrip("0").rstrip(".")


def build_threshold_pivot(df_long: pd.DataFrame, base_method: str) -> pd.DataFrame:
    subset = df_long[df_long["base_robust_method"] == base_method].copy()
    if subset.empty:
        return pd.DataFrame()

    subset["threshold_value"] = pd.to_numeric(subset["robust_threshold"], errors="coerce")

    pivot_df = subset.pivot_table(
        index=["site", "disease", "metric", "bundle", "num_patients", "disease_ratio", "num_diseased"],
        columns="threshold_value",
        values="mae",
        aggfunc="first",
    )
    ordered_cols = sorted(
        pivot_df.columns,
        key=lambda v: (0 if (isinstance(v, float) and np.isnan(v)) else 1, float(v) if not (isinstance(v, float) and np.isnan(v)) else -1)
    )
    pivot_df = pivot_df.reindex(columns=ordered_cols)
    return pivot_df


def order_threshold_columns(df_filt: pd.DataFrame) -> list:
    context_cols = {
        "site",
        "disease",
        "metric",
        "bundle",
        "num_patients",
        "disease_ratio",
        "num_diseased",
    }
    all_cols = [c for c in df_filt.columns if c not in context_cols]

    def _key(val):
        if isinstance(val, float) and np.isnan(val):
            return (0, -1)
        try:
            return (1, float(val))
        except Exception:
            return (2, str(val))

    base = [c for c in ["HC", "NoRobust"] if c in all_cols]
    rest = [c for c in all_cols if c not in base]
    rest_sorted = sorted(rest, key=_key)
    return base + rest_sorted


def build_color_map(ordered_cols: list) -> dict:
    col_colors = {}
    if "HC" in ordered_cols:
        col_colors["HC"] = "green"
    if "NoRobust" in ordered_cols:
        col_colors["NoRobust"] = "red"
    remaining = [c for c in ordered_cols if c not in col_colors]
    if remaining:
        palette = sns.color_palette("viridis", len(remaining))
        col_colors.update(dict(zip(remaining, palette)))
    return col_colors


def rank_methods_per_row(pivot_df: pd.DataFrame) -> pd.DataFrame:
    method_cols = pivot_df.select_dtypes(include="number").columns
    rank_df = (
        pivot_df[method_cols]
        .rank(axis=1, method="min", ascending=True)
        .astype(int)
    )
    return rank_df


def add_baselines(pivot_df: pd.DataFrame, baselines: dict[str, pd.Series]) -> pd.DataFrame:
    df = pivot_df.copy()
    for name, series in baselines.items():
        if series is None or series.empty:
            continue
        df = df.join(series, how="left")
    return df


def plot_mae_mean_all_ratios(
    pivot_df, sample_size,
    directory, method_name, dataset_type, Y="MAE"
):
    df_filt = pivot_df.loc[
        pivot_df.index.get_level_values("num_patients") == sample_size
    ].reset_index()

    ordered_cols = order_threshold_columns(df_filt)
    if not ordered_cols:
        return

    col_colors = build_color_map(ordered_cols)
    means = [df_filt[col].dropna().mean() for col in ordered_cols]

    x = np.arange(len(ordered_cols))
    bar_w = 0.7

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(
        x, means,
        width=bar_w,
        color=[col_colors[c] for c in ordered_cols],
        edgecolor="black"
    )

    ax.set_xticks(x)
    ax.set_xticklabels([format_threshold_label(c) for c in ordered_cols], rotation=45, ha="right")
    ax.set_ylabel(f"{Y}")
    ax.set_title(
        f"{Y} - {method_name}\n"
        f"Nb patients: {sample_size}   |   Dataset: {dataset_type}"
    )
    ax.axhline(y=0, color="black", linestyle="--", linewidth=1)
    plt.tight_layout()

    out_dir = os.path.join(directory, f"{Y}_PLOTS_MEAN",
                           "ALL_DISEASES_ALL_METRICS_ALL_RATIOS",
                           str(sample_size))
    os.makedirs(out_dir, exist_ok=True)
    fname = f"{Y}_mean_all_ratios_{dataset_type}.png"
    plt.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
    plt.close()


def plot_mae_mean_all_diseases_metrics(
    pivot_df, sample_size,
    directory, method_name, dataset_type, Y="MAE"
):
    df_filt = pivot_df.loc[
        pivot_df.index.get_level_values("num_patients") == sample_size
    ].reset_index()

    ordered_cols = order_threshold_columns(df_filt)
    if not ordered_cols:
        return

    col_colors = build_color_map(ordered_cols)

    ratios = sorted(df_filt["disease_ratio"].unique())
    x = np.arange(len(ratios))
    g_width   = .8
    n_methods = len(ordered_cols)
    bar_w     = g_width / n_methods

    fig, ax = plt.subplots(figsize=(14, 7))

    for i_m, col in enumerate(ordered_cols):
        means = [
            df_filt[df_filt["disease_ratio"] == r][col].dropna().mean()
            for r in ratios
        ]

        pos = x - g_width / 2 + (i_m + .5) * bar_w
        ax.bar(
            pos, means,
            width=bar_w * .9,
            color=col_colors[col],
            edgecolor="black",
            label=format_threshold_label(col)
        )

    ax.set_xlabel("Proportion of patients")
    ax.set_ylabel(f"{Y}")
    ax.set_title(
        f"{Y} - {method_name}\n"
        f"Nb patients : {sample_size}   |   Dataset : {dataset_type}"
    )
    ax.set_xticks(x)
    ax.set_xticklabels(ratios)
    ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
    ax.axhline(y=0, color="black", linestyle="--", linewidth=1)
    plt.tight_layout()

    out_dir = os.path.join(directory, f"{Y}_PLOTS_MEAN",
                           f"ALL_DISEASES_ALL_METRICS", str(sample_size))
    os.makedirs(out_dir, exist_ok=True)
    fname = f"{Y}_mean_all_diseases_metrics_{dataset_type}.png"
    plt.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
    plt.close()


def _plot_rank_barplot(
    pivot_df, sample_size, disease, directory,
    method_name, dataset_type, Y="MAE", metric=None, aggregate_metrics=False
):
    cond = (
        (pivot_df.index.get_level_values("num_patients") == sample_size) &
        (pivot_df.index.get_level_values("disease")      == disease)
    )

    if metric is not None:
        cond &= pivot_df.index.get_level_values("metric") == metric

    df_filt = pivot_df.loc[cond].reset_index()
    if df_filt.empty:
        return

    ordered_cols = order_threshold_columns(df_filt)
    if not ordered_cols:
        return

    col_colors = build_color_map(ordered_cols)

    ratios = sorted(df_filt["disease_ratio"].unique())
    x = np.arange(len(ratios))
    g_width   = .8
    n_methods = len(ordered_cols)
    bar_w     = g_width / n_methods

    fig, ax = plt.subplots(figsize=(14, 7))

    for i_m, col in enumerate(ordered_cols):
        values_per_ratio = [
            df_filt[df_filt["disease_ratio"] == r][col].dropna().values
            for r in ratios
        ]
        means = np.array([
            vals.mean() if len(vals) else np.nan
            for vals in values_per_ratio
        ], dtype=float)
        stds = np.array([
            vals.std(ddof=0) if len(vals) else 0.0
            for vals in values_per_ratio
        ], dtype=float)

        pos = x - g_width / 2 + (i_m + .5) * bar_w
        ax.bar(
            pos, means,
            width=bar_w * .9,
            color=col_colors[col],
            edgecolor="black",
            label=format_threshold_label(col)
        )
        ax.errorbar(
            pos, means,
            yerr=stds,
            fmt="none",
            ecolor="black",
            elinewidth=1,
            capsize=3
        )

    ax.set_xlabel("Proportion of patients")
    ylabel = f"{Y}"
    ax.set_ylabel(ylabel)

    if aggregate_metrics:
        ax.set_title(
            f"{Y}, tous bundles et métriques confondus - {method_name}\n"
            f"Maladie : {disease}   |   Nb patients : {sample_size}   |   Dataset : {dataset_type}"
        )
        fname = f"{Y}_all_metrics_mean_{dataset_type}.png"
    else:
        ax.set_title(
            f"{Y}, tous bundles confondus - {method_name}\n"
            f"Maladie : {disease}   |   Metric : {metric}\n"
            f"Nb patients : {sample_size}   |   Dataset : {dataset_type}"
        )
        fname = f"{Y}_{metric}_mean_{dataset_type}.png"

    ax.set_xticks(x)
    ax.set_xticklabels(ratios)

    ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
    ax.axhline(y=0, color="black", linestyle="--", linewidth=1)
    plt.tight_layout()
    out_dir = os.path.join(directory, f"{Y}_PLOTS_MEAN", disease,
                           str(sample_size))
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
    plt.close()


def plot_rank(
    pivot_df, sample_size, disease, metric,
    directory, method_name, dataset_type, Y="MAE"
):
    _plot_rank_barplot(
        pivot_df, sample_size, disease, directory, method_name, dataset_type,
        Y=Y, metric=metric, aggregate_metrics=False
    )


def plot_rank_all_metrics(
    pivot_df, sample_size, disease,
    directory, method_name, dataset_type, Y="MAE"
):
    _plot_rank_barplot(
        pivot_df, sample_size, disease, directory, method_name, dataset_type,
        Y=Y, metric=None, aggregate_metrics=True
    )


def plot_mae_all_bundles_pivot(
    pivot_df, sample_size, disease, metric,
    directory, method_name, dataset_type, Y="MAE", bundle=None,
):
    cond = (
        (pivot_df.index.get_level_values("num_patients") == sample_size) &
        (pivot_df.index.get_level_values("disease")      == disease) &
        (pivot_df.index.get_level_values("metric")       == metric)
    )

    if bundle is not None:
        if isinstance(bundle, (list, tuple, set)):
            cond &= pivot_df.index.get_level_values("bundle").isin(bundle)
        else:
            cond &= pivot_df.index.get_level_values("bundle") == bundle

    df_filt = pivot_df.loc[cond].reset_index()
    ordered_cols = order_threshold_columns(df_filt)
    if df_filt.empty or not ordered_cols:
        return

    col_colors = build_color_map(ordered_cols)

    ratios = sorted(df_filt["disease_ratio"].unique())
    x = np.arange(len(ratios))
    g_width = .8
    n_methods = len(ordered_cols)
    box_w = g_width / n_methods

    fig, ax = plt.subplots(figsize=(14, 7))

    for i_m, col in enumerate(ordered_cols):
        data = [
            df_filt[df_filt["disease_ratio"] == r][col].dropna().values
            for r in ratios
        ]
        if not any(len(d) for d in data):
            continue

        pos = x - g_width / 2 + (i_m + .5) * box_w
        ax.boxplot(
            data,
            positions=pos,
            widths=box_w * .8,
            patch_artist=True,
            showfliers=False,
            boxprops=dict(facecolor=col_colors[col],
                          edgecolor=col_colors[col]),
            medianprops=dict(color="black")
        )

    ax.set_xlabel("Proportion of patients")
    ax.set_ylabel(Y)
    bundle_str = ", ".join(bundle) if isinstance(bundle, (list, tuple, set)) else bundle
    ax.set_title(
        f"{Y} d’harmonisation - {method_name}"
        + (f", bundle : {bundle_str}" if bundle is not None else ", tous bundles confondus")
        + f"\nMaladie : {disease}   |   Metric : {metric}"
        + f"\nNb patients : {sample_size}   |   Dataset : {dataset_type}"
    )
    ax.set_xticks(x)
    ax.set_xticklabels(ratios)

    handles = [plt.Line2D([0], [0], color=col_colors[c], lw=3, label=format_threshold_label(c))
               for c in ordered_cols]
    ax.legend(handles=handles, loc="upper left", bbox_to_anchor=(1, 1))
    ax.axhline(y=0, color="black", linestyle="--", linewidth=1)
    plt.tight_layout()

    out_dir = os.path.join(directory, f"{Y}_PLOTS_NEW", disease, str(sample_size))
    if bundle is not None:
        out_dir = os.path.join(out_dir, metric)
    os.makedirs(out_dir, exist_ok=True)
    bundle_suffix = bundle_str.replace(" ", "_") if bundle is not None else "all_bundles"
    plt.savefig(os.path.join(out_dir, f"{Y}_{metric}_{bundle_suffix}_boxplot_{dataset_type}.png"),
                bbox_inches="tight")
    plt.close()


def plot_mae_each_bundle(
    pivot_df, sample_size, disease, metric,
    directory, method_name, dataset_type, Y="MAE",
):
    bundles = pivot_df.loc[
        (pivot_df.index.get_level_values("num_patients") == sample_size) &
        (pivot_df.index.get_level_values("disease")      == disease) &
        (pivot_df.index.get_level_values("metric")       == metric)
    ].index.get_level_values("bundle").unique()

    for b in bundles:
        plot_mae_all_bundles_pivot(
            pivot_df, sample_size, disease, metric,
            directory, method_name, dataset_type, Y=Y, bundle=b,
        )


In [13]:
detected_diseases, detected_sample_sizes, detected_disease_ratios, detected_num_tests = infer_results_grid(
    MAINFOLDER, harmonization_method, diseases if diseases is not None else None
)

if diseases is None:
    diseases = detected_diseases
else:
    missing = sorted(set(diseases) - set(detected_diseases))
    if missing:
        print(f"Pas de données pour : {missing}")
    diseases = [d for d in diseases if d in detected_diseases]
    if not diseases:
        raise ValueError("Aucune maladie valide trouvée dans les résultats.")

sample_sizes = detected_sample_sizes if sample_sizes is None else sample_sizes
disease_ratios = detected_disease_ratios if disease_ratios is None else disease_ratios
num_tests = detected_num_tests if num_tests is None else num_tests

print(f"Maladies analysées : {diseases}")
print(f"Tailles d'échantillon : {sample_sizes}")
print(f"Ratios de malades : {disease_ratios}")
print(f"Nombre de répétitions : {num_tests}")

std_mae_compilation_train_all = load_compilation("std_mae", "train",
                mainfolder=MAINFOLDER,
                diseases=diseases,
                sample_sizes=sample_sizes,
                disease_ratios=disease_ratios,
                num_tests=num_tests)
std_mae_compilation_test_all = load_compilation("std_mae", "test",
                mainfolder=MAINFOLDER,
                diseases=diseases,
                sample_sizes=sample_sizes,
                disease_ratios=disease_ratios,
                num_tests=num_tests)


Maladies analysées : ['ALL']
Tailles d'échantillon : [100]
Ratios de malades : [0.03, 0.1, 0.3, 0.5, 0.7, 0.8]
Nombre de répétitions : 2


In [14]:

std_mae_long = transformer_df_large_en_long(std_mae_compilation_train_all)
std_mae_long = add_nb_patients_and_diseased(std_mae_long)
std_mae_long["robust_threshold"] = std_mae_long["robust_threshold"].fillna(0)
std_mae_long["base_robust_method"] = std_mae_long.apply(compute_base_method, axis=1)
std_mae_long = std_mae_long[std_mae_long["base_robust_method"].isin(ROBUST_METHODS)].copy()

sites_with_nan_std = (
    std_mae_long
      .groupby("site")
      .filter(lambda g: g.isna().any().any())
      ["site"]
      .unique()
)

n_nan_sites_std = len(sites_with_nan_std)
n_total_sites_std = std_mae_long["site"].nunique()

print(f"Sites exclus pour NaN : {n_nan_sites_std} / {n_total_sites_std}")
if n_nan_sites_std:
    print("Liste :", list(sites_with_nan_std))

std_mae_long = std_mae_long[~std_mae_long["site"].isin(sites_with_nan_std)].copy()

pivot_by_method = {
    method: build_threshold_pivot(std_mae_long, method)
    for method in ROBUST_METHODS
}

baseline_series = {}
for name in ["HC", "NoRobust"]:
    pivot_ref = pivot_by_method.get(name)
    if pivot_ref is None or pivot_ref.empty:
        continue
    series = pivot_ref.iloc[:, 0]
    series.name = name
    baseline_series[name] = series

pivot_by_method = {
    method: add_baselines(pivot, baseline_series)
    for method, pivot in pivot_by_method.items()
}

ranked_by_method = {
    method: rank_methods_per_row(pivot)
    for method, pivot in pivot_by_method.items()
    if not pivot.empty
}


Sites exclus pour NaN : 0 / 12


In [15]:
pivot_by_method 


{'HC':                                                                                                      0.0  \
 site                          disease metric bundle    num_patients disease_ratio num_diseased             
 ALL_100_patients_10_percent_0 ALL     ad     mni_AC    100          10            10            0.154314   
                                              mni_AF_L  100          10            10            0.172547   
                                              mni_AF_R  100          10            10            0.146176   
                                              mni_AST_L 100          10            10            0.111345   
                                              mni_AST_R 100          10            10            0.169562   
 ...                                                                                                  ...   
 ALL_100_patients_80_percent_1 ALL     rdt    mni_STT_R 100          80            80            0.194890   
             

In [16]:
Y_LABEL = "STD_MAE"

for method, pivot_df in pivot_by_method.items():
    if pivot_df.empty:
        print(f"Aucune donnée pour {method}")
        continue

    method_dir = os.path.join(PLOT_FOLDER, method)
    Path(method_dir).mkdir(parents=True, exist_ok=True)

    tasks_mean_all = [
        (pivot_df, sample_size, method_dir, method, "train", Y_LABEL)
        for sample_size  in sample_sizes
    ]
    Parallel(n_jobs=n_jobs)(
        delayed(plot_mae_mean_all_diseases_metrics)(*task) for task in tasks_mean_all
    )
    Parallel(n_jobs=n_jobs)(
        delayed(plot_mae_mean_all_ratios)(*task) for task in tasks_mean_all
    )

    tasks_all_metrics = [
        (pivot_df, sample_size, disease, method_dir, method, "train", Y_LABEL)
        for disease      in diseases
        for sample_size  in sample_sizes
    ]
    Parallel(n_jobs=n_jobs)(
        delayed(plot_rank_all_metrics)(*task) for task in tasks_all_metrics
    )

    tasks = [
        (pivot_df, sample_size, disease, metric, method_dir, method, "train", Y_LABEL)
        for disease      in diseases
        for sample_size  in sample_sizes
        for metric       in metrics
    ]

    Parallel(n_jobs=n_jobs)(
        delayed(plot_rank)(*task) for task in tasks
    )

    Parallel(n_jobs=n_jobs)(
        delayed(plot_mae_all_bundles_pivot)(*task) for task in tasks
    )
    Parallel(n_jobs=n_jobs)(
        delayed(plot_mae_each_bundle)(*task) for task in tasks
    )
