In [None]:
import h5py
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import re

from plotly.subplots import make_subplots
from plotnine import *
from scipy.signal import butter, filtfilt, iirnotch


def reademg(
    filename: str,
    device: str = "98:D3:C1:FE:04:75",
    fs: int = 1000
) -> pd.DataFrame:
    """
    Purpose: Load raw EMG data from an HDF5 file and organize it into a pandas DataFrame

    Parameters:
        - filename (str): Path to the .h5 file containing EMG recordings
        - device (str): Unique Bluetooth device address used to identify the dataset in the HDF5 file
        - fs (int): Sampling frequency in Hz (used to generate timestamp column)

    Returns:
        - pd.DataFrame: DataFrame with EMG channels 'a1', 'a4' and a 'ts' (timestamp) column
    """
    raw = h5py.File(filename)[device]["raw"]
    df = pd.DataFrame({
        "a1": raw["channel_1"][:, 0],
        "a2": raw["channel_2"][:, 0],
        "a3": raw["channel_3"][:, 0],
        "a4": raw["channel_4"][:, 0]
    })
    num_rows = df.shape[0]
    df['ts'] = pd.Series(range(num_rows)) / fs
    return preprocess_emg(df)


def bandpass_filter(signal, lowcut=20, highcut=400, fs=1000, order=10):
    """
    Purpose: Apply a Butterworth bandpass filter to isolate relevant EMG frequency range

    Parameters:
        - signal (np.ndarray): 1D NumPy array of raw signal
        - lowcut (int): Lower cutoff frequency in Hz
        - highcut (int): Upper cutoff frequency in Hz
        - fs (int): Sampling frequency in Hz
        - order (int): Order of the Butterworth filter (higher = sharper)

    Returns:
        - np.ndarray: Bandpass-filtered signal
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, signal)


def notch_filter(signal, freq=60, fs=1000, Q=30):
    """
    Purpose: Remove powerline noise (e.g. 60 Hz in the US) using a notch filter

    Parameters:
        - signal (np.ndarray): 1D NumPy array of signal
        - freq (int): Notch frequency (typically 50 or 60 Hz)
        - fs (int): Sampling frequency in Hz
        - Q (int): Quality factor (higher = narrower notch)

    Returns:
        - np.ndarray: Notch-filtered signal
    """
    w0 = freq / (fs / 2)
    b, a = iirnotch(w0, Q)
    return filtfilt(b, a, signal)


def moving_rms(signal, window_size=100):
    """
    Purpose: Compute the root mean square (RMS) over a moving window
    This is used instead of simple rectification + smoothing for envelope detection

    Parameters:
        - signal (np.ndarray): 1D NumPy array of EMG signal
        - window_size (int): Number of samples over which RMS is calculated

    Returns:
        - np.ndarray: RMS-processed signal of same length as input
    """
    padded = np.pad(signal, (window_size//2, window_size//2), mode='reflect')
    squared = padded**2
    window = np.ones(window_size)
    rms = np.sqrt(np.convolve(squared, window / window_size, mode='valid'))
    if len(rms) > len(signal):
        rms = rms[:len(signal)]
    elif len(rms) < len(signal):
        rms = np.pad(rms, (0, len(signal) - len(rms)), mode='edge')
    return rms


# def moving_rms(x: np.ndarray, window_size: int = 20) -> np.ndarray:
#     """Compute moving RMS over a window for a 1D array."""
#     return np.sqrt(np.convolve(x**2, np.ones(window_size) / window_size, mode="same"))


def preprocess_emg(df, fs: int = 1000) -> pd.DataFrame:
    """
    Purpose: Preprocess all EMG channels in a DataFrame:
      - Bandpass filter
      - Notch filter
      - RMS envelope extraction

    Parameters:
        - df (pd.DataFrame): DataFrame containing columns like a1, a2, ..., aN
        - fs (int): Sampling frequency in Hz

    Returns:
        - pd.DataFrame: DataFrame with processed EMG signals
    """
    processed_df = df.copy()
    pattern = re.compile(r'^a\d+$')  # Matches 'a1', 'a2', ..., 'a99' etc.

    for col in df.columns:
        if pattern.match(col):
            raw_signal = df[col].values
            filtered = bandpass_filter(raw_signal, fs=fs)
            notch_removed = notch_filter(filtered, fs=fs)
            rectified = moving_rms(notch_removed)
            processed_df[col] = rectified

    return processed_df


def sync_emg_to_angles(
    df_emg: pd.DataFrame,
    df_angles: pd.DataFrame,
    tall_format: bool = False,
    trim_emg: tuple[float, float] = (0.0, 0.0),
    fs_angles: int = 60,
    window_size: int = 20,
) -> pd.DataFrame:
    """
    Align EMG data to joint angle data and add lag 0 RMS features.

    Parameters:
        - df_emg (pd.DataFrame): EMG data with 'ts' column and EMG channels (e.g., a1, a2, ...)
        - df_angles (pd.DataFrame): Joint angle data (without timestamp)
        - tall_format (bool): If True, return data in long format
        - trim_emg (tuple): Start/end trim in seconds
        - fs_angles (int): Sampling rate of angles
        - window_size (int): Window size for moving RMS

    Returns:
        - pd.DataFrame: Synced and optionally transformed DataFrame
    """
    trim_start, trim_end = trim_emg

    # Trim EMG data based on time
    ts_start = df_emg["ts"].iloc[0] + trim_start
    ts_end = df_emg["ts"].iloc[-1] - trim_end
    df_emg = df_emg[(df_emg["ts"] >= ts_start) & (df_emg["ts"] <= ts_end)].reset_index(
        drop=True
    )

    # Calculate new duration after trimming
    duration = df_emg["ts"].iloc[-1] - df_emg["ts"].iloc[0]
    n_samples = df_angles.shape[0]

    # Generate target timestamps for joint angle alignment
    angle_timestamps = np.linspace(0, duration, n_samples)
    target_ts = angle_timestamps + ts_start
    df_emg_interp = pd.DataFrame({"ts": target_ts})

    # Interpolate and compute lag 0 RMS for each EMG channel
    for col in df_emg.columns:
        if col.startswith("a"):
            # Interpolated EMG signal
            interp_emg = np.interp(target_ts, df_emg["ts"], df_emg[col])
            rms = moving_rms(interp_emg, window_size=window_size)
            # df_emg_interp[f"{col}_lag_0"] = rms
            for lag in range(5):  # lags 0 to 4
                if lag > 0:
                    lagged_rms = np.roll(rms, lag * window_size)
                    lagged_rms[: lag * window_size] = 0  # zero padding
                else:
                    lagged_rms = rms
                df_emg_interp[f"{col}_lag_{lag}"] = lagged_rms

    # Add timestamps to angle data
    df_angles_with_ts = df_angles.copy()
    df_angles_with_ts["ts"] = target_ts
    df_angles_with_ts = df_angles_with_ts.drop(columns="Frame", errors="ignore")

    # Merge angle and EMG data
    df_merged = pd.concat(
        [df_angles_with_ts.reset_index(drop=True), df_emg_interp.drop(columns="ts")],
        axis=1,
    )

    # Make 'ts' the first column
    cols = ["ts"] + [col for col in df_merged.columns if col != "ts"]
    df_merged = df_merged[cols]

    # Lowercase column names
    df_merged.columns = df_merged.columns.str.lower()
    df_merged = df_merged.dropna()

    # Optional tall format
    if tall_format:
        return pd.melt(
            df_merged, id_vars=["ts"], var_name="channel", value_name="value"
        ).reset_index(drop=True)
    else:
        return df_merged

In [36]:
df_emg = reademg("data/emg/April07_2025_trial1_int5s.h5")
df_angles = pd.read_csv("data/joint_angles/April07_2025_trial1_int5s_angles.csv")

df_synced = sync_emg_to_angles(df_emg, df_angles, trim_emg=(10.25, 21.75), tall_format=False)
df_synced.to_parquet("data/synced/trial1_5s_10.25_21.75_ws100.parquet")
df_synced = sync_emg_to_angles(df_emg, df_angles, trim_emg=(11.25, 20.75), tall_format=False)
df_synced.to_parquet("data/synced/trial1_5s_11.25_20.75_ws100.parquet")
df_synced = sync_emg_to_angles(df_emg, df_angles, trim_emg=(9.25, 22.75), tall_format=False)
df_synced.to_parquet("data/synced/trial1_5s_09.25_22.75_ws100.parquet")

In [37]:
df_synced.shape[0]

22574

In [39]:
df_synced["index_mcp_scaled"] = df_synced["index_mcp"] / 10
df_synced["middle_mcp_scaled"] = df_synced["middle_mcp"] / 10


# Create subplot with 2 rows and shared x-axis
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, subplot_titles=[
    "a2 + index_mcp_scaled", "a3 + middle_mcp_scaled"
])

# Row 1: a3 and index_mcp_scaled
fig.add_trace(go.Scatter(x=df_synced['ts'], y=df_synced['a2_lag_0'], mode='lines', name='a3', line=dict(color='blue')), row=1, col=1)
fig.add_trace(go.Scatter(x=df_synced['ts'], y=df_synced['index_mcp_scaled'], mode='lines', name='index_mcp_scaled', line=dict(color='green')), row=1, col=1)

# Row 2: a3 and middle_mcp_scaled
fig.add_trace(
    go.Scatter(
        x=df_synced["ts"],
        y=df_synced["a3_lag_0"],
        mode="lines",
        name="a3",
        showlegend=False,
        line=dict(color="blue"),
    ),
    row=2,
    col=1,
)
fig.add_trace(go.Scatter(x=df_synced['ts'], y=df_synced['middle_mcp_scaled'], mode='lines', name='middle_mcp_scaled', line=dict(color='orange')), row=2, col=1)

# Layout options
fig.update_layout(
    height=600,
    title="EMG Channel a3 vs MCP Angles",
    legend=dict(orientation='h', y=-0.2, x=0.5, xanchor='center'),
    margin=dict(t=50, b=80),
)

fig.show()