In [1]:
# ============================================================
# 0006_OUTPUT_LUNA_INTEGRATION.py
# ETL des prédictions LUNA vers les tables AI
# (ai_model_run + ai_epoch_prediction)
# ============================================================

import os
import re
import json
import uuid
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple

import psycopg2
from psycopg2.extras import RealDictCursor

# ------------------------------------------------------------
# 1. Logging
# ------------------------------------------------------------
logger = logging.getLogger("sleeplab_etl_luna")
if not logger.handlers:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
    )
logger.setLevel(logging.INFO)

# ------------------------------------------------------------
# 2. Config chemins & Postgres
# ------------------------------------------------------------
BASE_DIR = Path("/workspaces/SleepLab_new").resolve()
MODELS_OUTPUT_DIR = BASE_DIR / "data" / "raw" / "MODELS_OUTPUT"

DEFAULT_LUNA_JSON = MODELS_OUTPUT_DIR / "predictions_luna_S0001.json"

DB_DSN = os.getenv(
    "SLEEPLAB_DB_DSN",
    "postgresql://sleeplab:sleeplab@postgres:5432/sleeplab",
)
logger.info("DB_DSN utilisé : %s", DB_DSN)

# ------------------------------------------------------------
# 3. Helpers Postgres
# ------------------------------------------------------------
def get_conn():
    conn = psycopg2.connect(DB_DSN)
    with conn.cursor() as cur:
        cur.execute("SET search_path TO sleeplab")
    conn.commit()
    return conn


def db_cursor(conn, dict_cursor: bool = False):
    cursor_factory = RealDictCursor if dict_cursor else None
    return conn.cursor(cursor_factory=cursor_factory)


# ------------------------------------------------------------
# 4. Helpers métier : mapping session / run
# ------------------------------------------------------------
def require_file(path: Path, label: str) -> None:
    if not path.exists():
        logger.error("[FICHIER MANQUANT] %s : %s", label, path)
        raise FileNotFoundError(f"{label} non trouvé : {path}")
    if not path.is_file():
        logger.error("[CHEMIN NON FICHIER] %s : %s", label, path)
        raise FileNotFoundError(f"{label} n'est pas un fichier : {path}")
    logger.info("[OK] %s trouvé : %s", label, path)


def extract_subject_suffix_from_filename(json_path: Path) -> Optional[str]:
    """
    Ex: 'predictions_luna_S0001.json' -> 'S0001'
    """
    m = re.search(r"_S(\d{4})\.json$", json_path.name, re.IGNORECASE)
    if not m:
        logger.warning(
            "Impossible d'extraire un suffixe 'Sdddd' depuis le nom de fichier %s",
            json_path.name,
        )
        return None
    suffix = f"S{m.group(1)}"
    logger.info("Suffixe subject/session détecté à partir du nom de fichier : %s", suffix)
    return suffix


def extract_session_code_from_edf_key(edf_key: str) -> str:
    """
    Ex: 'PANDORE_SAS_DATASET_S0001_PSG.edf' -> 'PANDORE_SAS_DATASET_S0001'
    """
    m = re.match(r"^(?P<session_root>.+_S\d{4})_PSG\.edf$", edf_key)
    if not m:
        raise ValueError(
            f"Impossible d'extraire session_root depuis la clé EDF '{edf_key}'"
        )
    return m.group("session_root")


def get_session_id_from_code(conn, session_code: str) -> str:
    with db_cursor(conn, dict_cursor=True) as cur:
        cur.execute(
            """
            SELECT session_id
            FROM study_session
            WHERE session_code = %s
            """,
            (session_code,),
        )
        row = cur.fetchone()
        if not row:
            raise RuntimeError(
                f"Aucune session trouvée pour session_code={session_code}"
            )
        return row["session_id"]


def get_or_create_ai_model_run(
    conn,
    session_id: str,
    model_name: str,
    model_version: Optional[str],
    training_tag: Optional[str],
    source_file_name: str,
    params_json: Optional[Dict[str, Any]] = None,
) -> str:
    version_key = model_version or ""
    with db_cursor(conn) as cur:
        cur.execute(
            """
            SELECT run_id
            FROM ai_model_run
            WHERE session_id = %s
              AND model_name = %s
              AND COALESCE(model_version,'') = %s
              AND source_file_name = %s
            """,
            (session_id, model_name, version_key, source_file_name),
        )
        row = cur.fetchone()
        if row:
            run_id = row[0]
            logger.info(
                "Run AI déjà présent (session_id=%s, model=%s, version=%s, src=%s) → run_id=%s",
                session_id,
                model_name,
                model_version,
                source_file_name,
                run_id,
            )
            return run_id

        run_id = str(uuid.uuid4())
        cur.execute(
            """
            INSERT INTO ai_model_run (
                run_id, session_id,
                model_name, model_version,
                training_tag, source_file_name,
                params_json, created_ts
            )
            VALUES (%s,%s,%s,%s,%s,%s,%s, NOW())
            """,
            (
                run_id,
                session_id,
                model_name,
                model_version,
                training_tag,
                source_file_name,
                json.dumps(params_json) if params_json is not None else None,
            ),
        )
        logger.info(
            "Nouveau run AI inséré : run_id=%s (model=%s, version=%s, session_id=%s)",
            run_id,
            model_name,
            model_version,
            session_id,
        )
        return run_id


# ------------------------------------------------------------
# 5. Mapping LUNA -> stage
# ------------------------------------------------------------
LUNA_STAGE_KEYS = [
    ("W",  "prob_Wake"),
    ("N1", "prob_N1"),
    ("N2", "prob_N2"),
    ("N3", "prob_N3"),
    ("R",  "prob_REM"),
]


def luna_pick_stage_from_probs(prob_row: Dict[str, float]) -> Optional[str]:
    """
    prob_row contient: {'prob_Wake': x, 'prob_N1': y, ...}
    Retourne le stage_code correspondant au max.
    """
    best_stage = None
    best_p = None
    for stage_code, k in LUNA_STAGE_KEYS:
        p = prob_row.get(k)
        if p is None:
            continue
        try:
            p = float(p)
        except Exception:
            continue
        if best_p is None or p > best_p:
            best_p = p
            best_stage = stage_code
    return best_stage


# ------------------------------------------------------------
# 6. ETL principal LUNA
# ------------------------------------------------------------
def etl_luna_from_json_file(
    json_path: Path,
    model_name: str = "LUNA",
    model_version: Optional[str] = None,
    training_tag: Optional[str] = None,
    default_epoch_length_sec: int = 30,
) -> Tuple[List[str], int]:
    """
    Intègre un JSON LUNA dans ai_model_run + ai_epoch_prediction.

    Structure attendue (par EDF) :
      {
        "PANDORE_SAS_DATASET_S0001_PSG.edf": {
          "infos": {...},
          "used_channel": ...,
          "timestamps": [0, 30, 60, ...],
          "predictions": "{\"prob_Wake\": {\"1\":0.1,...}, \"prob_N1\": {...}, ...}"
        }
      }
    """
    require_file(json_path, f"JSON LUNA {model_name}")
    suffix = extract_subject_suffix_from_filename(json_path)

    logger.info("=== ETL LUNA démarré pour %s ===", json_path.name)
    data = json.loads(json_path.read_text())

    if not isinstance(data, dict):
        raise ValueError(f"Structure JSON inattendue (top-level) : {type(data)}")

    conn = get_conn()
    run_ids: List[str] = []
    total_epochs = 0

    try:
        with conn:
            for edf_key, payload in data.items():
                logger.info("Traitement de la clé EDF : %s", edf_key)

                session_code = extract_session_code_from_edf_key(edf_key)
                logger.info("session_code dérivé de la clé EDF : %s", session_code)

                if suffix and suffix.lower() not in session_code.lower():
                    logger.warning(
                        "Le suffixe '%s' (depuis le nom du fichier JSON) "
                        "ne correspond pas au session_code '%s' (clé EDF).",
                        suffix,
                        session_code,
                    )

                session_id = get_session_id_from_code(conn, session_code)

                if not isinstance(payload, dict):
                    logger.warning(
                        "Payload LUNA inattendu pour %s (type=%s) → ignoré",
                        edf_key,
                        type(payload),
                    )
                    continue

                stamps = payload.get("timestamps")
                pred_str = payload.get("predictions")

                if stamps is None or pred_str is None:
                    logger.warning(
                        "timestamps ou predictions manquants pour %s → ignoré", edf_key
                    )
                    continue

                # predictions est un JSON string -> dict
                try:
                    pred_dict = json.loads(pred_str) if isinstance(pred_str, str) else pred_str
                except Exception as e:
                    logger.error("Impossible de parser payload['predictions'] (json string) : %s", e)
                    continue

                # On attend ces 5 dicts
                missing_keys = [k for _, k in LUNA_STAGE_KEYS if k not in pred_dict]
                if missing_keys:
                    logger.warning(
                        "Clés de proba manquantes dans predictions (%s) pour %s → ignoré",
                        missing_keys,
                        edf_key,
                    )
                    continue

                n = len(stamps)

                infos = payload.get("infos", {})
                used_channel = payload.get("used_channel")

                params_json = {
                    "source_edf_key": edf_key,
                    "model_info": infos,
                    "used_channel": used_channel,
                    "notes": "LUNA predictions imported from JSON (argmax over probabilities)",
                }

                run_id = get_or_create_ai_model_run(
                    conn,
                    session_id=session_id,
                    model_name=model_name,
                    model_version=model_version,
                    training_tag=training_tag,
                    source_file_name=json_path.name,
                    params_json=params_json,
                )
                run_ids.append(run_id)

                # Nettoyage préalable (idempotence)
                with db_cursor(conn) as cur:
                    cur.execute(
                        "DELETE FROM ai_epoch_prediction WHERE run_id = %s",
                        (run_id,),
                    )
                    nb_deleted = cur.rowcount
                if nb_deleted:
                    logger.info(
                        "%d prédiction(s) LUNA existante(s) supprimée(s) pour run_id=%s",
                        nb_deleted,
                        run_id,
                    )

                # Insertion
                inserted = 0
                with db_cursor(conn) as cur:
                    for idx in range(n):
                        # timestamps (onset_sec)
                        ts = stamps[idx]
                        try:
                            onset_sec = float(ts)
                        except Exception:
                            onset_sec = float(idx * default_epoch_length_sec)

                        # duration
                        if idx < n - 1:
                            try:
                                duration_sec = float(stamps[idx + 1]) - float(ts)
                            except Exception:
                                duration_sec = float(default_epoch_length_sec)
                        else:
                            duration_sec = float(default_epoch_length_sec)

                        # LUNA probas sont indexées en "1..N"
                        key1 = str(idx + 1)

                        def _get_prob(prob_key: str) -> Optional[float]:
                            d = pred_dict.get(prob_key) or {}
                            v = d.get(key1)
                            if v is None:
                                return None
                            try:
                                return float(v)
                            except Exception:
                                return None

                        prob_w  = _get_prob("prob_Wake")
                        prob_n1 = _get_prob("prob_N1")
                        prob_n2 = _get_prob("prob_N2")
                        prob_n3 = _get_prob("prob_N3")
                        prob_r  = _get_prob("prob_REM")

                        prob_row = {
                            "prob_Wake": prob_w,
                            "prob_N1": prob_n1,
                            "prob_N2": prob_n2,
                            "prob_N3": prob_n3,
                            "prob_REM": prob_r,
                        }
                        stage = luna_pick_stage_from_probs(prob_row)

                        extra_json = {
                            "luna_epoch_key": key1,
                            "timestamp_raw": ts,
                            "used_channel": used_channel,
                        }

                        epoch_pred_id = str(uuid.uuid4())
                        cur.execute(
                            """
                            INSERT INTO ai_epoch_prediction (
                                epoch_pred_id, run_id,
                                epoch_index, onset_sec, duration_sec,
                                predicted_stage_code,
                                prob_w, prob_n1, prob_n2, prob_n3, prob_r,
                                extra_json
                            )
                            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                            """,
                            (
                                epoch_pred_id,
                                run_id,
                                int(idx),
                                onset_sec,
                                duration_sec,
                                stage,
                                prob_w,
                                prob_n1,
                                prob_n2,
                                prob_n3,
                                prob_r,
                                json.dumps(extra_json),
                            ),
                        )
                        inserted += 1

                logger.info(
                    "LUNA : %d prédiction(s) epoch insérées pour run_id=%s",
                    inserted,
                    run_id,
                )
                total_epochs += inserted

        logger.info(
            "=== ETL LUNA terminé pour %s : runs=%d, epochs=%d ===",
            json_path.name,
            len(run_ids),
            total_epochs,
        )
        return run_ids, total_epochs

    finally:
        conn.close()


# ------------------------------------------------------------
# 7. Run checks
# ------------------------------------------------------------
def log_table_count(conn, table: str):
    with db_cursor(conn) as cur:
        cur.execute(f"SELECT COUNT(*) FROM {table}")
        nb = cur.fetchone()[0]
    logger.info("Table %s : %d ligne(s)", table, nb)


def run_checks_after_etl_luna(
    json_path: Path,
    model_version: Optional[str] = None,
    training_tag: Optional[str] = None,
):
    run_ids, total_epochs = etl_luna_from_json_file(
        json_path=json_path,
        model_version=model_version,
        training_tag=training_tag,
    )

    conn = get_conn()
    try:
        logger.info(
            "LUNA : %d run(s) créés/mis à jour, %d epochs insérés",
            len(run_ids),
            total_epochs,
        )

        for tbl in ["ai_model_run", "ai_epoch_prediction"]:
            log_table_count(conn, tbl)

        if run_ids:
            last_run = run_ids[-1]
            with db_cursor(conn, dict_cursor=True) as cur:
                cur.execute(
                    """
                    SELECT *
                    FROM ai_model_run
                    WHERE run_id = %s
                    """,
                    (last_run,),
                )
                logger.info("Sample ai_model_run(last_run) : %s", cur.fetchall())

            with db_cursor(conn, dict_cursor=True) as cur:
                cur.execute(
                    """
                    SELECT epoch_index, onset_sec, duration_sec,
                           predicted_stage_code,
                           prob_w, prob_n1, prob_n2, prob_n3, prob_r
                    FROM ai_epoch_prediction
                    WHERE run_id = %s
                    ORDER BY epoch_index
                    LIMIT 10
                    """,
                    (last_run,),
                )
                logger.info(
                    "Sample ai_epoch_prediction(last_run, 10 premiers) : %s",
                    cur.fetchall(),
                )
    finally:
        conn.close()


if __name__ == "__main__":
    # Par défaut: predictions_luna_S0001.json
    run_checks_after_etl_luna(DEFAULT_LUNA_JSON)


2025-12-17 08:42:46,276 [INFO] sleeplab_etl_luna - DB_DSN utilisé : postgresql://sleeplab:sleeplab@localhost:5432/sleeplab


2025-12-17 08:42:46,298 [INFO] sleeplab_etl_luna - [OK] JSON LUNA LUNA trouvé : /workspaces/SleepLab_new/data/raw/MODELS_OUTPUT/predictions_luna_S0001.json


2025-12-17 08:42:46,304 [INFO] sleeplab_etl_luna - Suffixe subject/session détecté à partir du nom de fichier : S0001


2025-12-17 08:42:46,310 [INFO] sleeplab_etl_luna - === ETL LUNA démarré pour predictions_luna_S0001.json ===


2025-12-17 08:42:46,388 [INFO] sleeplab_etl_luna - Traitement de la clé EDF : PANDORE_SAS_DATASET_S0001_PSG.edf


2025-12-17 08:42:46,392 [INFO] sleeplab_etl_luna - session_code dérivé de la clé EDF : PANDORE_SAS_DATASET_S0001


2025-12-17 08:42:46,407 [INFO] sleeplab_etl_luna - Nouveau run AI inséré : run_id=c0ea9467-461d-4f90-8b10-7edd299d2de3 (model=LUNA, version=None, session_id=902f9f18-3ff8-47e9-a357-b5ea9a097e34)


2025-12-17 08:42:47,199 [INFO] sleeplab_etl_luna - LUNA : 1245 prédiction(s) epoch insérées pour run_id=c0ea9467-461d-4f90-8b10-7edd299d2de3


2025-12-17 08:42:47,228 [INFO] sleeplab_etl_luna - === ETL LUNA terminé pour predictions_luna_S0001.json : runs=1, epochs=1245 ===


2025-12-17 08:42:47,246 [INFO] sleeplab_etl_luna - LUNA : 1 run(s) créés/mis à jour, 1245 epochs insérés


2025-12-17 08:42:47,248 [INFO] sleeplab_etl_luna - Table ai_model_run : 3 ligne(s)


2025-12-17 08:42:47,250 [INFO] sleeplab_etl_luna - Table ai_epoch_prediction : 3735 ligne(s)


2025-12-17 08:42:47,252 [INFO] sleeplab_etl_luna - Sample ai_model_run(last_run) : [RealDictRow({'run_id': 'c0ea9467-461d-4f90-8b10-7edd299d2de3', 'session_id': '902f9f18-3ff8-47e9-a357-b5ea9a097e34', 'model_name': 'LUNA', 'model_version': None, 'training_tag': None, 'source_file_name': 'predictions_luna_S0001.json', 'params_json': {'notes': 'LUNA predictions imported from JSON (argmax over probabilities)', 'model_info': {'chs': [{'cal': 1.0, 'loc': '[nan nan nan nan nan nan nan nan nan nan nan nan]', 'kind': 2, 'unit': 107, 'logno': 1, 'range': 1.0, 'scanno': 1, 'ch_name': 'Activity', 'unit_mul': 0, 'coil_type': 1, 'coord_frame': 4}, {'cal': 1.0, 'loc': '[nan nan nan nan nan nan nan nan nan nan nan nan]', 'kind': 2, 'unit': 107, 'logno': 2, 'range': 1.0, 'scanno': 2, 'ch_name': 'C3-M2', 'unit_mul': 0, 'coil_type': 1, 'coord_frame': 4}, {'cal': 1.0, 'loc': '[nan nan nan nan nan nan nan nan nan nan nan nan]', 'kind': 2, 'unit': 107, 'logno': 3, 'range': 1.0, 'scanno': 3, 'ch_name': 'C3'

2025-12-17 08:42:47,255 [INFO] sleeplab_etl_luna - Sample ai_epoch_prediction(last_run, 10 premiers) : [RealDictRow({'epoch_index': 0, 'onset_sec': Decimal('0.000'), 'duration_sec': Decimal('30.000'), 'predicted_stage_code': 'W', 'prob_w': Decimal('0.99707'), 'prob_n1': Decimal('0.00120'), 'prob_n2': Decimal('0.00120'), 'prob_n3': Decimal('0.00029'), 'prob_r': Decimal('0.00024')}), RealDictRow({'epoch_index': 1, 'onset_sec': Decimal('30.000'), 'duration_sec': Decimal('30.000'), 'predicted_stage_code': 'W', 'prob_w': Decimal('0.99559'), 'prob_n1': Decimal('0.00278'), 'prob_n2': Decimal('0.00121'), 'prob_n3': Decimal('0.00012'), 'prob_r': Decimal('0.00030')}), RealDictRow({'epoch_index': 2, 'onset_sec': Decimal('60.000'), 'duration_sec': Decimal('30.000'), 'predicted_stage_code': 'W', 'prob_w': Decimal('0.99732'), 'prob_n1': Decimal('0.00163'), 'prob_n2': Decimal('0.00073'), 'prob_n3': Decimal('0.00008'), 'prob_r': Decimal('0.00025')}), RealDictRow({'epoch_index': 3, 'onset_sec': Decimal