# Définir seuil méthode robuste

Ce notebook explore différents seuils pour plusieurs méthodes robustes en réutilisant les jeux synthétiques «ALL».\
Il automatise l'évaluation (MAE) pour chaque seuil, puis génère les mêmes types de graphiques que l'analyse MAE :\
10 courbes par métrique et une moyenne globale par méthode.

⚠️ Les boucles peuvent être coûteuses : ajustez les filtres (tests, ratios) avant d'exécuter l'évaluation complète.


In [None]:
import math
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from robust_evaluation_tools.robust_utils import (
    get_metrics,
    get_camcan_file,
    format_param_token,
)
from robust_evaluation_tools.robust_harmonization import (
    fit,
    apply,
    compare_with_compilation,
)


In [None]:
HARMONIZATION_METHOD = "gmm"

PROCESSED_ROOT = Path("RESULTS/MAE_TEST/PROCESS_gmm/ALL")
SYNTHETIC_ROOT = Path("RESULTS/MAE_TEST/SYNTHETIC_SITES/v1/ALL")
OUTPUT_ROOT = Path("RESULTS/ROBUST_THRESHOLD")
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)

ROBUST_THRESHOLD_GRID = {
    "IQR": [0.75, 1.0, 1.5, 2.0, 2.5],
    "MAD": [2.0, 2.5, 3.0, 3.5, 4.0],
    "Z_SCORE_BUNDLE": [1.5, 2.0, 2.5, 3.0, 3.5],
    "Z_SCORE_METRIC": [1.5, 2.0, 2.5, 3.0],
    "SN": [2.0, 2.5, 3.0, 3.5],
    "QN": [2.0, 2.5, 3.0, 3.5],
}

METRICS = get_metrics()
CAMCAN_REF = {metric: get_camcan_file(metric) for metric in METRICS}

# Ajustez les filtres pour limiter la charge (None => tous disponibles)
SELECTED_SAMPLE_SIZES = {100}   # mettre None pour tous
SELECTED_RATIOS = None          # exemple: {3, 10, 30}
SELECTED_TEST_IDS = set(range(3))  # mettre None pour utiliser les 20 tests
USE_RWP = False
HC_ONLY = False
LOAD_EXISTING = False


def _to_optional_set(values):
    if values is None:
        return None
    if isinstance(values, set):
        return values
    return set(values)


SELECTED_SAMPLE_SIZES = _to_optional_set(SELECTED_SAMPLE_SIZES)
SELECTED_RATIOS = _to_optional_set(SELECTED_RATIOS)
SELECTED_TEST_IDS = _to_optional_set(SELECTED_TEST_IDS)

print(f"Méthodes configurées : {', '.join(ROBUST_THRESHOLD_GRID)}")
print(f"Métriques : {', '.join(METRICS)}")


In [None]:
def iter_cases(processed_root, synthetic_root, sample_filter=None, ratio_filter=None, test_filter=None):
    cases = []
    for sr_dir in sorted(processed_root.iterdir()):
        if not sr_dir.is_dir():
            continue
        try:
            sample_str, ratio_str = sr_dir.name.split("_")
            sample_size = int(sample_str)
            ratio_percent = int(ratio_str)
        except ValueError:
            continue
        if sample_filter and sample_size not in sample_filter:
            continue
        if ratio_filter and ratio_percent not in ratio_filter:
            continue
        synthetic_sr_dir = synthetic_root / sr_dir.name
        if not synthetic_sr_dir.exists():
            continue
        test_dirs = [d for d in sr_dir.iterdir() if d.is_dir() and d.name.isdigit()]
        for test_dir in sorted(test_dirs, key=lambda p: int(p.name)):
            test_idx = int(test_dir.name)
            if test_filter and test_idx not in test_filter:
                continue
            synthetic_test_dir = synthetic_sr_dir / test_dir.name
            if not synthetic_test_dir.exists():
                continue
            cases.append({
                "sample_size": sample_size,
                "ratio": ratio_percent / 100.0,
                "ratio_percent": ratio_percent,
                "label": sr_dir.name,
                "test_index": test_idx,
                "processed_dir": test_dir,
                "synthetic_dir": synthetic_test_dir,
            })
    return cases


def evaluate_method_thresholds(method, thresholds, cases, metrics):
    metric_records = []
    bundle_records = []
    for threshold in thresholds:
        robust_params = {"threshold": threshold}
        total_steps = len(cases) * len(metrics)
        if total_steps == 0:
            continue
        with tqdm(total=total_steps, desc=f"{method} | thr={threshold}") as pbar:
            for case in cases:
                sample_size = case["sample_size"]
                ratio = case["ratio"]
                ratio_percent = case["ratio_percent"]
                label = case["label"]
                test_idx = case["test_index"]
                processed_dir = case["processed_dir"]
                synthetic_dir = case["synthetic_dir"]
                for metric in metrics:
                    metric_dir = processed_dir / metric
                    train_file = metric_dir / f"train_{label}_{test_idx}_{metric}.csv"
                    test_file = metric_dir / f"test_{label}_{test_idx}_{metric}.csv"
                    if not train_file.exists() or not test_file.exists():
                        pbar.update(1)
                        continue
                    gt_test_file = synthetic_dir / f"gt_test_{label}_{test_idx}_{metric}.csv"
                    if not gt_test_file.exists():
                        pbar.update(1)
                        continue
                    run_dir = OUTPUT_ROOT / method / f"thr_{format_param_token(threshold)}" / label / str(test_idx) / metric
                    run_dir.mkdir(parents=True, exist_ok=True)
                    model_path = fit(
                        str(train_file),
                        CAMCAN_REF[metric],
                        metric,
                        HARMONIZATION_METHOD,
                        method,
                        USE_RWP,
                        str(run_dir),
                        HC_ONLY,
                        robust_params=robust_params,
                    )
                    test_output = apply(
                        str(test_file),
                        model_path,
                        metric,
                        HARMONIZATION_METHOD,
                        method,
                        USE_RWP,
                        str(run_dir),
                        robust_params=robust_params,
                    )
                    mae_df = compare_with_compilation(
                        pd.read_csv(test_output),
                        pd.read_csv(gt_test_file),
                    )
                    if mae_df.empty:
                        pbar.update(1)
                        continue
                    row = mae_df.iloc[0]
                    mean_mae = float(row.mean())
                    metric_records.append({
                        "method": method,
                        "threshold": threshold,
                        "sample_size": sample_size,
                        "disease_ratio": ratio,
                        "ratio_percent": ratio_percent,
                        "test_index": test_idx,
                        "metric": metric,
                        "mae": mean_mae,
                    })
                    for bundle_name, value in row.items():
                        bundle_records.append({
                            "method": method,
                            "threshold": threshold,
                            "sample_size": sample_size,
                            "disease_ratio": ratio,
                            "ratio_percent": ratio_percent,
                            "test_index": test_idx,
                            "metric": metric,
                            "bundle": bundle_name,
                            "mae": float(value),
                        })
                    pbar.update(1)
    return (
        pd.DataFrame(metric_records),
        pd.DataFrame(bundle_records),
    )


def run_all_methods():
    cases = iter_cases(
        PROCESSED_ROOT,
        SYNTHETIC_ROOT,
        sample_filter=SELECTED_SAMPLE_SIZES,
        ratio_filter=SELECTED_RATIOS,
        test_filter=SELECTED_TEST_IDS,
    )
    print(f"{len(cases)} cas sélectionnés.")
    metric_frames = []
    bundle_frames = []
    for method, thresholds in ROBUST_THRESHOLD_GRID.items():
        if not thresholds:
            continue
        metric_df, bundle_df = evaluate_method_thresholds(method, thresholds, cases, METRICS)
        if not metric_df.empty:
            metric_frames.append(metric_df)
        if not bundle_df.empty:
            bundle_frames.append(bundle_df)
    metric_results = pd.concat(metric_frames, ignore_index=True) if metric_frames else pd.DataFrame()
    bundle_results = pd.concat(bundle_frames, ignore_index=True) if bundle_frames else pd.DataFrame()
    return metric_results, bundle_results


In [None]:
summaries_dir = OUTPUT_ROOT / "summaries"
summaries_dir.mkdir(parents=True, exist_ok=True)

if LOAD_EXISTING and (summaries_dir / "metric_results.csv").exists():
    metric_results = pd.read_csv(summaries_dir / "metric_results.csv")
    bundle_results = pd.read_csv(summaries_dir / "bundle_results.csv")
    print("Résultats chargés depuis les fichiers existants.")
else:
    metric_results, bundle_results = run_all_methods()
    if not metric_results.empty:
        metric_results.to_csv(summaries_dir / "metric_results.csv", index=False)
    if not bundle_results.empty:
        bundle_results.to_csv(summaries_dir / "bundle_results.csv", index=False)
    print("Évaluation terminée.")

metric_results.head()


In [None]:
if metric_results.empty:
    raise ValueError("Aucun résultat MAE n'a été généré. Vérifiez les filtres ou relancez l'évaluation.")

metric_summary = (
    metric_results
    .groupby(["method", "threshold", "metric"], as_index=False)
    .agg(
        mean_mae=("mae", "mean"),
        std_mae=("mae", "std"),
        n_cases=("mae", "count"),
    )
)

global_summary = (
    metric_results
    .groupby(["method", "threshold"], as_index=False)
    .agg(
        mean_mae=("mae", "mean"),
        std_mae=("mae", "std"),
        n_cases=("mae", "count"),
    )
)

metric_summary.to_csv(summaries_dir / "metric_summary.csv", index=False)
global_summary.to_csv(summaries_dir / "global_summary.csv", index=False)

metric_summary.head()


In [None]:
sns.set_theme(style="whitegrid")
plots_dir = OUTPUT_ROOT / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)

for method in ROBUST_THRESHOLD_GRID:
    method_metric = metric_summary[metric_summary["method"] == method]
    if method_metric.empty:
        continue
    method_dir = plots_dir / method
    method_dir.mkdir(parents=True, exist_ok=True)

    for metric in sorted(method_metric["metric"].unique()):
        data = method_metric[method_metric["metric"] == metric].sort_values("threshold")
        fig, ax = plt.subplots(figsize=(8, 5))
        ax.errorbar(
            data["threshold"],
            data["mean_mae"],
            yerr=data["std_mae"].fillna(0.0),
            fmt="-o",
            capsize=4,
            linewidth=2,
        )
        ax.set_xlabel("Seuil")
        ax.set_ylabel("MAE moyen")
        ax.set_title(f"{method} – {metric.upper()}")
        ax.set_xticks(data["threshold"])
        ax.grid(True, axis="y", alpha=0.3)
        fig.tight_layout()
        fig.savefig(method_dir / f"{metric}_threshold_sweep.png", dpi=200)
        plt.close(fig)

    data_global = global_summary[global_summary["method"] == method].sort_values("threshold")
    if data_global.empty:
        continue
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.errorbar(
        data_global["threshold"],
        data_global["mean_mae"],
        yerr=data_global["std_mae"].fillna(0.0),
        fmt="-o",
        capsize=4,
        linewidth=2,
    )
    ax.set_xlabel("Seuil")
    ax.set_ylabel("MAE moyen (toutes métriques)")
    ax.set_title(f"{method} – Moyenne globale")
    ax.set_xticks(data_global["threshold"])
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()
    fig.savefig(method_dir / "global_threshold_sweep.png", dpi=200)
    plt.close(fig)

plots_dir
