# MIDAS Heart/Resp Motion Pipeline

This notebook separates respiratory and cardiac motion, detects beats, and builds
time-series features for grouping into four conditions (control, doxo, doxo+epa, other).


## 0) Setup
Load data from `./data`, estimate sampling rate, and define parameters.


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.optimize import linear_sum_assignment
from scipy.signal import butter, sosfiltfilt, welch, find_peaks, hilbert
from sklearn.linear_model import RidgeClassifierCV
from sklearn.model_selection import GroupKFold, cross_val_score
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

try:
    from sktime.transformations.panel.rocket import MiniRocket
except Exception:
    MiniRocket = None

try:
    import hdbscan
except Exception:
    hdbscan = None


In [None]:
DATA_DIR = Path("../data")
RESP_BAND_HZ = (0.1, 2.0)
HEART_BAND_HZ = (2.5, 15.0)
REFRACTORY_S = 0.20
BEAT_WINDOW_S = (0.15, 0.25)
RESAMPLE_LEN = 128


## 1) Load data and estimate sampling rate


In [None]:
@dataclass
class Record:
    path: Path
    label: str
    time_s: np.ndarray
    signal: np.ndarray
    fs: float
    group_id: str


def label_from_name(name: str) -> str:
    lower = name.lower()
    if lower.startswith("control"):
        return "control"
    if lower.startswith("doxo"):
        return "doxo"
    if lower.startswith("empa_doxo") or lower.startswith("preconditionare_empa_doxo"):
        return "empa_doxo"
    if lower.startswith("empa"):
        return "empa"
    return "other"


def load_records(data_dir: Path) -> list[Record]:
    records: list[Record] = []
    for path in sorted(data_dir.glob("*.csv")):
        df = pd.read_csv(path)
        time_s = df.iloc[:, 0].to_numpy(dtype=float)
        signal = df.iloc[:, 1].to_numpy(dtype=float)
        dt = np.diff(time_s)
        fs = 1.0 / float(np.median(dt)) if len(dt) else 0.0
        records.append(
            Record(
                path=path,
                label=label_from_name(path.stem),
                time_s=time_s,
                signal=signal,
                fs=fs,
                group_id=path.stem,
            )
        )
    return records


records = load_records(DATA_DIR)
{rec.path.name: rec.fs for rec in records}


## 2) Spectrum inspection (Welch PSD)


In [None]:
rec = records[0]
f, pxx = welch(rec.signal, fs=rec.fs, nperseg=min(2048, len(rec.signal)))
plt.figure(figsize=(8, 4))
plt.semilogy(f, pxx)
plt.title(f"PSD: {rec.path.name}")
plt.xlabel("Hz")
plt.ylabel("Power")
plt.xlim(0, 20)
plt.show()


## 3) Separate respiration and heart (zero-phase filtering)


In [None]:
def band_limits(band: tuple[float, float], fs: float) -> tuple[float, float]:
    low, high = band
    nyq = fs / 2.0
    return (max(0.01, low), min(high, 0.49 * nyq))


def butter_sos(band: tuple[float, float], fs: float, order: int = 4) -> np.ndarray:
    low, high = band_limits(band, fs)
    return butter(order, [low, high], btype="bandpass", fs=fs, output="sos")


def lowpass_sos(cutoff: float, fs: float, order: int = 4) -> np.ndarray:
    cutoff = min(cutoff, 0.49 * (fs / 2.0))
    return butter(order, cutoff, btype="lowpass", fs=fs, output="sos")


def separate_components(signal: np.ndarray, fs: float) -> tuple[np.ndarray, np.ndarray]:
    resp_sos = lowpass_sos(RESP_BAND_HZ[1], fs)
    heart_sos = butter_sos(HEART_BAND_HZ, fs)
    resp = sosfiltfilt(resp_sos, signal)
    heart = sosfiltfilt(heart_sos, signal)
    return resp, heart


resp, heart = separate_components(rec.signal, rec.fs)
plt.figure(figsize=(10, 4))
plt.plot(rec.time_s, rec.signal, label="raw", alpha=0.5)
plt.plot(rec.time_s, resp, label="resp")
plt.plot(rec.time_s, heart, label="heart")
plt.legend()
plt.title(f"Separated components: {rec.path.name}")
plt.xlabel("Time (s)")
plt.show()


## 4) Beat detection on heart component


In [None]:
def heart_envelope(heart: np.ndarray) -> np.ndarray:
    return np.abs(hilbert(heart))


def detect_beats(heart: np.ndarray, fs: float) -> tuple[np.ndarray, np.ndarray]:
    env = heart_envelope(heart)
    distance = max(1, int(REFRACTORY_S * fs))
    prominence = np.percentile(env, 75) * 0.5
    peaks, _ = find_peaks(env, distance=distance, prominence=prominence)
    return peaks, env


peaks, env = detect_beats(heart, rec.fs)
plt.figure(figsize=(10, 4))
plt.plot(rec.time_s, env, label="envelope")
plt.plot(rec.time_s[peaks], env[peaks], "rx", label="beats")
plt.legend()
plt.title(f"Beat detection: {rec.path.name}")
plt.xlabel("Time (s)")
plt.show()


## 5) Extract beat windows (fixed length)


In [None]:
def extract_beat_windows(
    signal: np.ndarray,
    peaks: np.ndarray,
    fs: float,
    window_s: tuple[float, float],
    resample_len: int,
) -> list[np.ndarray]:
    pre_s, post_s = window_s
    pre = int(pre_s * fs)
    post = int(post_s * fs)
    windows: list[np.ndarray] = []
    for peak in peaks:
        start = peak - pre
        end = peak + post
        if start < 0 or end >= len(signal):
            continue
        snippet = signal[start:end]
        x_old = np.linspace(0.0, 1.0, num=len(snippet), endpoint=False)
        x_new = np.linspace(0.0, 1.0, num=resample_len, endpoint=False)
        windows.append(np.interp(x_new, x_old, snippet))
    return windows


beat_windows = extract_beat_windows(heart, peaks, rec.fs, BEAT_WINDOW_S, RESAMPLE_LEN)
len(beat_windows)


## Ground-truth groups (from filenames)

Labels are derived from filename prefixes only and are used **only** for evaluation. Clustering models do not see these labels during fitting. The four target groups are:

- control
- doxo (including doxo_re)
- empa
- empa_doxo (including preconditionare_empa_doxo)



## 6) Build dataset (beats -> features -> grouping)


In [None]:
beats: list[np.ndarray] = []
labels: list[str] = []
groups: list[str] = []

for rec in records:
    _, heart = separate_components(rec.signal, rec.fs)
    peaks, _ = detect_beats(heart, rec.fs)
    windows = extract_beat_windows(heart, peaks, rec.fs, BEAT_WINDOW_S, RESAMPLE_LEN)
    beats.extend(windows)
    labels.extend([rec.label] * len(windows))
    groups.extend([rec.group_id] * len(windows))

X = np.stack(beats) if beats else np.empty((0, RESAMPLE_LEN))
y = np.array(labels)
group_ids = np.array(groups)
X.shape


## 7) MiniROCKET baseline (supervised)


In [None]:
if MiniRocket is None or len(X) == 0:
    print("MiniRocket not available or no beats extracted.")
else:
    X3d = X[:, np.newaxis, :]
    rocket = MiniRocket()
    X_feat = rocket.fit_transform(X3d)
    clf = RidgeClassifierCV(alphas=np.logspace(-3, 3, 7))
    unique_groups = np.unique(group_ids)
    if len(unique_groups) >= 3:
        splits = GroupKFold(n_splits=min(5, len(unique_groups))).split(X_feat, y, group_ids)
        scores = cross_val_score(clf, X_feat, y, cv=splits)
        print(f"Group CV accuracy: {scores.mean():.3f} +/- {scores.std():.3f}")
    else:
        clf.fit(X_feat, y)
        print("Trained on full data (insufficient groups for CV).")


## 8) Clustering options (exploratory)


In [None]:
def evaluate_clustering(y_true: np.ndarray, y_pred: np.ndarray, label: str, ignore_noise: bool = False) -> None:
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    if ignore_noise and np.any(y_pred == -1):
        mask = y_pred != -1
        print(f"{label} noise fraction: {1 - mask.mean():.2%}")
        y_true = y_true[mask]
        y_pred = y_pred[mask]
    if len(y_true) == 0 or len(np.unique(y_pred)) < 2:
        print(f"{label}: insufficient clusters for evaluation.")
        return
    ari = adjusted_rand_score(y_true, y_pred)
    nmi = normalized_mutual_info_score(y_true, y_pred)
    print(f"{label} ARI: {ari:.3f} | NMI: {nmi:.3f}")

    true_labels = sorted(set(y_true))
    pred_labels = sorted(set(y_pred))
    true_map = {lab: idx for idx, lab in enumerate(true_labels)}
    pred_map = {lab: idx for idx, lab in enumerate(pred_labels)}
    contingency = np.zeros((len(true_labels), len(pred_labels)), dtype=int)
    for t, p in zip(y_true, y_pred):
        contingency[true_map[t], pred_map[p]] += 1
    row_ind, col_ind = linear_sum_assignment(-contingency)
    matched = contingency[row_ind, col_ind].sum()
    best_acc = matched / len(y_true)
    mapping = {pred_labels[c]: true_labels[r] for r, c in zip(row_ind, col_ind)}
    print(f"{label} best alignment accuracy: {best_acc:.3f}")
    print(f"{label} mapping (cluster -> label): {mapping}")


if MiniRocket is None or len(X) == 0:
    print("MiniRocket not available or no beats extracted.")
else:
    X3d = X[:, np.newaxis, :]
    rocket = MiniRocket()
    X_feat = rocket.fit_transform(X3d)

    kmeans = KMeans(n_clusters=4, n_init=20, random_state=0)
    clusters = kmeans.fit_predict(X_feat)
    unique, counts = np.unique(clusters, return_counts=True)
    print("KMeans cluster counts", dict(zip(unique, counts)))
    evaluate_clustering(y, clusters, label="KMeans")

    if hdbscan is not None:
        clusterer = hdbscan.HDBSCAN(min_cluster_size=10)
        hdb = clusterer.fit_predict(X_feat)
        unique, counts = np.unique(hdb, return_counts=True)
        print("HDBSCAN cluster counts", dict(zip(unique, counts)))
        evaluate_clustering(y, hdb, label="HDBSCAN", ignore_noise=True)
