# Diagnostic MLP – 430 Features

Version 0.2  
Auteur : Yoan  
Date : 2025‑06‑30

Objectif : prédire automatiquement si un patient est malade (1) ou sain (0) à l’aide de 430 caractéristiques numériques extraites de données d’IRM.

Le notebook suit un pipeline complet : ingestion, feature engineering, split, hyperparameter tuning, entraînement final, évaluation et explicabilité.


In [None]:
# --- Configuration générale ---
from pathlib import Path
from robust_evaluation_tools.robust_MLP import PatientMLP, MODEL_DIR, build_mlp_from_config, DEFAULT_MODEL_CONFIG
DATA_DIR  = Path("DONNES_F/COMPILATIONS_AUG_3/")      # <-- adapte si besoin
disease = "ALL"
RUN_NAME  = f"mlp7_{disease}"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
SEED = 41

from robust_evaluation_tools.MLP_train import (
    PatientDataset, make_loaders, train_epoch, eval_epoch, fit
)


## Guide d’utilisation & Configuration unique

- Tous les paramètres du modèle et de l’entraînement sont centralisés ci‑dessous.
- Pour changer le nombre de couches, modifiez `MODEL_CFG["hidden_dims"]` (ex: `[512, 256]`).
- Activez/désactivez Optuna via `USE_OPTUNA`.
- Les artefacts sont enregistrés dans `Pytorch_models/` avec `RUN_NAME`.


In [None]:
# --- Config centralisée (modèle + entraînement) ---
from robust_evaluation_tools.robust_MLP import build_mlp_from_config, DEFAULT_MODEL_CONFIG

USE_OPTUNA = False  # False: utilise MODEL_CFG; True: lance l’Optuna

MODEL_CFG = {
    **DEFAULT_MODEL_CONFIG,
    # Modèle
    "in_features": 430,              # change si le nb de features évolue
    "hidden_dims": [256, 128, 64],   # nb de couches = len(hidden_dims)
    "activation": "relu",           # "relu" | "gelu" | "leaky_relu" | "elu" | "tanh"
    "batch_norm": True,
    "dropout": 0.5,                  # float ou liste par couche, ex: [0.1, 0.2, 0.2]
}

TRAIN_CFG = {
    "batch_size": 64,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "epochs": 100,
    "patience": 10,
    "neg_weight": 10.0,              # pondération de la classe 0 dans la loss
}

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


In [None]:
# INSTALLATION (décommente si nécessaire)
# %pip install -q pandas numpy scikit-learn torch optuna shap tensorboard joblib tqdm


In [None]:
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (classification_report, roc_auc_score, f1_score,
                             confusion_matrix, ConfusionMatrixDisplay, RocCurveDisplay,
                             PrecisionRecallDisplay)
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import joblib, random, os, json, optuna
from tqdm.auto import tqdm

from robust_evaluation_tools.synthectic_sites_generations import augment_df, split_train_test, generate_sites_no_file
from robust_evaluation_tools.robust_utils import remove_covariates_effects_metrics

# ----- Helpers -----
def set_seed(seed: int = 42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(SEED)

device = "cpu"
print("Device:", device)

def show_class_balance(y):
    vals, counts = np.unique(y, return_counts=True)
    for v, c in zip(vals, counts):
        print(f"Classe {int(v)} : {c}")

def plot_curves(train, val, ylabel="Loss"):
    plt.figure(figsize=(6,4))
    epochs = range(1, len(train)+1)
    plt.plot(epochs, train, label="train")
    plt.plot(epochs, val,   label="val")
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.title(f"Courbe {ylabel}")
    plt.legend(); plt.grid(True); plt.show()


In [None]:


# Si disease == "ALL", on fusionne toutes les maladies sans doublons de SID
if disease == "ALL":
    # sids_vus = set()
    # df_total = pd.DataFrame()
    # for maladie in ["AD", "ADHD", "BIP", "MCI", "SCHZ", "TBI"]:
    #     df_raw = pd.read_csv(DATA_DIR / f"{maladie}_combination_all_metrics_CamCAN.csv.gz")
        
    #     # On enlève les SIDs déjà vus
    #     df_filtré = df_raw[~df_raw["sid"].isin(sids_vus)]
        
    #     # On ajoute les nouveaux SIDs à notre set
    #     sids_vus.update(df_filtré["sid"].unique())
        
    #     # On concatène le DataFrame filtré
    #     df_total = pd.concat([df_total, df_filtré], ignore_index=True)
    # df_raw = df_total
    df_raw = pd.read_csv("DONNES_MLP/train_data_all.csv")
    df_raw[~((df_raw['disease'] == 'HC') & (df_raw['old_site'] != 'CamCAN'))]
else:
    df_raw = pd.read_csv(DATA_DIR / f"{disease}_combination_all_metrics_CamCAN.csv.gz")
print("Raw shape:", df_raw.shape)
display(df_raw.head())


In [None]:
# Nettoyage minimal
df_raw = df_raw[~df_raw['bundle'].isin(['left_ventricle', 'right_ventricle'])].copy()
print("Sans ventricules :", df_raw.shape)


In [None]:
def compute_zscore(df, value_col="mean_no_cov"):
    stats = (df.groupby("metric_bundle")[value_col]
               .agg(['mean', 'std'])
               .rename(columns={'mean': 'global_mean', 'std': 'global_std'}))
    stats['global_std'] = stats['global_std'].replace(0, 1e-6)
    df = df.merge(stats, on="metric_bundle", how="left")
    df["zscore"] = (df[value_col] - df["global_mean"]) / df["global_std"]
    return df.drop(columns=["global_mean", "global_std"])


def gen_sites_for_mlp(df):
    sample_sizes = [5,10,20,30,100,150]  # Différentes tailles d'échantillon
    sample_sizes = [30,100,150]  # Différentes tailles d'échantillon
    sample_sizes = [100]  # Différentes tailles d'échantillon
    disease_ratios = [0.03, 0.1, 0.3, 0.5, 0.7, 0.8]  # Différents pourcentages de malades
    num_tests = 20  # Nombre de tests à effectuer pour chaque combinaison
    n_jobs_number=-1
  
    dfs = generate_sites_no_file(sample_sizes, disease_ratios, num_tests, df,  disease=None, n_jobs=n_jobs_number)
    ret = pd.DataFrame()
    for i, df in enumerate(dfs):
        df["sid"] = df["sid"].astype(str) + str(i)
        d = remove_covariates_effects_metrics(df)
        d = compute_zscore(d)
        ret = pd.concat([ret, d], ignore_index=True)
    return ret
        
        
        

In [None]:
# Augment Data

df_train, df_temp = split_train_test(df_raw, test_size=0.2, random_state=None)

df_val, df_test = split_train_test(df_temp, test_size=0.5, random_state=None)

df_train = augment_df(df_train, 5)
df_train = gen_sites_for_mlp(df_train)

df_val = augment_df(df_val, 8)
df_val = gen_sites_for_mlp(df_val)
df_test = augment_df(df_test, 8)
df_test = gen_sites_for_mlp(df_test)

In [None]:
# ----- 3. Feature engineering -----

def build_feature_matrix(df, value_col="zscore", bundle_col="metric_bundle", healthy_tag="HC"):
    features = df.pivot(index="sid", columns=bundle_col, values=value_col)
    label = (df.groupby("sid")["disease"].first().ne(healthy_tag).astype(int))
    mat = features.assign(label=label).reset_index(drop=False)
    return mat

def make_X_Y(df, value_col="zscore"):
    df = compute_zscore(df, value_col="mean_no_cov")
    df_mat = build_feature_matrix(df, value_col=value_col)
    df_mat = df_mat.drop(columns=["sid"])
    X = df_mat.drop(columns="label").values.astype(np.float32)
    y = df_mat["label"].values.astype(np.float32)
    show_class_balance(y)
    return X, y

df_train 
dupes = (df_train
         .groupby(["sid", "metric_bundle"])
         .size()
         .loc[lambda s: s > 1]
         .sort_values(ascending=False))
print(f"Nombre de paires sid / metric_bundle en double : {dupes.shape[0]}")


In [None]:
# ----- 4. Split & normalisation -----
X_train, y_train = make_X_Y(df_train)
X_val, y_val = make_X_Y(df_val)
X_test, y_test = make_X_Y(df_test)

# X_train, X_temp, y_train, y_temp = train_test_split(
#     X, y, test_size=0.5, stratify=y, random_state=SEED)
# X_val, X_test, y_val, y_test = train_test_split(
#     X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=SEED)

# scaler = StandardScaler().fit(X_train)
# X_train = scaler.transform(X_train)
# X_val   = scaler.transform(X_val)
# X_test  = scaler.transform(X_test)

print("Train:", X_train.shape, "Val:", X_val.shape, "Test:", X_test.shape)


In [None]:
# ----- 5. DataLoader -----
class PatientDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

BATCH = TRAIN_CFG.get("batch_size", 64)
train_dl = DataLoader(PatientDataset(X_train, y_train), batch_size=BATCH, shuffle=True)
val_dl   = DataLoader(PatientDataset(X_val,   y_val),   batch_size=BATCH)
test_dl  = DataLoader(PatientDataset(X_test,  y_test),  batch_size=BATCH)


In [None]:
# ----- 6A. Baseline LogisticRegression -----
from sklearn.linear_model import LogisticRegression
baseline = LogisticRegression(max_iter=1000, n_jobs=-1)
baseline.fit(X_train, y_train)
prob_val = baseline.predict_proba(X_val)[:,1]
auc_base = roc_auc_score(y_val, prob_val)
print(f"AUC validation LogisticRegression: {auc_base:.3f}")


In [None]:
# ----- 7. Training helpers (importés) -----
from robust_evaluation_tools.MLP_train import (train_epoch, eval_epoch, fit)
NEG_WEIGHT = TRAIN_CFG.get("neg_weight", 10.0)


In [None]:
# ----- 8. Hyperparameter tuning (Optuna) -----
def objective(trial):
    hidden_dim1 = trial.suggest_int("h1", 128, 512, step=64)
    hidden_dim2 = trial.suggest_int("h2", 64, 256, step=32)
    hidden_dim3 = trial.suggest_int("h3", 32, 128, step=16)
    drop        = 0.5
    lr          = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
    wd          = 1e-3

    model = PatientMLP(hidden_dims=(hidden_dim1, hidden_dim2, hidden_dim3), drop=drop).to(device)
    state, _, _, best_auc = fit(model, train_dl, val_dl,
                                epochs=15, lr=lr, wd=wd,
                                patience=5, run_name="tune")
    return best_auc

study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=30, show_progress_bar=True)

print("Best AUC:", study.best_value)
print("Best params:", study.best_params)


In [None]:
# ----- 9. Entraînement final avec configuration centralisée -----
if USE_OPTUNA:
    best = study.best_params
    model_final = build_mlp_from_config({**MODEL_CFG,
        "hidden_dims": [best["h1"], best["h2"], best["h3"]],
    }).to(device)
    lr = float(best["lr"])
    wd = float(TRAIN_CFG.get("weight_decay", 1e-4))
else:
    model_final = build_mlp_from_config(MODEL_CFG).to(device)
    lr = float(TRAIN_CFG["lr"])
    wd = float(TRAIN_CFG["weight_decay"])

state, train_losses, val_losses, best_auc = fit(
    model_final, train_dl, val_dl,
    epochs=int(TRAIN_CFG["epochs"]), lr=lr, wd=wd,
    patience=int(TRAIN_CFG["patience"]), run_name=RUN_NAME, device=device, neg_weight=float(TRAIN_CFG.get("neg_weight", 10.0))
)


In [None]:
# Courbes d’apprentissage
plot_curves(train_losses, val_losses, ylabel="BCE Loss")


In [None]:
# ----- 11. Évaluation finale sur test -----
_, test_auc, test_f1 = eval_epoch(model_final, test_dl, nn.BCEWithLogitsLoss())
print(f"AUC test: {test_auc:.3f} | F1 test: {test_f1:.3f}")

# Confusion matrix
model_final.eval()
preds, labels = [], []
with torch.no_grad():
    for xb, yb in test_dl:
        preds.append(torch.sigmoid(model_final(xb.to(device))).cpu())
        labels.append(yb)
preds = torch.cat(preds).numpy()
labels= torch.cat(labels).numpy()
ConfusionMatrixDisplay.from_predictions(labels, preds>0.5)
plt.show()


In [None]:
# ----- 12. Sauvegarde -----
torch.save(state, MODEL_DIR / f"{RUN_NAME}_weights.pt")

if USE_OPTUNA:
    params_to_save = {**MODEL_CFG,
        "hidden_dims": [best["h1"], best["h2"], best["h3"]],
        "lr": lr,
        "weight_decay": wd,
    }
else:
    params_to_save = {**MODEL_CFG,
        "lr": lr,
        "weight_decay": wd,
    }

with open(MODEL_DIR / f"{RUN_NAME}_params.json", "w") as fp:
    json.dump(params_to_save, fp, indent=2)
print("Artifacts saved in", MODEL_DIR)


In [None]:
# ----- 13. Exemple d’inférence -----
sample = np.random.rand(430).reshape(1, -1)
with torch.no_grad():
    prob = torch.sigmoid(model_final(torch.tensor(sample, dtype=torch.float32).to(device))).item()
print(f"Probabilité malade: {prob:.3f}")


In [None]:
# ----- 14. Explainability (facultatif) -----
# import shap
# explainer = shap.DeepExplainer(model_final, torch.tensor(X_train[:100]).to(device))
# shap_values = explainer.shap_values(torch.tensor(sample_std).to(device))
# shap.summary_plot(shap_values, features=sample_std)
