In [1]:
%load_ext autoreload
%autoreload 2

In [27]:
import os
import glob
import numpy as np
import pandas as pd


def _nearest_index(target_ns: np.ndarray, grid_ns: np.ndarray) -> np.ndarray:
    """
    For each target timestamp, return index of nearest grid timestamp.
    grid_ns must be sorted ascending.
    """
    idx = np.searchsorted(grid_ns, target_ns, side="left")
    idx0 = np.clip(idx - 1, 0, grid_ns.size - 1)
    idx1 = np.clip(idx,     0, grid_ns.size - 1)

    d0 = np.abs(target_ns - grid_ns[idx0])
    d1 = np.abs(target_ns - grid_ns[idx1])
    return np.where(d1 < d0, idx1, idx0)


def load_self_report(self_report_csv: str) -> pd.DataFrame:
    """
    Input format:
      timestamp,PainLevel,Action,Trial
      2025-03-12 10:01:45.587,,Session Started,0
      ...
    Requirements:
      - timestamps are EDT (America/New_York, UTC-4) -> convert to UTC
      - Trial: ffill then bfill
    """
    sr = pd.read_csv(self_report_csv)

    # Parse timestamp as local (America/New_York) then convert to UTC
    t_local = pd.to_datetime(sr["timestamp"], errors="coerce")
    # Use America/New_York to correctly handle DST; March is usually EDT but this is safer.
    t_utc = (t_local
             .dt.tz_localize("America/New_York", ambiguous="infer", nonexistent="shift_forward")
             .dt.tz_convert("UTC"))

    sr["timestamp_utc"] = t_utc
    sr["timestamp_ns"] = (sr["timestamp_utc"]).astype("int64")

    # Fill Trial (ffill then bfill)
    # Keep as numeric if possible, otherwise as string
    if "Trial" in sr.columns:
        trial = sr["Trial"]
        # Try numeric, but don't force
        trial_num = pd.to_numeric(trial, errors="coerce")
        sr["Trial"] = trial_num.ffill().bfill().astype(int)

    return sr


def join_self_report_to_physio(
    physio_csv: str,
    self_report_csv: str,
    out_csv: str,
    keep_action: bool = False,
) -> pd.DataFrame:
    """
    Adds PainLevel + Trial (and optionally Action) to physio df by snapping each
    self-report PainLevel row to one nearest physio row.

    "excluding action" interpreted as: do not include Action in the output
    but keep rows that don't have PainLevel (session markers).
    """
    df = pd.read_csv(physio_csv)
    if "timestamp_ns" not in df.columns:
        raise ValueError(f"{physio_csv} missing timestamp_ns")

    df["timestamp_ns"] = df["timestamp_ns"].astype("int64")
    df = df.sort_values("timestamp_ns", kind="mergesort").reset_index(drop=True)
    grid_ns = df["timestamp_ns"].to_numpy(dtype=np.int64)

    sr = load_self_report(self_report_csv)

    if sr.empty:
        # still write output with empty PainLevel/Trial columns
        df["PainLevel"] = np.nan
        df["Trial"] = np.nan
        os.makedirs(os.path.dirname(out_csv), exist_ok=True)
        df.to_csv(out_csv, index=False)
        return df

    target_ns = sr["timestamp_ns"].to_numpy(dtype=np.int64)
    nearest_idx = _nearest_index(target_ns, grid_ns)

    # Prepare output columns
    if "PainLevel" not in df.columns:
        df["PainLevel"] = np.nan
    if "Trial" not in df.columns:
        df["Trial"] = np.nan
    if keep_action and "Action" not in df.columns:
        df["Action"] = np.nan

    # Assign (if multiple self-report rows map to the same physio row, later ones overwrite)
    df.loc[nearest_idx, "PainLevel"] = sr["PainLevel"].to_numpy()
    if "Trial" in sr.columns:
        df.loc[nearest_idx, "Trial"] = sr["Trial"].to_numpy()
    if keep_action and "Action" in sr.columns:
        df.loc[nearest_idx, "Action"] = sr["Action"].to_numpy()

    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    df['Trial'] = df['Trial'].ffill().bfill().astype(np.int32)
    df.to_csv(out_csv, index=False)
    return df


def batch_join(
    merged_dir: str,
    self_report_dir: str,
    out_dir: str,
    merged_glob: str = "*_merged_64hz.csv",
):
    """
    Example directories:
      merged_dir: ./processed_data/combined
      self_report_dir: ./self_report
      out_dir: ./processed_data/combined_with_self_report
    """
    os.makedirs(out_dir, exist_ok=True)

    merged_paths = sorted(glob.glob(os.path.join(merged_dir, merged_glob)))
    if not merged_paths:
        raise FileNotFoundError(f"No merged files found in {merged_dir} with glob {merged_glob}")

    for phys_path in merged_paths:
        base = os.path.basename(phys_path)
        subject_id = base.split("_")[0]  # e.g., "003" from "003_merged_64hz.csv"

        # Try a few common self-report naming patterns
        candidates = [
            os.path.join(self_report_dir, f"{subject_id}.csv"),
            os.path.join(self_report_dir, f"{subject_id}_self_report.csv"),
            os.path.join(self_report_dir, f"{subject_id}_selfreport.csv"),
        ]
        sr_path = next((p for p in candidates if os.path.exists(p)), None)
        if sr_path is None:
            # fallback: any file containing subject_id
            fallback = sorted(glob.glob(os.path.join(self_report_dir, f"*{subject_id}*.csv")))
            sr_path = fallback[0] if fallback else None

        if sr_path is None:
            print(f"[WARN] no self report for subject {subject_id}; skipping")
            continue

        out_path = os.path.join(out_dir, f"{subject_id}_merged_64hz_with_self_report.csv")
        print(f"[OK] {subject_id}: {phys_path} + {sr_path} -> {out_path}")
        join_self_report_to_physio(phys_path, sr_path, out_path, keep_action=False)



In [29]:
batch_join("./processed_data/per_subject", "./original_files/self_report", "./processed_data/with_self_report/per_subject")

[OK] 001: ./processed_data/per_subject/001_merged_64hz.csv + ./original_files/self_report/001.csv -> ./processed_data/with_self_report/per_subject/001_merged_64hz_with_self_report.csv
[OK] 002: ./processed_data/per_subject/002_merged_64hz.csv + ./original_files/self_report/002.csv -> ./processed_data/with_self_report/per_subject/002_merged_64hz_with_self_report.csv
[OK] 003: ./processed_data/per_subject/003_merged_64hz.csv + ./original_files/self_report/003.csv -> ./processed_data/with_self_report/per_subject/003_merged_64hz_with_self_report.csv
[OK] 004: ./processed_data/per_subject/004_merged_64hz.csv + ./original_files/self_report/004.csv -> ./processed_data/with_self_report/per_subject/004_merged_64hz_with_self_report.csv
[OK] 005: ./processed_data/per_subject/005_merged_64hz.csv + ./original_files/self_report/005.csv -> ./processed_data/with_self_report/per_subject/005_merged_64hz_with_self_report.csv
[OK] 006: ./processed_data/per_subject/006_merged_64hz.csv + ./original_files/se