In [None]:
# ============================================================
# 0004_OUTPUT_YASA_INTEGRATION.ipynb
# ETL des prédictions YASA vers les tables AI
# ============================================================

# %%
import os
import re
import json
import uuid
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple

import psycopg2
from psycopg2.extras import RealDictCursor

# %%
# ------------------------------------------------------------
# 1. Logging
# ------------------------------------------------------------
logger = logging.getLogger("sleeplab_etl_yasa")
if not logger.handlers:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
    )
logger.setLevel(logging.INFO)

# %%
# ------------------------------------------------------------
# 2. Config chemins & Postgres
# ------------------------------------------------------------
BASE_DIR = Path("/workspaces/SleepLab_new").resolve()
MODELS_OUTPUT_DIR = BASE_DIR / "data" / "raw" / "MODELS_OUTPUT"

# Exemple de fichier (à adapter si besoin)
DEFAULT_YASA_JSON = MODELS_OUTPUT_DIR / "predictions_yasa_S0001.json"

DB_DSN = os.getenv(
    "SLEEPLAB_DB_DSN",
    "postgresql://sleeplab:sleeplab@postgres:5432/sleeplab",
)

logger.info("DB_DSN utilisé : %s", DB_DSN)


# %%
# ------------------------------------------------------------
# 3. Helpers Postgres locaux
# ------------------------------------------------------------
def get_conn():
    """
    Retourne une connexion Postgres avec search_path fixé sur 'sleeplab'.
    On ouvre une connexion par notebook, simple pour l'instant.
    """
    conn = psycopg2.connect(DB_DSN)
    with conn.cursor() as cur:
        cur.execute("SET search_path TO sleeplab")
    conn.commit()
    return conn


def db_cursor(conn, dict_cursor: bool = False):
    """
    Helper pour obtenir un curseur.
    - dict_cursor=True → RealDictCursor
    """
    cursor_factory = RealDictCursor if dict_cursor else None
    return conn.cursor(cursor_factory=cursor_factory)


# %%
# ------------------------------------------------------------
# 4. Helpers métier : mapping session / run
# ------------------------------------------------------------
def require_file(path: Path, label: str) -> None:
    if not path.exists():
        logger.error("[FICHIER MANQUANT] %s : %s", label, path)
        raise FileNotFoundError(f"{label} non trouvé : {path}")
    if not path.is_file():
        logger.error("[CHEMIN NON FICHIER] %s : %s", label, path)
        raise FileNotFoundError(f"{label} n'est pas un fichier : {path}")
    logger.info("[OK] %s trouvé : %s", label, path)


def extract_subject_suffix_from_filename(json_path: Path) -> Optional[str]:
    """
    Ex: 'predictions_yasa_S0001.json' -> 'S0001'
    (utile pour contrôle/cohérence, mais la vraie source est le nom EDF).
    """
    m = re.search(r"_S(\d{4})\.json$", json_path.name, re.IGNORECASE)
    if not m:
        logger.warning(
            "Impossible d'extraire un suffixe 'Sdddd' depuis le nom de fichier %s",
            json_path.name,
        )
        return None
    suffix = f"S{m.group(1)}"
    logger.info("Suffixe subject/session détecté à partir du nom de fichier : %s", suffix)
    return suffix


def extract_session_code_from_edf_key(edf_key: str) -> str:
    """
    Exemple :
      'PANDORE_SAS_DATASET_S0001_PSG.edf' → 'PANDORE_SAS_DATASET_S0001'
    """
    m = re.match(r"^(?P<session_root>.+_S\d{4})_PSG\.edf$", edf_key)
    if not m:
        raise ValueError(
            f"Impossible d'extraire session_root depuis la clé EDF '{edf_key}'"
        )
    return m.group("session_root")


def get_session_id_from_code(conn, session_code: str) -> str:
    """
    SELECT session_id FROM study_session WHERE session_code = %s
    """
    with db_cursor(conn, dict_cursor=True) as cur:
        cur.execute(
            """
            SELECT session_id
            FROM study_session
            WHERE session_code = %s
            """,
            (session_code,),
        )
        row = cur.fetchone()
        if not row:
            raise RuntimeError(
                f"Aucune session trouvée pour session_code={session_code}"
            )
        return row["session_id"]


def get_or_create_ai_model_run(
    conn,
    session_id: str,
    model_name: str,
    model_version: Optional[str],
    training_tag: Optional[str],
    source_file_name: str,
    params_json: Optional[Dict[str, Any]] = None,
) -> str:
    """
    Crée ou récupère un run d'IA.

    Hypothèse : on a une contrainte d'unicité (session_id, model_name, model_version, source_file_name)
    ou équivalent dans le script SQL d'amendement.
    """
    version_key = model_version or ""
    with db_cursor(conn) as cur:
        cur.execute(
            """
            SELECT run_id
            FROM ai_model_run
            WHERE session_id = %s
              AND model_name = %s
              AND COALESCE(model_version,'') = %s
              AND source_file_name = %s
            """,
            (session_id, model_name, version_key, source_file_name),
        )
        row = cur.fetchone()
        if row:
            run_id = row[0]
            logger.info(
                "Run AI déjà présent (session_id=%s, model=%s, version=%s, src=%s) → run_id=%s",
                session_id,
                model_name,
                model_version,
                source_file_name,
                run_id,
            )
            return run_id

        run_id = str(uuid.uuid4())
        cur.execute(
            """
            INSERT INTO ai_model_run (
                run_id, session_id,
                model_name, model_version,
                training_tag, source_file_name,
                params_json, created_ts
            )
            VALUES (%s,%s,%s,%s,%s,%s,%s, NOW())
            """,
            (
                run_id,
                session_id,
                model_name,
                model_version,
                training_tag,
                source_file_name,
                json.dumps(params_json) if params_json is not None else None,
            ),
        )
        logger.info(
            "Nouveau run AI inséré : run_id=%s (model=%s, version=%s, session_id=%s)",
            run_id,
            model_name,
            model_version,
            session_id,
        )
        return run_id


# %%
# ------------------------------------------------------------
# 5. Parsing des prédictions YASA
# ------------------------------------------------------------
STAGE_ORDER = ["W", "N1", "N2", "N3", "R"]  # ordre pour l'argmax


def parse_yasa_prediction_payload(pred_payload: Any) -> Dict[int, Dict[str, float]]:
    """
    pred_payload peut être :
      - un dict { "N1": {"0": prob, ...}, "N2": {...}, ...}
      - OU une string JSON de ce dict.

    Retour :
      dict[epoch_index] -> dict[stage_code] -> prob
      ex: {0: {"N1": 0.01, "N2": 0.93, ...}, 1: {...}, ...}
    """
    if isinstance(pred_payload, str):
        try:
            pred_dict = json.loads(pred_payload)
        except json.JSONDecodeError as e:
            raise ValueError(f"Impossible de parser prediction YASA (string) : {e}")
    elif isinstance(pred_payload, dict):
        pred_dict = pred_payload
    else:
        raise ValueError(f"Type inattendu pour 'prediction' YASA : {type(pred_payload)}")

    per_epoch: Dict[int, Dict[str, float]] = {}
    for stage_label, epoch_map in pred_dict.items():
        # On normalise le label en code standard si besoin
        stage = stage_label.upper()
        if stage == "N0":  # au cas où
            stage = "W"

        if stage not in STAGE_ORDER:
            # On loggue mais on ne bloque pas
            logger.warning("Stade inattendu dans YASA : '%s'", stage_label)
            continue

        if not isinstance(epoch_map, dict):
            logger.warning(
                "epoch_map inattendu pour stage=%s (type=%s)", stage, type(epoch_map)
            )
            continue

        for epoch_idx_str, prob in epoch_map.items():
            try:
                epoch_idx = int(epoch_idx_str)
            except Exception:
                logger.warning(
                    "Index d'epoch non entier dans YASA : key=%s", epoch_idx_str
                )
                continue

            try:
                p = float(prob)
            except Exception:
                logger.warning(
                    "Probabilité non numérique pour epoch=%s, stage=%s : %s",
                    epoch_idx,
                    stage,
                    prob,
                )
                continue

            per_epoch.setdefault(epoch_idx, {})[stage] = p

    return per_epoch


def compute_predicted_stage(probs_by_stage: Dict[str, float]) -> Optional[str]:
    """
    Retourne le stage avec la probabilité max dans l'ordre STAGE_ORDER.
    Si dict vide → None.
    """
    if not probs_by_stage:
        return None
    best_stage = None
    best_prob = -1.0
    for stage in STAGE_ORDER:
        p = probs_by_stage.get(stage)
        if p is None:
            continue
        if p > best_prob:
            best_prob = p
            best_stage = stage
    return best_stage


# %%
# ------------------------------------------------------------
# 6. ETL principal YASA
# ------------------------------------------------------------
def etl_yasa_from_json_file(
    json_path: Path,
    model_name: str = "YASA",
    model_version: Optional[str] = None,
    training_tag: Optional[str] = None,
    epoch_length_sec: int = 30,
) -> Tuple[List[str], int]:
    """
    Intègre un fichier JSON YASA dans ai_model_run + ai_epoch_prediction.

    Retourne :
      - liste des run_id créés
      - nombre total d'epochs insérés
    """
    require_file(json_path, f"JSON YASA {model_name}")
    suffix = extract_subject_suffix_from_filename(json_path)

    logger.info("=== ETL YASA démarré pour %s ===", json_path.name)
    raw_text = json_path.read_text()
    data = json.loads(raw_text)

    if not isinstance(data, dict):
        raise ValueError(f"Structure JSON inattendue (top-level) : {type(data)}")

    conn = get_conn()
    total_epochs = 0
    run_ids: List[str] = []

    try:
        with conn:
            for edf_key, payload in data.items():
                logger.info("Traitement de la clé EDF : %s", edf_key)

                session_code = extract_session_code_from_edf_key(edf_key)
                logger.info("session_code dérivé de la clé EDF : %s", session_code)

                if suffix and suffix.lower() not in session_code.lower():
                    logger.warning(
                        "Le suffixe '%s' (depuis le nom du fichier JSON) "
                        "ne correspond pas au session_code '%s' (clé EDF).",
                        suffix,
                        session_code,
                    )

                session_id = get_session_id_from_code(conn, session_code)

                # payload attendu : dict avec au moins 'prediction'
                if not isinstance(payload, dict):
                    logger.warning(
                        "Payload inattendu pour %s (type=%s) → ignoré",
                        edf_key,
                        type(payload),
                    )
                    continue

                pred_payload = payload.get("prediction")
                if pred_payload is None:
                    logger.warning(
                        "Aucun champ 'prediction' pour %s → ignoré", edf_key
                    )
                    continue

                per_epoch = parse_yasa_prediction_payload(pred_payload)
                logger.info(
                    "YASA : %d epochs détectés pour session_code=%s",
                    len(per_epoch),
                    session_code,
                )

                # Création du run
                params_json = {
                    "source_edf_key": edf_key,
                    "notes": "YASA predictions imported from JSON",
                }
                run_id = get_or_create_ai_model_run(
                    conn,
                    session_id=session_id,
                    model_name=model_name,
                    model_version=model_version,
                    training_tag=training_tag,
                    source_file_name=json_path.name,
                    params_json=params_json,
                )
                run_ids.append(run_id)

                # Nettoyage préalable : on supprime les prédictions existantes de ce run
                with db_cursor(conn) as cur:
                    cur.execute(
                        "DELETE FROM ai_epoch_prediction WHERE run_id = %s",
                        (run_id,),
                    )
                    nb_deleted = cur.rowcount
                if nb_deleted:
                    logger.info(
                        "%d prédiction(s) existante(s) supprimée(s) pour run_id=%s",
                        nb_deleted,
                        run_id,
                    )

                # Insertion des epochs
                inserted = 0
                with db_cursor(conn) as cur:
                    for epoch_idx in sorted(per_epoch.keys()):
                        probs = per_epoch[epoch_idx]
                        pred_stage = compute_predicted_stage(probs)

                        onset_sec = float(epoch_idx * epoch_length_sec)
                        duration_sec = float(epoch_length_sec)

                        extra_json = {
                            "all_probs": probs,
                            "epoch_index_source": "YASA",
                        }

                        # On mappe sur les colonnes prob_w, prob_n1, etc.
                        prob_w = probs.get("W")
                        prob_n1 = probs.get("N1")
                        prob_n2 = probs.get("N2")
                        prob_n3 = probs.get("N3")
                        prob_r = probs.get("R")

                        epoch_pred_id = str(uuid.uuid4())
                        cur.execute(
                            """
                            INSERT INTO ai_epoch_prediction (
                                epoch_pred_id, run_id,
                                epoch_index, onset_sec, duration_sec,
                                predicted_stage_code,
                                prob_w, prob_n1, prob_n2, prob_n3, prob_r,
                                extra_json
                            )
                            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                            """,
                            (
                                epoch_pred_id,
                                run_id,
                                int(epoch_idx),
                                onset_sec,
                                duration_sec,
                                pred_stage,
                                prob_w,
                                prob_n1,
                                prob_n2,
                                prob_n3,
                                prob_r,
                                json.dumps(extra_json),
                            ),
                        )
                        inserted += 1

                logger.info(
                    "YASA : %d prédiction(s) epoch insérées pour run_id=%s",
                    inserted,
                    run_id,
                )
                total_epochs += inserted

        logger.info(
            "=== ETL YASA terminé pour %s : runs=%d, epochs=%d ===",
            json_path.name,
            len(run_ids),
            total_epochs,
        )
        return run_ids, total_epochs

    finally:
        conn.close()


# %%
# ------------------------------------------------------------
# 7. Fonction de contrôle : run_checks_after_etl_yasa
# ------------------------------------------------------------
def log_table_count(conn, table: str):
    with db_cursor(conn) as cur:
        cur.execute(f"SELECT COUNT(*) FROM {table}")
        nb = cur.fetchone()[0]
    logger.info("Table %s : %d ligne(s)", table, nb)


def run_checks_after_etl_yasa(
    json_path: Path,
    model_version: Optional[str] = None,
    training_tag: Optional[str] = None,
):
    """
    Lance l'ETL YASA puis affiche quelques stats / échantillons.
    """
    run_ids, total_epochs = etl_yasa_from_json_file(
        json_path=json_path,
        model_version=model_version,
        training_tag=training_tag,
    )

    conn = get_conn()
    try:
        logger.info("YASA : %d run(s) créés/mis à jour, %d epochs insérés", len(run_ids), total_epochs)

        # Compteurs globaux AI
        for tbl in ["ai_model_run", "ai_epoch_prediction"]:
            log_table_count(conn, tbl)

        # Samples pour le dernier run
        if run_ids:
            last_run = run_ids[-1]
            with db_cursor(conn, dict_cursor=True) as cur:
                cur.execute(
                    """
                    SELECT *
                    FROM ai_model_run
                    WHERE run_id = %s
                    """,
                    (last_run,),
                )
                logger.info("Sample ai_model_run(last_run) : %s", cur.fetchall())

            with db_cursor(conn, dict_cursor=True) as cur:
                cur.execute(
                    """
                    SELECT epoch_index, onset_sec, duration_sec,
                           predicted_stage_code, prob_w, prob_n1, prob_n2, prob_n3, prob_r
                    FROM ai_epoch_prediction
                    WHERE run_id = %s
                    ORDER BY epoch_index
                    LIMIT 10
                    """,
                    (last_run,),
                )
                logger.info(
                    "Sample ai_epoch_prediction(last_run, 10 premiers) : %s",
                    cur.fetchall(),
                )
    finally:
        conn.close()

: 

In [None]:
# %%
# Exemple d'appel dans le notebook :
from pathlib import Path
yasa_json = MODELS_OUTPUT_DIR / "predictions_yasa_S0001.json"
run_checks_after_etl_yasa(json_path=yasa_json, model_version="v1", training_tag=None)