In [None]:
# ============================================================
# SleepLab ETL V0 - S0001
# PANDORE Metadata + Hypnogram CSV + EDF PSG
# ============================================================

import os
import json
import uuid
import math
import logging
import re
from pathlib import Path
from typing import Optional, Tuple, Union

import pandas as pd
import psycopg2
from psycopg2.extras import RealDictCursor

import mne
import datetime as dt


# ============================================================
# 0. Logging
# ============================================================

logger = logging.getLogger("sleeplab_etl_s0001")
if not logger.handlers:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
    )
logger.setLevel(logging.INFO)


# ============================================================
# 1. Configuration chemins & paramètres
# ============================================================

# On fixe explicitement la racine du projet (Codespace)
BASE_DIR = Path("/workspaces/SleepLab_new").resolve()

# Dossier contenant les fichiers bruts PANDORE
RAW_DIR = BASE_DIR / "data" / "raw" / "PANDORE_SOURCE"

# Fichiers pour la session S0001
METADATA_CSV = RAW_DIR / "PANDORE_SAS_DATASET_Metadata.csv"
HYPNO_CSV    = RAW_DIR / "PANDORE_SAS_DATASET_S0001_Hypnogram.csv"
EDF_FILE     = RAW_DIR / "PANDORE_SAS_DATASET_S0001_PSG_1.edf"

# Paramètres métier
EPOCH_LENGTH_SEC = 30              # Durée d'une epoch en secondes (V0 : pas encore utilisé pour regridding)
DEFAULT_TZ       = "Europe/Brussels"

# Regex pour détecter un "time-only" du type HH:MM ou HH:MM:SS
HOUR_ONLY_RE = re.compile(r"^\s*\d{1,2}:\d{2}(:\d{2})?\s*$")

# Types de canaux connus dans ref_channel_type
REF_CHANNEL_TYPES = {"EEG", "EOG", "EMG", "OXIMETER", "AUDIO", "OTHER"}

# Configuration Postgres : variable d'environnement SLEEPLAB_DB_DSN
DB_DSN = os.getenv(
    "SLEEPLAB_DB_DSN",
    "postgresql://sleeplab:sleeplab@postgres:5432/sleeplab",
)
logger.info("DB_DSN utilisé : %s", DB_DSN)


# ============================================================
# 2. Connexion Postgres & helpers
# ============================================================

conn = psycopg2.connect(DB_DSN)

# On fixe le search_path sur le schéma 'sleeplab'
with conn.cursor() as _cur:
    _cur.execute("SET search_path TO sleeplab")
conn.commit()
logger.info("Connexion Postgres OK, search_path fixé sur 'sleeplab'")


def db_cursor(dict_cursor: bool = False):
    """
    Helper pour obtenir un curseur Postgres.

    - dict_cursor=True : lignes renvoyées sous forme de dict (RealDictCursor).
    - dict_cursor=False : tuples classiques.
    """
    cursor_factory = RealDictCursor if dict_cursor else None
    return conn.cursor(cursor_factory=cursor_factory)


def require_file(path: Path, label: str):
    """
    Vérifie qu'un fichier existe et est bien un fichier.
    Logue explicitement l'erreur si absent.
    """
    if not path.exists():
        logger.error("[FICHIER MANQUANT] %s : %s", label, path)
        raise FileNotFoundError(f"{label} non trouvé : {path}")
    if not path.is_file():
        logger.error("[CHEMIN NON FICHIER] %s : %s", label, path)
        raise FileNotFoundError(f"{label} n'est pas un fichier : {path}")
    logger.info("[OK] %s trouvé : %s", label, path)


def sanitize_for_json(value):
    """
    Nettoie une valeur pour qu'elle soit sérialisable en JSON standard :

    - NaN / NaT / valeurs pandas nulles -> None
    - datetime / Timestamp -> string isoformat()
    - 'nan' / 'NaN' / 'NaT' / '' -> None
    - Types numpy scalaires -> valeur Python native
    - numpy.ndarray -> list
    """
    if value is None:
        return None

    if isinstance(value, float) and math.isnan(value):
        return None

    if isinstance(value, (pd.Timestamp, dt.datetime)):
        return value.isoformat()

    if isinstance(value, str) and value.strip().lower() in ("nan", "nat", ""):
        return None

    try:
        import numpy as np

        if isinstance(value, (np.generic,)):
            return value.item()

        if isinstance(value, np.ndarray):
            return value.tolist()
    except Exception:
        pass

    return value


def normalize_channel_type_for_ref(channel_type_code: Optional[str]) -> Optional[str]:
    """
    S'assure que le channel_type_code est compatible avec ref_channel_type.
    Si le code n'est pas connu (ex: 'RESP'), on le remappe vers 'OTHER'.
    """
    if channel_type_code is None:
        return None

    code = channel_type_code.upper()
    if code in REF_CHANNEL_TYPES:
        return code

    logger.warning(
        "channel_type_code '%s' non présent dans ref_channel_type, remappé à 'OTHER'",
        code,
    )
    return "OTHER"


# ============================================================
# 3. Chargement des données sources
# ============================================================

def load_metadata_dataframe() -> pd.DataFrame:
    """
    Charge le metadata PANDORE (SAS export) dans un DataFrame.
    On laisse pandas détecter le séparateur (sep=None, engine='python').
    """
    require_file(METADATA_CSV, "PANDORE metadata CSV")
    logger.info("Lecture metadata PANDORE : %s", METADATA_CSV)

    df = pd.read_csv(
        METADATA_CSV,
        sep=None,
        engine="python",
        dtype=str,
    )
    logger.info(
        "Metadata chargé : %d lignes, %d colonnes. Colonnes=%s",
        len(df),
        len(df.columns),
        list(df.columns),
    )
    return df


def get_metadata_row_for_s0001(df_meta: pd.DataFrame) -> pd.Series:
    """
    Sélectionne la ligne de metadata correspondant à S0001
    en se basant sur la colonne Folder_id (ou toute colonne contenant 'folder').
    """
    folder_id_value = "PANDORE_SAS_DATASET_S0001"

    folder_col = None
    for col in df_meta.columns:
        if "folder" in col.lower():
            folder_col = col
            break

    if folder_col is None:
        logger.error(
            "Impossible de trouver une colonne de type Folder_id. Colonnes disponibles : %s",
            list(df_meta.columns),
        )
        raise KeyError("Colonne Folder_id introuvable dans le metadata PANDORE")

    subset = df_meta[df_meta[folder_col] == folder_id_value]

    if subset.empty:
        logger.error(
            "Aucune ligne metadata pour Folder_id=%s (colonne=%s)",
            folder_id_value,
            folder_col,
        )
        raise KeyError(f"Aucune ligne metadata pour Folder_id={folder_id_value}")

    if len(subset) > 1:
        logger.warning(
            "%d lignes metadata pour Folder_id=%s ; la première sera utilisée.",
            len(subset),
            folder_id_value,
        )

    row = subset.iloc[0]
    logger.info(
        "Ligne metadata sélectionnée pour %s (index=%s). Exemple de valeurs: %s",
        folder_id_value,
        row.name,
        row.to_dict(),
    )
    return row


def load_hypnogram_df() -> pd.DataFrame:
    """
    Charge l'hypnogramme PANDORE pour S0001 dans un DataFrame,
    en s'assurant que les colonnes begin/end/event sont bien parsées,
    même s'il y a un BOM ou des espaces/casse différents.
    """
    require_file(HYPNO_CSV, "PANDORE hypnogram CSV")
    logger.info("Lecture hypnogramme : %s", HYPNO_CSV)

    df = pd.read_csv(
        HYPNO_CSV,
        sep=None,
        engine="python",
        dtype=str,
    )
    logger.info(
        "Hypnogramme chargé : %d lignes, %d colonnes. Colonnes brutes=%s",
        len(df),
        len(df.columns),
        list(df.columns),
    )

    # Normalisation des noms de colonnes (BOM, espaces, casse)
    col_map = {}
    for c in df.columns:
        norm = c.strip().lstrip("\ufeff").lower()
        if norm == "begin":
            col_map[c] = "begin"
        elif norm == "end":
            col_map[c] = "end"
        elif norm == "event":
            col_map[c] = "event"

    if col_map:
        df = df.rename(columns=col_map)

    missing = []
    for col in ("begin", "end", "event"):
        if col not in df.columns:
            missing.append(col)

    if missing:
        logger.error(
            "Colonnes obligatoires manquantes dans l'hypnogramme : %s. Colonnes présentes=%s",
            missing,
            list(df.columns),
        )
        raise KeyError(f"Colonnes manquantes dans l'hypnogramme PANDORE : {missing}")

    # Conversion begin/end en Timestamp UTC
    df["begin"] = pd.to_datetime(df["begin"], errors="coerce", utc=True)
    df["end"]   = pd.to_datetime(df["end"], errors="coerce", utc=True)

    nb_null_begin = df["begin"].isna().sum()
    nb_null_end   = df["end"].isna().sum()
    if nb_null_begin or nb_null_end:
        logger.warning(
            "Hypnogramme : begin NaN=%d, end NaN=%d",
            nb_null_begin,
            nb_null_end,
        )

    return df


def load_edf() -> mne.io.BaseRaw:
    """
    Charge l'EDF PSG S0001 via mne.
    """
    require_file(EDF_FILE, "EDF PSG S0001")
    logger.info("Chargement EDF via mne : %s", EDF_FILE)
    raw = mne.io.read_raw_edf(str(EDF_FILE), preload=False, verbose=False)

    info = raw.info
    logger.info(
        "EDF chargé : nchan=%d, sfreq=%.3f, meas_date=%s, n_times=%d",
        info["nchan"],
        info["sfreq"],
        info["meas_date"],
        raw.n_times,
    )
    return raw


# ============================================================
# 4. ETL subject
# ============================================================

def etl_subject_from_metadata(row_meta: pd.Series) -> Tuple[uuid.UUID, int]:
    """
    Crée (ou récupère) le sujet 'S0001' dans la table subject.
    Pour cette V0 : on ne lit pas encore les infos patient dans l'EDF/metadata.
    """
    subject_code = "S0001"
    logger.info("ETL subject pour subject_code=%s", subject_code)

    with db_cursor() as cur:
        cur.execute(
            "SELECT subject_id FROM subject WHERE subject_code = %s",
            (subject_code,),
        )
        row = cur.fetchone()
        if row:
            subject_id = row[0]
            logger.info(
                "Sujet déjà présent, subject_id=%s → aucune nouvelle ligne.",
                subject_id,
            )
            return subject_id, 0

    subject_id = uuid.uuid4()
    with db_cursor() as cur:
        cur.execute(
            """
            INSERT INTO subject (
                subject_id, subject_code, source_system,
                source_subject_id, birth_year, sex,
                height_cm, weight_kg, bmi, notes
            )
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
            """,
            (
                str(subject_id),
                subject_code,
                "PANDORE",
                None,
                None,
                None,
                None,
                None,
                None,
                None,
            ),
        )
    logger.info(
        "Nouveau sujet inséré, subject_id=%s → 1 ligne insérée dans subject",
        subject_id,
    )
    return subject_id, 1


# ============================================================
# 5. ETL study_session (avec correction du bug de date)
# ============================================================

def parse_session_start_from_metadata(
    row_meta: pd.Series,
) -> Optional[Union[pd.Timestamp, str]]:
    """
    Essaie de récupérer start_time depuis le metadata.

    - Si on détecte une heure seule (HH:MM ou HH:MM:SS) → on renvoie la chaîne.
    - Si on parvient à parser un datetime complet → on renvoie un Timestamp UTC.
    - Sinon → None.
    """
    if "start_time" not in row_meta.index:
        logger.warning(
            "Colonne start_time absente du metadata → session_start_utc sera déduite de l'hypnogramme ou NULL"
        )
        return None

    raw_val = row_meta["start_time"]
    if pd.isna(raw_val) or str(raw_val).strip() == "":
        logger.warning(
            "start_time vide dans le metadata → session_start_utc sera déduite de l'hypnogramme ou NULL"
        )
        return None

    txt = str(raw_val).strip()

    if HOUR_ONLY_RE.match(txt):
        logger.info(
            "start_time '%s' détecté comme heure (sans date) → "
            "il sera combiné avec la date du premier 'begin' de l'hypnogramme.",
            txt,
        )
        return txt

    try:
        ts = pd.to_datetime(txt, utc=True)
        logger.info(
            "start_time '%s' parsé comme datetime complet : %s",
            txt,
            ts,
        )
        return ts
    except Exception as e:
        logger.warning(
            "Impossible de parser start_time '%s' comme datetime complet : %s → sera ignoré.",
            txt,
            e,
        )
        return None


def build_session_start_utc(
    row_meta: pd.Series,
    df_hypno: pd.DataFrame,
) -> Optional[pd.Timestamp]:
    """
    Construit un Timestamp UTC pour session_start_utc à partir :
      - de start_time (metadata), si possible
      - sinon de la première date 'begin' de l'hypnogramme
      - ou combinaison date(begin) + heure(start_time)
    """
    start_meta = parse_session_start_from_metadata(row_meta)

    if isinstance(start_meta, pd.Timestamp):
        logger.info(
            "session_start_utc pris directement du metadata (datetime complet) : %s",
            start_meta,
        )
        return start_meta

    if df_hypno is None or df_hypno.empty or "begin" not in df_hypno.columns:
        logger.warning(
            "Pas d'hypnogramme disponible pour inférer la date de début, session_start_utc sera NULL."
        )
        return None

    first_begin = df_hypno["begin"].dropna().min()
    if not isinstance(first_begin, pd.Timestamp):
        logger.warning(
            "Impossible de convertir 'begin' en Timestamp, session_start_utc sera NULL."
        )
        return None

    if start_meta is None:
        logger.info(
            "Aucun start_time exploitable, session_start_utc pris depuis le premier 'begin' de l'hypnogramme : %s",
            first_begin,
        )
        return first_begin

    if isinstance(start_meta, str) and HOUR_ONLY_RE.match(start_meta.strip()):
        date_part = first_begin.date().isoformat()
        time_part = start_meta.strip()
        try:
            combined = pd.to_datetime(f"{date_part} {time_part}", utc=True)
            logger.info(
                "session_start_utc combiné : date hypnogramme (%s) + heure start_time (%s) → %s",
                date_part,
                time_part,
                combined,
            )
            return combined
        except Exception as e:
            logger.warning(
                "Impossible de combiner date hypnogramme + start_time (%s) : %s ; on garde begin=%s",
                time_part,
                e,
                first_begin,
            )
            return first_begin

    logger.warning(
        "start_meta de type inattendu (%r), session_start_utc pris depuis le premier 'begin' : %s",
        start_meta,
        first_begin,
    )
    return first_begin


def etl_study_session_from_metadata(
    row_meta: pd.Series,
    subject_id: uuid.UUID,
    df_hypno: pd.DataFrame,
) -> Tuple[uuid.UUID, int, Optional[pd.Timestamp]]:
    """
    Crée (ou récupère) la study_session associée à S0001.
    """
    folder_id = None
    for col in row_meta.index:
        if "folder" in col.lower():
            folder_id = row_meta[col]
            break
    if not folder_id:
        folder_id = "PANDORE_SAS_DATASET_S0001"
        logger.warning(
            "Aucune colonne Folder_id dans la ligne metadata → session_code par défaut = %s",
            folder_id,
        )

    session_code = str(folder_id)
    session_start_utc = build_session_start_utc(row_meta, df_hypno)

    protocol_code = None
    for col in row_meta.index:
        if "study_type" in col.lower() or col.lower() == "type":
            protocol_code = row_meta[col]
            break

    logger.info(
        "ETL study_session pour session_code=%s, protocol_code=%s",
        session_code,
        protocol_code,
    )

    with db_cursor() as cur:
        cur.execute(
            """
            SELECT session_id
            FROM study_session
            WHERE session_code = %s AND subject_id = %s
            """,
            (session_code, str(subject_id)),
        )
        row = cur.fetchone()
        if row:
            session_id = row[0]
            logger.info(
                "Session déjà présente, session_id=%s → aucune nouvelle ligne.",
                session_id,
            )
            return session_id, 0, session_start_utc

    session_id = uuid.uuid4()
    with db_cursor() as cur:
        cur.execute(
            """
            INSERT INTO study_session (
                session_id, subject_id, session_code,
                session_start_utc, session_end_utc,
                site_name, protocol_code, scorer_name,
                tz_name, comments
            )
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
            """,
            (
                str(session_id),
                str(subject_id),
                session_code,
                session_start_utc.to_pydatetime()
                if isinstance(session_start_utc, pd.Timestamp)
                else None,
                None,
                None,
                protocol_code,
                None,
                DEFAULT_TZ,
                None,
            ),
        )

    logger.info(
        "Nouvelle session insérée, session_id=%s → 1 ligne insérée dans study_session",
        session_id,
    )
    return session_id, 1, session_start_utc


# ============================================================
# 6. ETL data_file & signal_channel à partir de l'EDF
# ============================================================

def etl_data_file_from_edf(
    raw: mne.io.BaseRaw,
    session_id: uuid.UUID,
) -> Tuple[uuid.UUID, int]:
    """
    Crée une entrée data_file à partir de l'EDF.
    """
    file_id = uuid.uuid4()
    info = raw.info

    sfreq = float(info["sfreq"] or 0.0)
    n_times = int(raw.n_times)
    duration_sec = int(round(n_times / sfreq)) if sfreq > 0 else None
    byte_size = EDF_FILE.stat().st_size

    logger.info(
        "Préparation data_file EDF : size=%d bytes, duration=%s sec, nchan=%d",
        byte_size,
        duration_sec,
        info["nchan"],
    )

    manufacturer = None
    device_serial = None
    device_info = info.get("device_info") or {}
    if isinstance(device_info, dict):
        manufacturer = device_info.get("manufacturer") or device_info.get("name")
        device_serial = device_info.get("serial")

    manufacturer = manufacturer or None

    device_serial_hash = None
    if device_serial:
        device_serial_hash = str(
            uuid.uuid5(uuid.NAMESPACE_DNS, str(device_serial))
        )

    header = {
        "meas_date": sanitize_for_json(info.get("meas_date")),
        "nchan": sanitize_for_json(info.get("nchan")),
        "sfreq": sanitize_for_json(info.get("sfreq")),
        "highpass": sanitize_for_json(info.get("highpass")),
        "lowpass": sanitize_for_json(info.get("lowpass")),
        "device_info": {
            k: sanitize_for_json(v) for k, v in (device_info or {}).items()
        },
    }
    raw_header_json_str = json.dumps(header, allow_nan=False)

    recorded_start_utc = None
    meas_date = info.get("meas_date")
    if isinstance(meas_date, pd.Timestamp):
        recorded_start_utc = meas_date.isoformat()

    with db_cursor() as cur:
        cur.execute(
            """
            INSERT INTO data_file (
                file_id, session_id,
                file_type_code, source_path, file_name,
                sha256_hex, byte_size,
                recorded_start_utc, recorded_duration_sec,
                n_channels, manufacturer, device_serial_hash,
                ingest_ts, raw_header_json
            )
            VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,NOW(),%s)
            """,
            (
                str(file_id),
                str(session_id),
                "EDF",
                str(EDF_FILE),
                EDF_FILE.name,
                None,
                byte_size,
                recorded_start_utc,
                duration_sec,
                info["nchan"],
                manufacturer,
                device_serial_hash,
                raw_header_json_str,
            ),
        )

    logger.info(
        "Ligne data_file insérée, file_id=%s → 1 ligne insérée dans data_file",
        file_id,
    )
    return file_id, 1


def map_channel_label_to_type(label: str) -> Optional[str]:
    """
    Mapping très simple label EDF → type de canal.
    Peut être complété plus tard.
    """
    if not label:
        return None

    low = label.lower()

    if low.startswith("eeg") or low in (
        "c3",
        "c4",
        "f3",
        "f4",
        "o1",
        "o2",
        "c3-m2",
        "c4-m1",
        "f3-m2",
        "f4-m1",
    ):
        return "EEG"

    if "eog" in low or low.startswith("e1") or low.startswith("e2"):
        return "EOG"

    if "chin" in low or "emg" in low:
        return "EMG"

    if "ox" in low or "spo2" in low:
        return "OXIMETER"

    if "thorax" in low or "abdomen" in low or "resp" in low:
        return "RESP"

    if "mic" in low or "snore" in low or "audio" in low:
        return "AUDIO"

    return "OTHER"


def etl_signal_channels_from_edf(
    raw: mne.io.BaseRaw,
    file_id: uuid.UUID,
) -> int:
    """
    Remplit la table signal_channel pour tous les canaux de l'EDF.
    """
    info = raw.info
    sfreq = float(info["sfreq"] or 0.0)
    nchan = info["nchan"]

    logger.info("ETL signal_channel : nchan=%d, sfreq=%.3f", nchan, sfreq)

    with db_cursor() as cur:
        cur.execute("DELETE FROM signal_channel WHERE file_id=%s", (str(file_id),))
        nb_deleted = cur.rowcount
    if nb_deleted:
        logger.info(
            "%d canal(aux) existant(s) supprimé(s) pour file_id=%s",
            nb_deleted,
            file_id,
        )

    nb_inserted = 0
    with db_cursor() as cur:
        for idx, label in enumerate(raw.ch_names):
            ch_info = info["chs"][idx]
            raw_type = map_channel_label_to_type(label)
            channel_type_code = normalize_channel_type_for_ref(raw_type)
            channel_id = uuid.uuid4()

            unit = None
            physical_min = None
            physical_max = None
            digital_min = None
            digital_max = None
            transducer = None
            prefilter = None

            if isinstance(ch_info, dict):
                unit = ch_info.get("unit")
                transducer = ch_info.get("ch_name")
                loc = ch_info.get("loc")
                if loc is not None:
                    prefilter = str(loc)

            if raw_type != channel_type_code:
                logger.info(
                    "Channel %d ('%s') : type brut='%s' normalisé='%s'",
                    idx,
                    label,
                    raw_type,
                    channel_type_code,
                )

            cur.execute(
                """
                INSERT INTO signal_channel (
                    channel_id, file_id,
                    channel_index, label,
                    channel_type_code, unit,
                    sampling_hz, physical_min, physical_max,
                    digital_min, digital_max, transducer,
                    prefilter, lowcut_hz, highcut_hz, notch_hz
                )
                VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                """,
                (
                    str(channel_id),
                    str(file_id),
                    idx,
                    label,
                    channel_type_code,
                    unit,
                    sfreq,
                    physical_min,
                    physical_max,
                    digital_min,
                    digital_max,
                    transducer,
                    prefilter,
                    info.get("highpass"),
                    info.get("lowpass"),
                    None,
                ),
            )
            nb_inserted += 1

    logger.info(
        "%d ligne(s) insérée(s) dans signal_channel pour file_id=%s",
        nb_inserted,
        file_id,
    )
    return nb_inserted


# ============================================================
# 7. Hypnogramme → hypnogram_epoch + event_annotation
# ============================================================

def map_event_to_sleep_stage(event_str: Optional[str]) -> Optional[str]:
    """
    Mapping 'event' -> code de stade de sommeil standardisé.

    Objectif :
    - Accepter le CSV PANDORE qui code les stades en :
        AWA, N1, N2, N3, REM
    - Garder la compatibilité avec d'éventuels anciens exports
      de type 'stage-n1', 'stage-rem', etc.
    - Ne rien renvoyer pour les événements non-stade :
      arousals, apnées, positions, ...

    Retourne :
        'W', 'N1', 'N2', 'N3', 'R' ou None
    """
    if not event_str:
        return None

    s = event_str.strip().lower()

    # Nouveau CSV : AWA / N1 / N2 / N3 / REM
    if s in ("awa", "wake", "awake", "w"):
        return "W"
    if s == "n1":
        return "N1"
    if s == "n2":
        return "N2"
    if s in ("n3", "n4"):
        return "N3"
    if s == "rem":
        return "R"

    # Anciennes notations : 'stage-n1', 'stage-rem', ...
    if "stage-w" in s or "stage w" in s:
        return "W"
    if "stage-n1" in s or ("n1" in s and "stage" in s):
        return "N1"
    if "stage-n2" in s or ("n2" in s and "stage" in s):
        return "N2"
    if "stage-n3" in s or "stage-n4" in s or ("n3" in s and "stage" in s):
        return "N3"
    if "stage-rem" in s or ("rem" in s and "stage" in s):
        return "R"

    # Arousal, apnées, positions, etc. → pas un stade de sommeil
    return None


def etl_hypnogram_to_epochs_and_events(
    df_hypno: pd.DataFrame,
    session_id: uuid.UUID,
    session_start_utc: Optional[pd.Timestamp],
) -> Tuple[int, int]:
    """
    Pour l'instant, on considère chaque ligne de l'hypnogramme comme une "epoch"
    et un événement associé :
      - hypnogram_epoch : 1 ligne par évènement, grille V0 (epoch_index = index du DF)
      - event_annotation : 1 ligne par évènement, extra_json complet
    """

    if "begin" not in df_hypno.columns or "end" not in df_hypno.columns:
        logger.error(
            "Colonnes 'begin' et/ou 'end' absentes du hypnogramme → ETL hypnogram annulé"
        )
        return 0, 0

    if "event" not in df_hypno.columns:
        logger.error("Colonne 'event' absente du hypnogramme → ETL hypnogram annulé")
        return 0, 0

    logger.info(
        "ETL hypnogram_epoch + event_annotation pour session_id=%s (%d lignes brutes)",
        session_id,
        len(df_hypno),
    )

    nb_epochs_inserted = 0
    nb_events_inserted = 0

    with db_cursor() as cur:
        cur.execute(
            "DELETE FROM hypnogram_epoch WHERE session_id=%s", (str(session_id),)
        )
        cur.execute(
            """
            DELETE FROM event_annotation
            WHERE session_id=%s AND source_name = 'PANDORE_HYPNO'
            """,
            (str(session_id),),
        )

    with db_cursor() as cur:
        for idx, row in df_hypno.iterrows():
            begin_dt = row["begin"]
            end_dt = row["end"]

            if not isinstance(begin_dt, pd.Timestamp) or not isinstance(
                end_dt, pd.Timestamp
            ):
                logger.debug(
                    "Ligne %s ignorée car begin/end non parsable : %s / %s",
                    idx,
                    begin_dt,
                    end_dt,
                )
                continue

            event_str = row.get("event")
            sleep_stage_code = map_event_to_sleep_stage(event_str)

            if session_start_utc is not None and isinstance(
                session_start_utc, pd.Timestamp
            ):
                onset_sec = (begin_dt - session_start_utc).total_seconds()
            else:
                onset_sec = 0.0

            duration_sec = (end_dt - begin_dt).total_seconds()

            # hypnogram_epoch
            epoch_id = uuid.uuid4()
            epoch_index = idx  # V0 : 1 ligne d'hypnogram = 1 epoch

            scorer_name = row.get("scoring_owner") or row.get("scoring_name")

            cur.execute(
                """
                INSERT INTO hypnogram_epoch (
                    epoch_id, session_id,
                    epoch_index, start_sec,
                    duration_sec, sleep_stage_code,
                    scorer_name
                )
                VALUES (%s,%s,%s,%s,%s,%s,%s)
                """,
                (
                    str(epoch_id),
                    str(session_id),
                    int(epoch_index),
                    int(onset_sec),
                    int(duration_sec),
                    sleep_stage_code,
                    scorer_name,
                ),
            )
            nb_epochs_inserted += 1

            # event_annotation
            scoring_name = row.get("scoring_name")
            scoring_owner = row.get("scoring_owner")
            scoring_type = row.get("scoring_type")

            extra_raw = {
                "begin": begin_dt,
                "end": end_dt,
                "location": row.get("location"),
                "is_deleted": row.get("is_deleted"),
                "scoring_name": scoring_name,
                "scoring_owner": scoring_owner,
                "scoring_type": scoring_type,
                "scoring": row.get("scoring"),
            }

            extra_clean = {k: sanitize_for_json(v) for k, v in extra_raw.items()}
            extra_json_str = json.dumps(extra_clean, allow_nan=False)

            event_id = uuid.uuid4()
            cur.execute(
                """
                INSERT INTO event_annotation (
                    event_id, session_id, file_id,
                    channel_id, event_type_code,
                    event_label, onset_sec, duration_sec,
                    severity, source_name, version_tag,
                    extra_json
                )
                VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                """,
                (
                    str(event_id),
                    str(session_id),
                    None,
                    None,
                    None,
                    event_str,
                    float(onset_sec),
                    float(duration_sec),
                    None,
                    "PANDORE_HYPNO",
                    scoring_type or scoring_owner or scoring_name,
                    extra_json_str,
                ),
            )
            nb_events_inserted += 1

    logger.info(
        "%d epoch(s) insérée(s) dans hypnogram_epoch pour session_id=%s",
        nb_epochs_inserted,
        session_id,
    )
    logger.info(
        "%d événement(s) inséré(s) dans event_annotation (source=PANDORE_HYPNO) pour session_id=%s",
        nb_events_inserted,
        session_id,
    )
    return nb_epochs_inserted, nb_events_inserted


# ============================================================
# 8. ETL metric_summary à partir du metadata
# ============================================================

def etl_metric_summary_from_metadata(
    row_meta: pd.Series,
    session_id: uuid.UUID,
) -> int:
    """
    Insère les métriques globales AHI / AI / HI / ODI / LMI / PLMI
    dans metric_summary, à partir du metadata PANDORE.
    """
    metric_cols = ["ahi", "ai", "hi", "odi", "lmi", "plmi"]
    nb_inserted = 0

    with db_cursor() as cur:
        for col in metric_cols:
            col_in_meta = None
            for c in row_meta.index:
                if c.lower() == col:
                    col_in_meta = c
                    break

            if not col_in_meta:
                logger.warning("Colonne métrique '%s' absente du metadata", col)
                continue

            val_str = row_meta[col_in_meta]
            if val_str is None or str(val_str).strip() == "":
                logger.warning(
                    "Valeur vide pour la métrique '%s' dans le metadata", col
                )
                continue

            try:
                val = float(str(val_str).replace(",", "."))
            except Exception:
                logger.warning(
                    "Impossible de convertir la métrique '%s' (valeur=%s)",
                    col,
                    val_str,
                )
                continue

            metric_id = uuid.uuid4()
            cur.execute(
                """
                INSERT INTO metric_summary (
                    metric_id, session_id,
                    channel_id, scope_level,
                    sleep_stage_code, metric_name,
                    metric_value, unit,
                    method, window_sec, computed_ts
                )
                VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,NOW())
                """,
                (
                    str(metric_id),
                    str(session_id),
                    None,
                    "session",
                    None,
                    col.upper(),
                    val,
                    "index_per_hour",
                    "PANDORE_METADATA_V1",
                    None,
                ),
            )
            nb_inserted += 1

    logger.info(
        "%d métrique(s) insérée(s) dans metric_summary pour session_id=%s",
        nb_inserted,
        session_id,
    )
    return nb_inserted


# ============================================================
# 9. Orchestration : run_etl_s0001 + contrôles
# ============================================================

def run_etl_s0001():
    """
    Orchestrateur V0 :
        - Metadata → subject + study_session + metric_summary
        - EDF → data_file + signal_channel
        - Hypnogram CSV → hypnogram_epoch + event_annotation
    pour la session PANDORE_SAS_DATASET_S0001.
    """
    logger.info("=== ETL V0 S0001 démarré ===")

    df_meta = load_metadata_dataframe()
    df_hypno = load_hypnogram_df()
    row_meta = get_metadata_row_for_s0001(df_meta)

    with conn:
        subject_id, nb_subject = etl_subject_from_metadata(row_meta)

        session_id, nb_session, session_start_utc = etl_study_session_from_metadata(
            row_meta, subject_id, df_hypno
        )

        raw = load_edf()
        file_id, nb_df = etl_data_file_from_edf(raw, session_id)
        nb_channels = etl_signal_channels_from_edf(raw, file_id)

        nb_epochs, nb_events = etl_hypnogram_to_epochs_and_events(
            df_hypno, session_id, session_start_utc
        )

        nb_metrics = etl_metric_summary_from_metadata(row_meta, session_id)

    logger.info(
        "=== ETL S0001 terminé : "
        "subject=%d, session=%d, data_file=%d, "
        "channels=%d, epochs=%d, events=%d, metrics=%d ===",
        nb_subject,
        nb_session,
        nb_df,
        nb_channels,
        nb_epochs,
        nb_events,
        nb_metrics,
    )


def log_table_count(table_name: str):
    """
    Log le nombre total de lignes dans une table.
    """
    with db_cursor() as cur:
        cur.execute(f"SELECT COUNT(*) FROM {table_name}")
        nb = cur.fetchone()[0]
    logger.info("Table %s : %d ligne(s) au total", table_name, nb)


def run_checks_after_etl():
    """
    Lance l'ETL S0001 puis affiche quelques stats / échantillons.
    """
    run_etl_s0001()

    for tbl in [
        "subject",
        "study_session",
        "data_file",
        "signal_channel",
        "hypnogram_epoch",
        "event_annotation",
        "metric_summary",
    ]:
        log_table_count(tbl)

    with db_cursor(dict_cursor=True) as cur:
        cur.execute("SELECT * FROM subject LIMIT 5")
        logger.info("Sample subject : %s", cur.fetchall())

    with db_cursor(dict_cursor=True) as cur:
        cur.execute("SELECT * FROM study_session LIMIT 5")
        logger.info("Sample study_session : %s", cur.fetchall())

    with db_cursor(dict_cursor=True) as cur:
        cur.execute("SELECT file_id, file_name FROM data_file LIMIT 5")
        logger.info("Sample data_file : %s", cur.fetchall())

    with db_cursor(dict_cursor=True) as cur:
        cur.execute(
            """
            SELECT channel_index, label
            FROM signal_channel
            ORDER BY channel_index
            LIMIT 10
            """
        )
        logger.info("Sample signal_channel (10 premiers) : %s", cur.fetchall())

    with db_cursor(dict_cursor=True) as cur:
        cur.execute(
            """
            SELECT epoch_index, start_sec, duration_sec, sleep_stage_code
            FROM hypnogram_epoch
            ORDER BY epoch_index
            LIMIT 10
            """
        )
        logger.info("Sample hypnogram_epoch (10 premiers) : %s", cur.fetchall())

    with db_cursor(dict_cursor=True) as cur:
        cur.execute(
            """
            SELECT event_label, onset_sec, duration_sec
            FROM event_annotation
            ORDER BY onset_sec
            LIMIT 10
            """
        )
        logger.info("Sample event_annotation (10 premiers) : %s", cur.fetchall())

    with db_cursor(dict_cursor=True) as cur:
        cur.execute(
            """
            SELECT metric_name, metric_value
            FROM metric_summary
            ORDER BY metric_name
            LIMIT 10
            """
        )
        logger.info("Sample metric_summary (10 premiers) : %s", cur.fetchall())


: 

In [None]:
run_checks_after_etl()
