In [None]:
# ============================================================
# Exporting signals in Empatica E4 format by parts
#
# Processes ONE subject from ~/TFM/all_data.pkl
# (dict: subject_id -> list of blocks)
# and exports CSV files following the Empatica E4 structure,
# split by sensor and temporal parts.
#
# Output structure:
# ~/TFM/step2_e4csv/<SUBJECT>/
#   ├── EDA/  (part01.csv, part02.csv, ...)
#   ├── TEMP/ (part01.csv, ...)
#   ├── BVP/  (part01.csv, ...)
#   ├── ACC/  (part01.csv, ...)
#   ├── IBI/  (part01.csv, ...)
#   ├── _manifest.csv
#   └── _README_UNITS.txt
# ============================================================

import os
import math
import gc
import pickle
from dataclasses import dataclass
from datetime import datetime, timezone

import numpy as np
import pandas as pd


# -------------------- Parameters --------------------
# Input pickle with all subjects
in_all = os.path.expanduser("~/TFM/all_data.pkl")  # dict: patient_id -> list[blocks]

# Output root directory
out_root = os.path.expanduser("~/TFM/step2_e4csv")

# Target subject (if None, first one alphabetically)
patient_id_target = "CLIDEM22"

# Split rules (temporal gaps)
gap_split_sec_uniform = 180  # 3 min (EDA/TEMP/BVP/ACC)
gap_split_sec_ibi = 120      # 2 min between peaks (IBI)

# Target sampling frequencies
FS_TARGET_TEMP = 4           # TEMP: 1 Hz -> 4 Hz (interpolation)
FS_TARGET_ACC = 32           # ACC: 64 Hz -> 32 Hz (antialias + resampling)

# Sampling frequency tolerances
FS_TOL_ABS = 0.01            # Hz
FS_TOL_REL = 0.0025          # 0.25%


# -------------------- Helper functions --------------------
def final_timestamp(t0: float, n: int, fs: float) -> float:
    # Last timestamp of a uniformly sampled block
    return t0 + (n - 1) / fs


def gap_epsilon(fs: float) -> float:
    # Temporal jitter tolerance: |gap| <= 1/fs is considered continuous
    if fs is None or (isinstance(fs, float) and (math.isnan(fs) or fs <= 0)):
        return 0.0
    return 1.0 / fs


def same_fs(a: float, b: float,
            tol_abs: float = FS_TOL_ABS,
            tol_rel: float = FS_TOL_REL) -> bool:
    # Compare sampling frequencies with absolute and relative tolerance
    if a is None or b is None:
        return False
    if (isinstance(a, float) and math.isnan(a)) or (isinstance(b, float) and math.isnan(b)):
        return False
    diff = abs(a - b)
    thr = max(tol_abs, tol_rel * max(abs(a), abs(b)))
    return diff <= thr


def resample_signal(values: np.ndarray, fs_orig: float, fs_target: float) -> np.ndarray:
    # Linear interpolation resampling
    values = np.asarray(values, dtype=float)
    if fs_orig <= 0 or fs_target <= 0 or values.size <= 1:
        return values

    t_orig = np.arange(values.size) / fs_orig
    t_target = np.arange(0, t_orig[-1] + 1e-12, 1.0 / fs_target)
    return np.interp(t_target, t_orig, values).astype(float)


def fill_na_edges(x: np.ndarray) -> np.ndarray:
    # Fill NaNs at the beginning and end with nearest valid values
    x = np.asarray(x, dtype=float)
    if not np.isnan(x).any():
        return x

    if np.isnan(x[0]):
        first_ok = np.where(~np.isnan(x))[0]
        if first_ok.size:
            x[:first_ok[0] + 1] = x[first_ok[0]]

    if np.isnan(x[-1]):
        last_ok = np.where(~np.isnan(x))[0]
        if last_ok.size:
            x[last_ok[-1]:] = x[last_ok[-1]]

    return x


def lowpass_ma_centered(values: np.ndarray, k: int) -> np.ndarray:
    # Simple centered moving-average low-pass filter
    values = np.asarray(values, dtype=float)
    k = max(1, int(round(k)))
    if k % 2 == 0:
        k += 1
    if k == 1 or values.size < k:
        return values

    kernel = np.ones(k, dtype=float) / k
    y = np.convolve(values, kernel, mode="same")
    return fill_na_edges(y)


def is_integer_ratio(fs_orig: float, fs_target: float, tol: float = 1e-6) -> bool:
    # Check if fs_orig / fs_target is (approximately) an integer
    if fs_orig is None or fs_target is None or fs_target == 0:
        return False
    r = fs_orig / fs_target
    return abs(r - round(r)) < tol


def resample_with_antialias(values: np.ndarray,
                            fs_orig: float,
                            fs_target: float) -> np.ndarray:
    # Resampling with basic antialiasing (used for ACC)
    values = np.asarray(values, dtype=float)
    if fs_orig <= 0 or fs_target <= 0 or values.size <= 1:
        return values

    if abs(fs_orig - fs_target) < 1e-12:
        return values

    if fs_orig < fs_target:
        return resample_signal(values, fs_orig, fs_target)

    r = fs_orig / fs_target
    if is_integer_ratio(fs_orig, fs_target):
        k = int(round(r))
        y = lowpass_ma_centered(values, k)
        return y[::k]
    else:
        k = int(math.ceil(r))
        y = lowpass_ma_centered(values, k)
        return resample_signal(y, fs_orig, fs_target)


def trim_overlap_samples(t_final_prev: float,
                          t0_next: float,
                          fs_next: float) -> int:
    # Number of samples to drop when blocks overlap
    if fs_next <= 0:
        return 0
    overlap_sec = t_final_prev - t0_next
    if overlap_sec <= 0:
        return 0
    return max(0, int(math.floor(overlap_sec * fs_next) + 1))


def format_hms(sec: float) -> str | None:
    # Convert seconds to HH:MM:SS
    if sec is None or (isinstance(sec, float) and math.isnan(sec)) or sec < 0:
        return None
    h = int(sec // 3600)
    m = int((sec % 3600) // 60)
    s = int(round(sec % 60))
    return f"{h:02d}:{m:02d}:{s:02d}"


def epoch_to_iso(t0: float) -> str:
    # Epoch seconds to ISO-8601 UTC
    return datetime.fromtimestamp(float(t0), tz=timezone.utc).isoformat()


# -------------------- E4 writers --------------------
def ensure_dir(path: str) -> None:
    # Create directory if needed
    os.makedirs(path, exist_ok=True)


def write_signal_single(path: str,
                        t0: float,
                        fs: float,
                        vals: np.ndarray) -> None:
    # Write EDA / TEMP / BVP E4 format
    ensure_dir(os.path.dirname(path))
    with open(path, "w", encoding="utf-8") as f:
        f.write(f"{t0:.12f}\n")
        f.write(f"{fs:.12f}\n")
        for v in np.asarray(vals, dtype=float):
            f.write("NA\n" if np.isnan(v) else f"{v}\n")


def write_signal_acc(path: str,
                     t0: float,
                     fs: float,
                     x: np.ndarray,
                     y: np.ndarray,
                     z: np.ndarray) -> None:
    # Write ACC E4 format (3 axes)
    ensure_dir(os.path.dirname(path))
    with open(path, "w", encoding="utf-8") as f:
        f.write(",".join([f"{t0:.12f}"] * 3) + "\n")
        f.write(",".join([f"{fs:.12f}"] * 3) + "\n")
        n = min(len(x), len(y), len(z))
        for i in range(n):
            def fmt(v):
                return "NA" if np.isnan(v) else f"{v}"
            f.write(f"{fmt(x[i])},{fmt(y[i])},{fmt(z[i])}\n")


def write_signal_ibi(path: str,
                     t0: float,
                     offsets: np.ndarray,
                     ibis: np.ndarray) -> None:
    # Write IBI E4 format (event-based)
    ensure_dir(os.path.dirname(path))
    with open(path, "w", encoding="utf-8") as f:
        f.write(f"{t0:.6f},IBI\n")
        for o, i in zip(offsets, ibis):
            f.write(f"{o:.10f},{i:.10f}\n")


# -------------------- Segmentation structures --------------------
@dataclass
class Part1D:
    t0: float
    fs: float
    vals: np.ndarray
    n_blocks: int
    splits: str
    trimmed: int


@dataclass
class PartACC:
    t0: float
    fs: float
    x: np.ndarray
    y: np.ndarray
    z: np.ndarray
    n_blocks: int
    splits: str
    trimmed: int


# -------------------- Segmentation logic --------------------
def segment_items_1d(items,
                     gap_split_sec: float,
                     check_fs: bool) -> list[Part1D]:
    # Segment uniformly sampled 1D signals
    if not items:
        return []

    items = sorted(items, key=lambda d: d["t0"])
    parts = []

    cur_vals = None
    cur_t0 = None
    cur_fs = None
    cur_blocks = 0
    split_reasons = []
    total_trim = 0

    def flush():
        nonlocal cur_vals, cur_t0, cur_fs, cur_blocks, split_reasons, total_trim
        if cur_vals is None:
            return
        parts.append(
            Part1D(
                t0=cur_t0,
                fs=cur_fs,
                vals=np.asarray(cur_vals),
                n_blocks=cur_blocks,
                splits=" | ".join(split_reasons),
                trimmed=total_trim,
            )
        )
        cur_vals = None
        cur_t0 = None
        cur_fs = None
        cur_blocks = 0
        split_reasons = []
        total_trim = 0

    for it in items:
        t0, fs, vals = it["t0"], it["fs"], np.asarray(it["vals"])

        if cur_vals is None:
            cur_t0, cur_fs = t0, fs
            cur_vals = vals
            cur_blocks = 1
            continue

        if check_fs and not same_fs(cur_fs, fs):
            split_reasons.append("fs_change")
            flush()
            cur_t0, cur_fs = t0, fs
            cur_vals = vals
            cur_blocks = 1
            continue

        t_final_prev = final_timestamp(cur_t0, len(cur_vals), cur_fs)
        gap = t0 - t_final_prev
        eps = gap_epsilon(cur_fs)

        if gap > gap_split_sec:
            split_reasons.append(f"gap>{int(gap_split_sec)}s")
            flush()
            cur_t0, cur_fs = t0, fs
            cur_vals = vals
            cur_blocks = 1
        elif gap < -eps:
            drop = trim_overlap_samples(t_final_prev, t0, fs)
            total_trim += drop
            if drop < len(vals):
                cur_vals = np.concatenate([cur_vals, vals[drop:]])
                cur_blocks += 1
        else:
            cur_vals = np.concatenate([cur_vals, vals])
            cur_blocks += 1

    flush()
    return parts


def segment_items_acc(items,
                      gap_split_sec: float) -> list[PartACC]:
    # Segment ACC signals (already resampled to target fs)
    if not items:
        return []

    items = sorted(items, key=lambda d: d["t0"])
    parts = []

    cur_x = cur_y = cur_z = None
    cur_t0 = None
    cur_fs = FS_TARGET_ACC
    cur_blocks = 0
    split_reasons = []
    total_trim = 0

    def flush():
        nonlocal cur_x, cur_y, cur_z, cur_t0, cur_blocks, split_reasons, total_trim
        if cur_x is None:
            return
        parts.append(
            PartACC(
                t0=cur_t0,
                fs=cur_fs,
                x=np.asarray(cur_x),
                y=np.asarray(cur_y),
                z=np.asarray(cur_z),
                n_blocks=cur_blocks,
                splits=" | ".join(split_reasons),
                trimmed=total_trim,
            )
        )
        cur_x = cur_y = cur_z = None
        cur_t0 = None
        cur_blocks = 0
        split_reasons = []
        total_trim = 0

    for it in items:
        t0 = it["t0"]
        x, y, z = it["x"], it["y"], it["z"]

        if cur_x is None:
            cur_t0 = t0
            cur_x, cur_y, cur_z = x, y, z
            cur_blocks = 1
            continue

        t_final_prev = final_timestamp(cur_t0, len(cur_x), cur_fs)
        gap = t0 - t_final_prev
        eps = gap_epsilon(cur_fs)

        if gap > gap_split_sec:
            split_reasons.append(f"gap>{int(gap_split_sec)}s")
            flush()
            cur_t0 = t0
            cur_x, cur_y, cur_z = x, y, z
            cur_blocks = 1
        elif gap < -eps:
            drop = trim_overlap_samples(t_final_prev, t0, cur_fs)
            total_trim += drop
            if drop < len(x):
                cur_x = np.concatenate([cur_x, x[drop:]])
                cur_y = np.concatenate([cur_y, y[drop:]])
                cur_z = np.concatenate([cur_z, z[drop:]])
                cur_blocks += 1
        else:
            cur_x = np.concatenate([cur_x, x])
            cur_y = np.concatenate([cur_y, y])
            cur_z = np.concatenate([cur_z, z])
            cur_blocks += 1

    flush()
    return parts


# -------------------- Load data and select subject --------------------
with open(in_all, "rb") as f:
    all_data = pickle.load(f)

if not isinstance(all_data, dict) or not all_data:
    raise ValueError("all_data must be a non-empty dict")

patients = sorted(all_data.keys())
patient_id = patients[0] if patient_id_target is None else patient_id_target

if patient_id not in all_data:
    raise ValueError(f"Subject not found: {patient_id}")

blocks = all_data[patient_id]
if not blocks:
    raise ValueError("Selected subject has no data blocks")

# Output directories
ensure_dir(out_root)
pat_dir = os.path.join(out_root, patient_id)
ensure_dir(pat_dir)
for s in ["EDA", "TEMP", "BVP", "ACC", "IBI"]:
    ensure_dir(os.path.join(pat_dir, s))

manifest_rows = []


# ============================================================
# EDA
# ============================================================
eda_items = []
for b in blocks:
    eda = b.get("eda")
    if eda and eda.get("values"):
        t0 = float(eda["timestampStart"]) / 1e6
        fs = float(eda["samplingFrequency"])
        eda_items.append({"t0": t0, "fs": fs, "vals": eda["values"]})

for i, p in enumerate(segment_items_1d(eda_items,
                                       gap_split_sec_uniform,
                                       check_fs=True), 1):
    fname = os.path.join("EDA", f"part{i:02d}.csv")
    write_signal_single(os.path.join(pat_dir, fname), p.t0, p.fs, p.vals)
    dur = len(p.vals) / p.fs
    manifest_rows.append({
        "sensor": "EDA",
        "part": i,
        "filename": fname,
        "t0_epoch": p.t0,
        "t0_iso": epoch_to_iso(p.t0),
        "fs": p.fs,
        "n": len(p.vals),
        "duration_sec": dur,
        "duration_hms": format_hms(dur),
        "n_blocks": p.n_blocks,
        "splits": p.splits,
        "overlap_trim_samples": p.trimmed,
    })


# ============================================================
# TEMP
# ============================================================
temp_items = []
for b in blocks:
    t = b.get("temperature")
    if t and t.get("values"):
        t0 = float(t["timestampStart"]) / 1e6
        fs_o = float(t["samplingFrequency"])
        vals = np.asarray(t["values"], dtype=float)
        if abs(fs_o - FS_TARGET_TEMP) < 1e-12 or math.isnan(fs_o):
            vals_rs = vals
        else:
            vals_rs = resample_signal(vals, fs_o, FS_TARGET_TEMP)
        temp_items.append({"t0": t0, "fs": FS_TARGET_TEMP, "vals": vals_rs})

for i, p in enumerate(segment_items_1d(temp_items,
                                       gap_split_sec_uniform,
                                       check_fs=False), 1):
    fname = os.path.join("TEMP", f"part{i:02d}.csv")
    write_signal_single(os.path.join(pat_dir, fname), p.t0, p.fs, p.vals)
    dur = len(p.vals) / p.fs
    manifest_rows.append({
        "sensor": "TEMP",
        "part": i,
        "filename": fname,
        "t0_epoch": p.t0,
        "t0_iso": epoch_to_iso(p.t0),
        "fs": p.fs,
        "n": len(p.vals),
        "duration_sec": dur,
        "duration_hms": format_hms(dur),
        "n_blocks": p.n_blocks,
        "splits": p.splits,
        "overlap_trim_samples": p.trimmed,
    })


# ============================================================
# BVP
# ============================================================
bvp_items = []
for b in blocks:
    bvp = b.get("bvp")
    if bvp and bvp.get("values"):
        t0 = float(bvp["timestampStart"]) / 1e6
        fs = float(bvp["samplingFrequency"])
        bvp_items.append({"t0": t0, "fs": fs, "vals": bvp["values"]})

for i, p in enumerate(segment_items_1d(bvp_items,
                                       gap_split_sec_uniform,
                                       check_fs=True), 1):
    fname = os.path.join("BVP", f"part{i:02d}.csv")
    write_signal_single(os.path.join(pat_dir, fname), p.t0, p.fs, p.vals)
    dur = len(p.vals) / p.fs
    manifest_rows.append({
        "sensor": "BVP",
        "part": i,
        "filename": fname,
        "t0_epoch": p.t0,
        "t0_iso": epoch_to_iso(p.t0),
        "fs": p.fs,
        "n": len(p.vals),
        "duration_sec": dur,
        "duration_hms": format_hms(dur),
        "n_blocks": p.n_blocks,
        "splits": p.splits,
        "overlap_trim_samples": p.trimmed,
    })


# ============================================================
# ACC
# ============================================================
acc_items = []
for b in blocks:
    acc = b.get("accelerometer")
    if not acc:
        continue

    imu = acc.get("imuParams", {}) or {}
    try:
        factor = ((float(imu["physicalMax"]) - float(imu["physicalMin"])) /
                  (float(imu["digitalMax"]) - float(imu["digitalMin"])))
    except Exception:
        continue

    fs_o = float(acc.get("samplingFrequency", float("nan")))
    t0 = float(acc.get("timestampStart")) / 1e6

    x = np.asarray(acc.get("x", []), dtype=float) * factor * 64.0
    y = np.asarray(acc.get("y", []), dtype=float) * factor * 64.0
    z = np.asarray(acc.get("z", []), dtype=float) * factor * 64.0
    if not len(x):
        continue

    if not math.isnan(fs_o) and abs(fs_o - FS_TARGET_ACC) > 1e-12:
        x = resample_with_antialias(x, fs_o, FS_TARGET_ACC)
        y = resample_with_antialias(y, fs_o, FS_TARGET_ACC)
        z = resample_with_antialias(z, fs_o, FS_TARGET_ACC)

    n = min(len(x), len(y), len(z))
    acc_items.append({"t0": t0, "x": x[:n], "y": y[:n], "z": z[:n]})

for i, p in enumerate(segment_items_acc(acc_items,
                                        gap_split_sec_uniform), 1):
    fname = os.path.join("ACC", f"part{i:02d}.csv")
    write_signal_acc(os.path.join(pat_dir, fname),
                     p.t0, p.fs, p.x, p.y, p.z)
    dur = len(p.x) / p.fs
    manifest_rows.append({
        "sensor": "ACC",
        "part": i,
        "filename": fname,
        "t0_epoch": p.t0,
        "t0_iso": epoch_to_iso(p.t0),
        "fs": p.fs,
        "n": len(p.x),
        "duration_sec": dur,
        "duration_hms": format_hms(dur),
        "n_blocks": p.n_blocks,
        "splits": p.splits,
        "overlap_trim_samples": p.trimmed,
    })


# ============================================================
# IBI
# ============================================================
peaks = []
for b in blocks:
    sp = b.get("systolicPeaks")
    if sp and sp.get("peaksTimeNanos"):
        peaks.extend([float(x) / 1e9 for x in sp["peaksTimeNanos"]])

peaks = np.unique(np.asarray(peaks))
peaks.sort()

if len(peaks) > 1:
    gaps = np.diff(peaks)
    cut_idx = np.where(gaps > gap_split_sec_ibi)[0]
    starts = np.r_[0, cut_idx + 1]
    ends = np.r_[cut_idx, len(peaks) - 1]

    part = 0
    for s, e in zip(starts, ends):
        idx = np.arange(s, e + 1)
        if len(idx) <= 1:
            continue
        part += 1
        pk = peaks[idx]
        t0 = pk[0]
        offsets = pk[1:] - t0
        ibis = np.diff(pk)

        fname = os.path.join("IBI", f"part{part:02d}.csv")
        write_signal_ibi(os.path.join(pat_dir, fname),
                         t0, offsets, ibis)

        dur = offsets.max() if len(offsets) else 0.0
        manifest_rows.append({
            "sensor": "IBI",
            "part": part,
            "filename": fname,
            "t0_epoch": t0,
            "t0_iso": epoch_to_iso(t0),
            "fs": np.nan,
            "n": len(ibis),
            "duration_sec": dur,
            "duration_hms": format_hms(dur),
            "n_blocks": np.nan,
            "splits": "" if part == 1 else f"gap>{gap_split_sec_ibi}s",
            "overlap_trim_samples": np.nan,
        })


# -------------------- Save manifest and README --------------------
manifest = pd.DataFrame(manifest_rows)
manifest.to_csv(os.path.join(pat_dir, "_manifest.csv"), index=False)

readme_txt = "\n".join([
    "Units per sensor:",
    "- EDA: microSiemens (µS), no resampling.",
    "- TEMP: Celsius (°C), resampled 1 Hz -> 4 Hz (linear interpolation).",
    "- BVP: relative PPG units, no resampling.",
    "- ACC: g×64 units, resampled 64 Hz -> 32 Hz with antialiasing.",
    "- IBI: seconds, event-based.",
    "",
    "Segmentation rules:",
    f"- EDA/TEMP/BVP/ACC: new part if gap > {gap_split_sec_uniform} s.",
    f"- IBI: new part if inter-peak gap > {gap_split_sec_ibi} s.",
    "",
    "Tolerances:",
    "- Sampling rate considered equal if |Δfs| <= max(0.01 Hz, 0.25%).",
    "- Gaps <= 1/fs are treated as continuous.",
])

with open(os.path.join(pat_dir, "_README_UNITS.txt"), "w", encoding="utf-8") as f:
    f.write(readme_txt + "\n")

print(f"Export completed for subject {patient_id}")
print(f"Output folder: {pat_dir}")
print(f"Manifest file: {os.path.join(pat_dir, '_manifest.csv')}")

gc.collect()