In [1]:
# ============================================================
# 0005_OUTPUT_GSSC_INTEGRATION.ipynb
# ETL des prédictions GSSC vers les tables AI
# ============================================================

# %%
import os
import re
import json
import uuid
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple

import psycopg2
from psycopg2.extras import RealDictCursor

# %%
# ------------------------------------------------------------
# 1. Logging
# ------------------------------------------------------------
logger = logging.getLogger("sleeplab_etl_gssc")
if not logger.handlers:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
    )
logger.setLevel(logging.INFO)

# %%
# ------------------------------------------------------------
# 2. Config chemins & Postgres
# ------------------------------------------------------------
BASE_DIR = Path("/workspaces/SleepLab_new").resolve()
MODELS_OUTPUT_DIR = BASE_DIR / "data" / "raw" / "MODELS_OUTPUT"

DEFAULT_GSSC_JSON = MODELS_OUTPUT_DIR / "predictions_gssc_S0001.json"

DB_DSN = os.getenv(
    "SLEEPLAB_DB_DSN",
    "postgresql://sleeplab:sleeplab@postgres:5432/sleeplab",
)

logger.info("DB_DSN utilisé : %s", DB_DSN)


# %%
# ------------------------------------------------------------
# 3. Helpers Postgres locaux
# ------------------------------------------------------------
def get_conn():
    conn = psycopg2.connect(DB_DSN)
    with conn.cursor() as cur:
        cur.execute("SET search_path TO sleeplab")
    conn.commit()
    return conn


def db_cursor(conn, dict_cursor: bool = False):
    cursor_factory = RealDictCursor if dict_cursor else None
    return conn.cursor(cursor_factory=cursor_factory)


# %%
# ------------------------------------------------------------
# 4. Helpers métier : mapping session / run
# ------------------------------------------------------------
def require_file(path: Path, label: str) -> None:
    if not path.exists():
        logger.error("[FICHIER MANQUANT] %s : %s", label, path)
        raise FileNotFoundError(f"{label} non trouvé : {path}")
    if not path.is_file():
        logger.error("[CHEMIN NON FICHIER] %s : %s", label, path)
        raise FileNotFoundError(f"{label} n'est pas un fichier : {path}")
    logger.info("[OK] %s trouvé : %s", label, path)


def extract_subject_suffix_from_filename(json_path: Path) -> Optional[str]:
    """
    Ex: 'predictions_gssc_S0001.json' -> 'S0001'
    """
    m = re.search(r"_S(\d{4})\.json$", json_path.name, re.IGNORECASE)
    if not m:
        logger.warning(
            "Impossible d'extraire un suffixe 'Sdddd' depuis le nom de fichier %s",
            json_path.name,
        )
        return None
    suffix = f"S{m.group(1)}"
    logger.info("Suffixe subject/session détecté à partir du nom de fichier : %s", suffix)
    return suffix


def extract_session_code_from_edf_key(edf_key: str) -> str:
    """
    Ex: 'PANDORE_SAS_DATASET_S0001_PSG.edf' -> 'PANDORE_SAS_DATASET_S0001'
    """
    m = re.match(r"^(?P<session_root>.+_S\d{4})_PSG\.edf$", edf_key)
    if not m:
        raise ValueError(
            f"Impossible d'extraire session_root depuis la clé EDF '{edf_key}'"
        )
    return m.group("session_root")


def get_session_id_from_code(conn, session_code: str) -> str:
    with db_cursor(conn, dict_cursor=True) as cur:
        cur.execute(
            """
            SELECT session_id
            FROM study_session
            WHERE session_code = %s
            """,
            (session_code,),
        )
        row = cur.fetchone()
        if not row:
            raise RuntimeError(
                f"Aucune session trouvée pour session_code={session_code}"
            )
        return row["session_id"]


def get_or_create_ai_model_run(
    conn,
    session_id: str,
    model_name: str,
    model_version: Optional[str],
    training_tag: Optional[str],
    source_file_name: str,
    params_json: Optional[Dict[str, Any]] = None,
) -> str:
    version_key = model_version or ""
    with db_cursor(conn) as cur:
        cur.execute(
            """
            SELECT run_id
            FROM ai_model_run
            WHERE session_id = %s
              AND model_name = %s
              AND COALESCE(model_version,'') = %s
              AND source_file_name = %s
            """,
            (session_id, model_name, version_key, source_file_name),
        )
        row = cur.fetchone()
        if row:
            run_id = row[0]
            logger.info(
                "Run AI GSSC déjà présent (session_id=%s, model=%s, version=%s, src=%s) → run_id=%s",
                session_id,
                model_name,
                model_version,
                source_file_name,
                run_id,
            )
            return run_id

        run_id = str(uuid.uuid4())
        cur.execute(
            """
            INSERT INTO ai_model_run (
                run_id, session_id,
                model_name, model_version,
                training_tag, source_file_name,
                params_json, created_ts
            )
            VALUES (%s,%s,%s,%s,%s,%s,%s, NOW())
            """,
            (
                run_id,
                session_id,
                model_name,
                model_version,
                training_tag,
                source_file_name,
                json.dumps(params_json) if params_json is not None else None,
            ),
        )
        logger.info(
            "Nouveau run AI GSSC inséré : run_id=%s (model=%s, version=%s, session_id=%s)",
            run_id,
            model_name,
            model_version,
            session_id,
        )
        return run_id


# %%
# ------------------------------------------------------------
# 5. Mapping des labels GSSC
# ------------------------------------------------------------
# Hypothèse standard GSSC :
#   0 → W, 1 → N1, 2 → N2, 3 → N3, 4 → R
GSSC_LABEL_TO_STAGE = {
    0: "W",
    1: "N1",
    2: "N2",
    3: "N3",
    4: "R",
}


def gssc_label_int_to_stage(label: int) -> Optional[str]:
    stage = GSSC_LABEL_TO_STAGE.get(label)
    if stage is None:
        logger.warning("Label GSSC inattendu : %s", label)
    return stage


# %%
# ------------------------------------------------------------
# 6. ETL principal GSSC
# ------------------------------------------------------------
def etl_gssc_from_json_file(
    json_path: Path,
    model_name: str = "GSSC",
    model_version: Optional[str] = None,
    training_tag: Optional[str] = None,
    default_epoch_length_sec: int = 30,
) -> Tuple[List[str], int]:
    """
    Intègre un JSON GSSC dans ai_model_run + ai_epoch_prediction.

    Structure attendue (par EDF) :
      {
        "PANDORE_SAS_DATASET_S0001_PSG.edf": {
          "infos": {...},
          "prediction": [0, 0, 1, 2, ...],
          "timestamps": [0.0, 30.0, 60.0, ...]
        },
        ...
      }
    """
    require_file(json_path, f"JSON GSSC {model_name}")
    suffix = extract_subject_suffix_from_filename(json_path)

    logger.info("=== ETL GSSC démarré pour %s ===", json_path.name)
    raw_text = json_path.read_text()
    data = json.loads(raw_text)

    if not isinstance(data, dict):
        raise ValueError(f"Structure JSON inattendue (top-level) : {type(data)}")

    conn = get_conn()
    run_ids: List[str] = []
    total_epochs = 0

    try:
        with conn:
            for edf_key, payload in data.items():
                logger.info("Traitement de la clé EDF : %s", edf_key)

                session_code = extract_session_code_from_edf_key(edf_key)
                logger.info("session_code dérivé de la clé EDF : %s", session_code)

                if suffix and suffix.lower() not in session_code.lower():
                    logger.warning(
                        "Le suffixe '%s' (depuis le nom du fichier JSON) "
                        "ne correspond pas au session_code '%s' (clé EDF).",
                        suffix,
                        session_code,
                    )

                session_id = get_session_id_from_code(conn, session_code)

                if not isinstance(payload, dict):
                    logger.warning(
                        "Payload GSSC inattendu pour %s (type=%s) → ignoré",
                        edf_key,
                        type(payload),
                    )
                    continue

                preds = payload.get("prediction")
                stamps = payload.get("timestamps")

                if preds is None or stamps is None:
                    logger.warning(
                        "Prediction ou timestamps manquants pour %s → ignoré", edf_key
                    )
                    continue

                if len(preds) != len(stamps):
                    logger.warning(
                        "Longueur preds (%d) != timestamps (%d) pour %s. "
                        "On utilisera min(n_preds, n_timestamps).",
                        len(preds),
                        len(stamps),
                        edf_key,
                    )
                n = min(len(preds), len(stamps))

                infos = payload.get("infos", {})
                params_json = {
                    "source_edf_key": edf_key,
                    "model_info": infos,
                    "notes": "GSSC predictions imported from JSON",
                }

                run_id = get_or_create_ai_model_run(
                    conn,
                    session_id=session_id,
                    model_name=model_name,
                    model_version=model_version,
                    training_tag=training_tag,
                    source_file_name=json_path.name,
                    params_json=params_json,
                )
                run_ids.append(run_id)

                # Nettoyage préalable
                with db_cursor(conn) as cur:
                    cur.execute(
                        "DELETE FROM ai_epoch_prediction WHERE run_id = %s",
                        (run_id,),
                    )
                    nb_deleted = cur.rowcount
                if nb_deleted:
                    logger.info(
                        "%d prédiction(s) GSSC existante(s) supprimée(s) pour run_id=%s",
                        nb_deleted,
                        run_id,
                    )

                # Insertion
                inserted = 0
                with db_cursor(conn) as cur:
                    for idx in range(n):
                        label = preds[idx]
                        ts = stamps[idx]

                        try:
                            int_label = int(label)
                        except Exception:
                            logger.warning(
                                "Label non entier à l'index %d : %s", idx, label
                            )
                            continue

                        stage = gssc_label_int_to_stage(int_label)

                        # onset_sec : on privilégie timestamps GSSC, sinon idx*epoch_length
                        try:
                            onset_sec = float(ts)
                        except Exception:
                            onset_sec = float(idx * default_epoch_length_sec)

                        # duration_sec : on approxime par différence de timestamps
                        if idx < n - 1:
                            try:
                                duration_sec = float(stamps[idx + 1]) - float(ts)
                            except Exception:
                                duration_sec = float(default_epoch_length_sec)
                        else:
                            duration_sec = float(default_epoch_length_sec)

                        extra_json = {
                            "gssc_label_int": int_label,
                            "timestamp_raw": ts,
                        }

                        epoch_pred_id = str(uuid.uuid4())
                        cur.execute(
                            """
                            INSERT INTO ai_epoch_prediction (
                                epoch_pred_id, run_id,
                                epoch_index, onset_sec, duration_sec,
                                predicted_stage_code,
                                prob_w, prob_n1, prob_n2, prob_n3, prob_r,
                                extra_json
                            )
                            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                            """,
                            (
                                epoch_pred_id,
                                run_id,
                                int(idx),
                                onset_sec,
                                duration_sec,
                                stage,
                                None,  # prob_w
                                None,  # prob_n1
                                None,  # prob_n2
                                None,  # prob_n3
                                None,  # prob_r
                                json.dumps(extra_json),
                            ),
                        )
                        inserted += 1

                logger.info(
                    "GSSC : %d prédiction(s) epoch insérées pour run_id=%s",
                    inserted,
                    run_id,
                )
                total_epochs += inserted

        logger.info(
            "=== ETL GSSC terminé pour %s : runs=%d, epochs=%d ===",
            json_path.name,
            len(run_ids),
            total_epochs,
        )
        return run_ids, total_epochs

    finally:
        conn.close()


# %%
# ------------------------------------------------------------
# 7. Fonction de contrôle : run_checks_after_etl_gssc
# ------------------------------------------------------------
def log_table_count(conn, table: str):
    with db_cursor(conn) as cur:
        cur.execute(f"SELECT COUNT(*) FROM {table}")
        nb = cur.fetchone()[0]
    logger.info("Table %s : %d ligne(s)", table, nb)


def run_checks_after_etl_gssc(
    json_path: Path,
    model_version: Optional[str] = None,
    training_tag: Optional[str] = None,
):
    """
    Lance l'ETL GSSC puis affiche quelques stats / échantillons.
    """
    run_ids, total_epochs = etl_gssc_from_json_file(
        json_path=json_path,
        model_version=model_version,
        training_tag=training_tag,
    )

    conn = get_conn()
    try:
        logger.info("GSSC : %d run(s) créés/mis à jour, %d epochs insérés", len(run_ids), total_epochs)

        for tbl in ["ai_model_run", "ai_epoch_prediction"]:
            log_table_count(conn, tbl)

        if run_ids:
            last_run = run_ids[-1]
            with db_cursor(conn, dict_cursor=True) as cur:
                cur.execute(
                    """
                    SELECT *
                    FROM ai_model_run
                    WHERE run_id = %s
                    """,
                    (last_run,),
                )
                logger.info("Sample ai_model_run(last_run) : %s", cur.fetchall())

            with db_cursor(conn, dict_cursor=True) as cur:
                cur.execute(
                    """
                    SELECT epoch_index, onset_sec, duration_sec,
                           predicted_stage_code
                    FROM ai_epoch_prediction
                    WHERE run_id = %s
                    ORDER BY epoch_index
                    LIMIT 10
                    """,
                    (last_run,),
                )
                logger.info(
                    "Sample ai_epoch_prediction(last_run, 10 premiers) : %s",
                    cur.fetchall(),
                )
    finally:
        conn.close()

2025-12-08 09:02:53,585 [INFO] sleeplab_etl_gssc - DB_DSN utilisé : postgresql://sleeplab:sleeplab@postgres:5432/sleeplab


In [2]:
# %%
# Exemple d'appel dans le notebook :
from pathlib import Path
gssc_json = MODELS_OUTPUT_DIR / "predictions_gssc_S0001.json"
run_checks_after_etl_gssc(json_path=gssc_json, model_version="v1", training_tag=None)

2025-12-08 09:03:01,015 [INFO] sleeplab_etl_gssc - [OK] JSON GSSC GSSC trouvé : /workspaces/SleepLab_new/data/raw/MODELS_OUTPUT/predictions_gssc_S0001.json
2025-12-08 09:03:01,018 [INFO] sleeplab_etl_gssc - Suffixe subject/session détecté à partir du nom de fichier : S0001
2025-12-08 09:03:01,018 [INFO] sleeplab_etl_gssc - === ETL GSSC démarré pour predictions_gssc_S0001.json ===
2025-12-08 09:03:01,034 [INFO] sleeplab_etl_gssc - Traitement de la clé EDF : PANDORE_SAS_DATASET_S0001_PSG.edf
2025-12-08 09:03:01,035 [INFO] sleeplab_etl_gssc - session_code dérivé de la clé EDF : PANDORE_SAS_DATASET_S0001
2025-12-08 09:03:01,039 [INFO] sleeplab_etl_gssc - Nouveau run AI GSSC inséré : run_id=161cc47e-6345-4875-bfd5-a1754a20445f (model=GSSC, version=v1, session_id=0ee866aa-e9db-4af5-bc70-89e99827d320)
2025-12-08 09:03:01,269 [INFO] sleeplab_etl_gssc - GSSC : 1245 prédiction(s) epoch insérées pour run_id=161cc47e-6345-4875-bfd5-a1754a20445f
2025-12-08 09:03:01,273 [INFO] sleeplab_etl_gssc - ==