In [None]:
# ============================================================
# SUBJECT PAIRING: EMBRACE PLUS vs EMPATICA (robust and optimal)
#
# - Embrace: uses _manifest.csv (t0_epoch + duration_sec | n/fs) or E4 parts (last non-empty)
# - Normalizes subject IDs (CLIDEM011 ≡ CLIDEM11) before merging
# - Global 1-to-1 matching (Hungarian algorithm if available, otherwise greedy)
# - Cost = 1*|Δstart| + 0.5*|Δend| + 0.25*|Δduration| (ignores NA terms)
# - Outputs: emparejamientos.csv, rangos_por_reloj.csv
# ============================================================

import os
import re
import math
from datetime import datetime, timezone

import numpy as np
import pandas as pd


# ------------------------------------------------------------
# Configuration
# ------------------------------------------------------------
EMPATICA_ROOT = os.path.expanduser("~/Documents/EMPATICA_all")
EMBRACE_ROOT  = os.path.expanduser("~/TFM/step2_e4csv")

SENSORES_COMUNES  = ["ACC", "EDA", "TEMP", "BVP"]
SENSORES_FALLBACK = ["ACC", "EDA", "TEMP", "BVP", "HR", "IBI"]

DIR_SALIDA = os.path.expanduser("~/TFM/Emparejamiento")
SALIDA_EMPAREJAMIENTO = os.path.join(DIR_SALIDA, "emparejamientos.csv")
SALIDA_RANGOS = os.path.join(DIR_SALIDA, "rangos_por_reloj.csv")
os.makedirs(DIR_SALIDA, exist_ok=True)

ALERTA_HORAS = 48  # QC threshold for large start-time differences


# ------------------------------------------------------------
# General helper functions
# ------------------------------------------------------------
def normalizar_sensor(sensor: str) -> str:
    s = str(sensor).lower()
    if "acc" in s:  return "ACC"
    if "eda" in s:  return "EDA"
    if "temp" in s: return "TEMP"
    if "bvp" in s:  return "BVP"
    if "ibi" in s:  return "IBI"
    if "hr"  in s:  return "HR"
    return str(sensor).upper()


def epoch_to_dt_utc(x: float) -> datetime | None:
    # Converts epoch (s or ms) to UTC datetime using basic heuristics
    if x is None or (isinstance(x, float) and (math.isnan(x) or not math.isfinite(x))):
        return None
    try:
        if x < 1e7:
            return None
        if x > 1e11:
            x = x / 1000.0
        return datetime.fromtimestamp(float(x), tz=timezone.utc)
    except Exception:
        return None


def fmt_es(dt: datetime | None) -> str | None:
    # Spanish-style datetime formatting
    if dt is None:
        return None
    return dt.astimezone(timezone.utc).strftime("%d/%m/%Y %H:%M:%S")


def normalize_clidem(x: str) -> str:
    # CLIDEM011 -> CLIDEM11
    x = str(x).strip().upper()
    x = re.sub(r"\s+", "", x)
    m = re.match(r"^([A-Z]+)(\d+)$", x)
    if not m:
        return x
    pref, num = m.group(1), m.group(2)
    try:
        num = str(int(num))
    except Exception:
        pass
    return pref + num


def normalize_empatica_id(x: str) -> str:
    return str(x).strip()


# ------------------------------------------------------------
# EMPATICA: compute start/end time per subject
# ------------------------------------------------------------
def leer_epoch_inicio_empatica(csv_path: str) -> float | None:
    try:
        with open(csv_path, "r", encoding="utf-8", errors="ignore") as f:
            l1 = f.readline().strip()
        if not l1:
            return None
        v = float(l1.split(",")[0])
        if not math.isfinite(v) or v < 1e7:
            return None
        return v
    except Exception:
        return None


def leer_freq_empatica(csv_path: str) -> float | None:
    try:
        with open(csv_path, "r", encoding="utf-8", errors="ignore") as f:
            _ = f.readline()
            l2 = f.readline().strip()
        if not l2:
            return None
        fr = float(l2.split(",")[0])
        if (not math.isfinite(fr)) or fr <= 0:
            return None
        return fr
    except Exception:
        return None


def nfilas_datos_empatica(csv_path: str) -> int | None:
    try:
        df = pd.read_csv(csv_path, skiprows=2, header=None, usecols=[0])
        return int(df.shape[0])
    except Exception:
        return None


def rango_paciente_empatica(dir_paciente: str) -> dict:
    paciente = os.path.basename(dir_paciente)
    cand = [os.path.join(dir_paciente, f"{s}.csv") for s in SENSORES_COMUNES]
    cand = [p for p in cand if os.path.exists(p)]

    if not cand:
        return {"paciente": paciente, "start_empatica": None, "end_empatica": None}

    inicios = []
    finales = []
    for f in cand:
        st = leer_epoch_inicio_empatica(f)
        fr = leer_freq_empatica(f)
        if st is None or fr is None:
            continue
        n = nfilas_datos_empatica(f)
        if n is None or n <= 0:
            continue
        en = st + (n - 1) / fr
        inicios.append(st)
        finales.append(en)

    if not inicios or not finales:
        return {"paciente": paciente, "start_empatica": None, "end_empatica": None}

    start = epoch_to_dt_utc(min(inicios))
    end   = epoch_to_dt_utc(max(finales))
    return {"paciente": paciente, "start_empatica": start, "end_empatica": end}


# ------------------------------------------------------------
# EMBRACE: compute start/end time per subject (robust)
# ------------------------------------------------------------
def e4_times_from_file(csv_path: str) -> dict:
    try:
        with open(csv_path, "r", encoding="utf-8", errors="ignore") as f:
            l1 = f.readline().strip()
            l2 = f.readline().strip()
        st = float(l1.split(",")[0])
        fr = float(l2.split(",")[0])
        if (not math.isfinite(st)) or (not math.isfinite(fr)) or fr <= 0:
            return {"start": None, "end": None, "n": None, "st_epoch": None, "fr": None}

        try:
            n = int(pd.read_csv(csv_path, skiprows=2, header=None, usecols=[0]).shape[0])
        except Exception:
            n = None

        start_dt = epoch_to_dt_utc(st)
        if n is None or n <= 0:
            return {"start": start_dt, "end": None, "n": n, "st_epoch": st, "fr": fr}

        end_dt = epoch_to_dt_utc(st + (n - 1) / fr)
        return {"start": start_dt, "end": end_dt, "n": n, "st_epoch": st, "fr": fr}
    except Exception:
        return {"start": None, "end": None, "n": None, "st_epoch": None, "fr": None}


def sensor_start_end_from_dir(dir_sensor: str) -> dict:
    files = [os.path.join(dir_sensor, f) for f in os.listdir(dir_sensor) if f.lower().endswith(".csv")]
    if not files:
        return {"start": None, "end": None}
    files.sort()

    st = None
    en = None

    # First non-empty file
    for f in files:
        tf = e4_times_from_file(f)
        if tf["start"] is not None and tf["n"] is not None and tf["n"] > 0:
            st = tf["start"]
            break

    # Last non-empty file
    for f in reversed(files):
        tf = e4_times_from_file(f)
        if tf["end"] is not None and tf["n"] is not None and tf["n"] > 0:
            en = tf["end"]
            break

    return {"start": st, "end": en}


def rango_paciente_embrace(path_paciente: str) -> dict:
    paciente = os.path.basename(path_paciente)
    manifest_path = os.path.join(path_paciente, "_manifest.csv")

    # Primary source: manifest
    if os.path.exists(manifest_path):
        try:
            man = pd.read_csv(manifest_path)
            if "sensor" in man.columns:
                man["sensor"] = man["sensor"].astype(str).map(normalizar_sensor)

            man2 = man[man["sensor"].isin(SENSORES_COMUNES)] if "sensor" in man.columns else man
            if len(man2) > 0 and "t0_epoch" in man2.columns:
                t0 = pd.to_numeric(man2["t0_epoch"], errors="coerce").to_numpy(dtype=float)

                if "duration_sec" in man2.columns:
                    dur = pd.to_numeric(man2["duration_sec"], errors="coerce").to_numpy(dtype=float)
                    t_end = t0 + dur
                elif "fs" in man2.columns and "n" in man2.columns:
                    fs = pd.to_numeric(man2["fs"], errors="coerce").to_numpy(dtype=float)
                    nn = pd.to_numeric(man2["n"], errors="coerce").to_numpy(dtype=float)
                    t_end = t0 + (nn - 1) / fs
                else:
                    t_end = np.full_like(t0, np.nan)

                starts = [epoch_to_dt_utc(x) for x in t0 if epoch_to_dt_utc(x) is not None]
                ends   = [epoch_to_dt_utc(x) for x in t_end if epoch_to_dt_utc(x) is not None]

                if starts or ends:
                    return {
                        "paciente": paciente,
                        "start_embrace": min(starts) if starts else None,
                        "end_embrace": max(ends) if ends else None
                    }
        except Exception:
            pass

    # Fallback: derive from sensor folders
    starts = []
    ends = []
    for s in SENSORES_FALLBACK:
        d = os.path.join(path_paciente, s)
        if not os.path.isdir(d):
            continue
        se = sensor_start_end_from_dir(d)
        if se["start"] is not None:
            starts.append(se["start"])
        if se["end"] is not None:
            ends.append(se["end"])

    return {
        "paciente": paciente,
        "start_embrace": min(starts) if starts else None,
        "end_embrace": max(ends) if ends else None
    }


# ------------------------------------------------------------
# Build summaries
# ------------------------------------------------------------
def list_dirs(root: str) -> list[str]:
    if not os.path.isdir(root):
        return []
    return [os.path.join(root, d) for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]


dirs_empatica = list_dirs(EMPATICA_ROOT)
empatica_df = pd.DataFrame([rango_paciente_empatica(d) for d in dirs_empatica])
if not empatica_df.empty:
    empatica_df["paciente"] = empatica_df["paciente"].map(normalize_empatica_id)
    empatica_df = empatica_df.dropna(subset=["start_empatica"])
    empatica_df["duracion_empatica_h"] = (
        (empatica_df["end_empatica"] - empatica_df["start_empatica"]).dt.total_seconds() / 3600.0
    )

dirs_embrace = list_dirs(EMBRACE_ROOT)
embrace_df = pd.DataFrame([rango_paciente_embrace(d) for d in dirs_embrace])
if not embrace_df.empty:
    embrace_df["paciente"] = embrace_df["paciente"].map(normalize_clidem)
    embrace_df = embrace_df.dropna(subset=["start_embrace"])
    embrace_df["duracion_embrace_h"] = (
        (embrace_df["end_embrace"] - embrace_df["start_embrace"]).dt.total_seconds() / 3600.0
    )

print(f"Subjects after normalization: Embrace={len(embrace_df)} | Empatica={len(empatica_df)}")

if len(embrace_df) == 0 or len(empatica_df) == 0:
    raise RuntimeError("Insufficient subjects for pairing. Check input paths.")


# ------------------------------------------------------------
# One-to-one matching using combined cost
# ------------------------------------------------------------
def hours_diff(a: datetime | None, b: datetime | None) -> float | None:
    if a is None or b is None:
        return None
    return abs((a - b).total_seconds()) / 3600.0


def pair_cost(se, ee, de, ss, es, ds, w_start=1.0, w_end=0.5, w_dur=0.25) -> float:
    c_start = hours_diff(ss, se)
    c_end = None if (ee is None or es is None) else hours_diff(es, ee)
    c_dur = None if (de is None or ds is None or math.isnan(de) or math.isnan(ds)) else abs(ds - de)
    return (w_start * (c_start or 0.0)
            + w_end * (c_end or 0.0)
            + w_dur * (c_dur or 0.0))


E = embrace_df[["paciente", "start_embrace", "end_embrace", "duracion_embrace_h"]].reset_index(drop=True)
P = empatica_df[["paciente", "start_empatica", "end_empatica", "duracion_empatica_h"]].reset_index(drop=True)

C = np.zeros((len(E), len(P)), dtype=float)
for i in range(len(E)):
    for j in range(len(P)):
        C[i, j] = pair_cost(
            E.at[i, "start_embrace"], E.at[i, "end_embrace"], E.at[i, "duracion_embrace_h"],
            P.at[j, "start_empatica"], P.at[j, "end_empatica"], P.at[j, "duracion_empatica_h"],
        )


def assign_optimal(C: np.ndarray) -> list[tuple[int, int]]:
    try:
        from scipy.optimize import linear_sum_assignment
        r, c = linear_sum_assignment(C)
        return list(zip(r.tolist(), c.tolist()))
    except Exception:
        used_i = np.zeros(C.shape[0], dtype=bool)
        used_j = np.zeros(C.shape[1], dtype=bool)
        pairs = []
        while True:
            Cmask = C.copy()
            Cmask[used_i, :] = np.inf
            Cmask[:, used_j] = np.inf
            if not np.isfinite(Cmask).any():
                break
            i, j = np.unravel_index(np.argmin(Cmask), Cmask.shape)
            used_i[i] = True
            used_j[j] = True
            pairs.append((i, j))
        return pairs


pairs = assign_optimal(C)


def get_pair_row(i: int, j: int) -> dict:
    return {
        "paciente_embrace": E.at[i, "paciente"],
        "paciente_empatica": P.at[j, "paciente"],
        "diferencia_inicio_horas": hours_diff(P.at[j, "start_empatica"], E.at[i, "start_embrace"]),
        "diferencia_final_horas": hours_diff(P.at[j, "end_empatica"], E.at[i, "end_embrace"]),
        "diferencia_duracion_horas": abs(P.at[j, "duracion_empatica_h"] - E.at[i, "duracion_embrace_h"]),
        "start_embrace": E.at[i, "start_embrace"],
        "end_embrace": E.at[i, "end_embrace"],
        "start_empatica": P.at[j, "start_empatica"],
        "end_empatica": P.at[j, "end_empatica"],
    }


emparejamientos_final = pd.DataFrame([get_pair_row(i, j) for i, j in pairs])
emparejamientos_final.to_csv(SALIDA_EMPAREJAMIENTO, index=False)


# ------------------------------------------------------------
# Export ranges per device
# ------------------------------------------------------------
rangos_embrace = embrace_df.copy()
rangos_embrace["reloj"] = "Embrace"
rangos_embrace["start"] = rangos_embrace["start_embrace"].map(fmt_es)
rangos_embrace["end"]   = rangos_embrace["end_embrace"].map(fmt_es)
rangos_embrace = rangos_embrace[["reloj", "paciente", "start", "end"]]

rangos_empatica = empatica_df.copy()
rangos_empatica["reloj"] = "Empatica"
rangos_empatica["start"] = rangos_empatica["start_empatica"].map(fmt_es)
rangos_empatica["end"]   = rangos_empatica["end_empatica"].map(fmt_es)
rangos_empatica = rangos_empatica[["reloj", "paciente", "start", "end"]]

rangos_por_reloj = pd.concat([rangos_embrace, rangos_empatica], ignore_index=True)
rangos_por_reloj.to_csv(SALIDA_RANGOS, index=False)

print("Pairing completed successfully.")