In [None]:
import pandas as pd
import plotly.express as px
import numpy as np
from scipy.signal import correlate
from itertools import combinations

In [None]:
import sys
from pathlib import Path

# Add the parent directory (repo root) to Python path
repo_root = Path.cwd().parent
sys.path.insert(0, str(repo_root))

# Force reload the module to get the latest version
import importlib
import src.utils.data_processing
importlib.reload(src.utils.data_processing)

from src.utils.data_processing import process_file, prepare_data_for_analysis

In [None]:
# Get FIT files from the sandbox/data directory
data_dir = Path.cwd() / "data"
fit_files = list(data_dir.glob("*.fit"))

print(f"Found {len(fit_files)} FIT files:")
for f in fit_files:
    print(f"    {f.name}")

# Process the files - use the full path objects directly
session_record_list = [process_file(str(file)) for file in fit_files]

In [None]:
session_record_list[0]["records"].head()  # Display the first few rows of the records DataFrame from the first file

In [None]:
df_list = [prepare_data_for_analysis(record["records"], "heart_rate") for record in session_record_list]

df = pd.concat(df_list, ignore_index=True).reset_index(drop=True)

df

In [None]:
# Create line chart of heart rate by elapsed time, colored by filename
fig = px.line(df,
              x='elapsed_seconds',
              y='heart_rate',
              color='filename',
              labels={
                  'elapsed_seconds': 'Elapsed Time (seconds)',
                  'heart_rate': 'Heart Rate (bpm)',
                  'filename': 'File'
              })

fig.show()

In [None]:
def z_norm(x, eps=1e-8):
    x = np.asarray(x, dtype=np.float32)
    m, s = np.mean(x), np.std(x)
    if s < eps:  # protect against nearly-constant signals
        return (x - m) * 0.0
    return (x - m) / s


def raw_hr_by_file(df, value_col="heart_rate", file_col="filename"):
    """
    Return dict: filename -> np.array(raw hr).
    """
    hr_by_file = {}
    for fname, g in df.groupby(file_col, sort=False):
        x = g[value_col].dropna().values
        if x.size == 0:
            continue
        hr_by_file[fname] = x
    if not hr_by_file:
        raise ValueError("No non-empty heart_rate series found.")
    return hr_by_file


def normalize_hr_by_file(hr_by_file):
    """
    Return dict: filename -> np.array(z-normalized hr).
    Takes the output of raw_hr_by_file and applies z_norm to each array.
    """
    return {fname: z_norm(x) for fname, x in hr_by_file.items() if x.size > 0}


def pairwise_xcorr(hr_dict, max_lag=None):
    """
    Compute pairwise cross-correlation best lag & score for all files.
    Uses the minimum length of each pair (not global min).
    Assumes signals already z-normalized.
    """
    names = list(hr_dict.keys())
    N = len(names)

    lag_mat = np.zeros((N, N), dtype=int)
    corr_mat = np.zeros((N, N), dtype=np.float32)

    # self vs self
    for i in range(N):
        lag_mat[i, i] = 0
        corr_mat[i, i] = 1.0

    # all pairs
    for (i, xi), (j, xj) in combinations(enumerate([hr_dict[n] for n in names]), 2):
        # pair-specific min length
        L = min(len(xi), len(xj))
        xi_ = xi[:L]
        xj_ = xj[:L]

        lags_full = np.arange(-L + 1, L, dtype=int)
        if max_lag is not None:
            keep = (lags_full >= -max_lag) & (lags_full <= max_lag)
        else:
            keep = slice(None)

        c = correlate(xi_, xj_, mode="full")
        c = c[keep]
        lags = lags_full[keep]

        # normalize correlation to [-1, 1] (since z_norm applied already)
        c_norm = c / float(L)

        k = np.argmax(c_norm)
        best_lag = int(lags[k])
        best_corr = float(c_norm[k])

        lag_mat[i, j] = best_lag
        corr_mat[i, j] = best_corr
        lag_mat[j, i] = -best_lag
        corr_mat[j, i] = best_corr

    meta = {"order": names}
    return lag_mat, corr_mat, meta


def summarize_pairs(lag_mat, corr_mat, meta, fs=1.0):
    """Print human-friendly summary of best lags for every unique pair."""
    names = meta["order"]
    N = len(names)
    for i in range(N):
        for j in range(i + 1, N):
            lag = lag_mat[i, j]
            corr = corr_mat[i, j]
            print(
                f"{names[i]} vs {names[j]}  |  "
                f"best lag = {lag:+d} samples "
                f"({lag / fs:+.2f} s)  |  corr = {corr:.3f}"
            )


def mean_pairwise_corr(corr_mat, meta):
    names = meta["order"]
    corr_df = pd.DataFrame(corr_mat, index=names, columns=names)

    # Mean correlation with all others (excluding self)
    mean_corr = corr_df.apply(lambda row: (row.sum() - 1) / (len(row) - 1), axis=1)

    return mean_corr.sort_values(ascending=False)

In [None]:
def plot_pairwise_aligned(hr_by_file_raw, lag_mat, meta, corr_mat=None, fs=1.0):
    """
    Plot raw HR signals pairwise, aligned using best lags from lag_mat.

    Args:
        hr_by_file_raw : dict[str, np.ndarray]
            Raw HR signals (in bpm), unnormalized.
        lag_mat : np.ndarray
            Best lags computed from z-normalized signals.
        meta : dict
            Contains 'order': list of filenames (same order as lag_mat).
        fs : float
            Sampling frequency (Hz). Default 1.0 (1 sample per second).

    Returns:
        dict[(str,str), plotly.Figure]
    """
    figs = {}
    names = meta["order"]

    for (i, name_i), (j, name_j) in combinations(enumerate(names), 2):
        lag = lag_mat[i, j]

        sig_i = hr_by_file_raw[name_i]
        sig_j = hr_by_file_raw[name_j]

        # apply lag: shift i relative to j
        if lag > 0:
            # i lags j
            sig_i_shifted = sig_i[lag:]
            sig_j_shifted = sig_j[: len(sig_i_shifted)]
        elif lag < 0:
            # i leads j
            sig_j_shifted = sig_j[-lag:]
            sig_i_shifted = sig_i[: len(sig_j_shifted)]
        else:
            # no lag
            L = min(len(sig_i), len(sig_j))
            sig_i_shifted = sig_i[:L]
            sig_j_shifted = sig_j[:L]

        # build aligned time axis
        L = min(len(sig_i_shifted), len(sig_j_shifted))
        t = np.arange(L) / fs

        df_plot = pd.DataFrame(
            {
                "elapsed_seconds": np.tile(t, 2),
                "heart_rate": np.concatenate([sig_i_shifted[:L], sig_j_shifted[:L]]),
                "filename": [name_i] * L + [name_j] * L,
            }
        )

        title = f"{name_i} vs {name_j} | shift: {lag:+d}"
        if corr_mat is not None:
            corr = corr_mat[i, j]
            title += f" | corr: {corr:.3f}"

        fig = px.line(
            df_plot,
            x="elapsed_seconds",
            y="heart_rate",
            color="filename",
            labels={
                "elapsed_seconds": "Elapsed Time (s, shifted)",
                "heart_rate": "Heart Rate (bpm)",
                "filename": "File",
            },
            title=title,
        )
        figs[(name_i, name_j)] = fig

    return figs

In [None]:
def snap_to_threshold(val, thresholds):
    """Snap val down (for min) or up (for max) to nearest threshold."""
    thresholds = np.array(thresholds)
    return thresholds[np.argmin(np.abs(thresholds - val))]


def plot_corr_heatmap(corr_mat, meta):
    names = meta["order"]
    corr_df = pd.DataFrame(corr_mat, index=names, columns=names)

    thresholds = [-1, -0.5, 0, 0.5, 1]

    # snap min/max
    zmin = min(thresholds, key=lambda t: abs(t - np.min(corr_mat)))
    zmax = min(thresholds, key=lambda t: abs(t - np.max(corr_mat)))

    fig = px.imshow(
        corr_df,
        x=names,
        y=names,
        color_continuous_scale="Viridis_r",  # reversed colormap
        text_auto=".2f",
        zmin=zmin,
        zmax=zmax,
        aspect="auto",
        title=f"Pairwise Best Correlation Heatmap (zmin={zmin}, zmax={zmax})",
    )
    fig.update_layout(
        xaxis_title="File",
        yaxis_title="File",
        coloraxis_colorbar=dict(title="Correlation"),
    )
    return fig

In [None]:
# Sampling rate (Hz). If one sample per second, fs = 1.0
fs = 1.0

# Optional: restrict the maximum lag you care about (in samples).
# Set to None for full range, or e.g. 600 for ±10 minutes at 1 Hz.
max_lag = None

# 1) Prepare data
hr_by_file = raw_hr_by_file(df, value_col="heart_rate", file_col="filename")
normalized_hr_by_file = normalize_hr_by_file(hr_by_file)

# 2) Cross-correlation across files (pair-specific min length)
lag_mat, corr_mat, meta = pairwise_xcorr(normalized_hr_by_file, max_lag=max_lag)

# 3) Mean pairwise correlation
mean_corr = mean_pairwise_corr(corr_mat, meta)
print("Mean pairwise correlation (excluding self):")
print(mean_corr)

plot_corr_heatmap(corr_mat, meta).show()

In [None]:
figs = plot_pairwise_aligned(hr_by_file, lag_mat, meta, fs=fs, corr_mat=corr_mat)

In [None]:
for key in list(figs.keys()):
    if "20250918_yq87LlVo.fit" not in key[0] and "20250918_yq87LlVo.fit" not in key[1]:
        figs[key].show()

# for key in list(figs.keys()):
#     if "20250918_yq87LlVo.fit" in key[0] or "20250918_yq87LlVo.fit" in key[1]:
#         figs[key].show()