In [1]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.signal import butter, filtfilt, welch

# --- folder layout ---
ROOT = Path("..")  # notebook lives in notebooks/
DATA_RAW = ROOT / "data/ninapro/db1/raw"
DATA_PROCESSED = ROOT / "data/processed/db1"
DATA_OUTPUTS = ROOT / "data/outputs/db1"
MODELS_DIR = ROOT / "models/db1"
REPORTS_DIR = ROOT / "reports/db1"

# make dirs that will receive outputs
for p in [DATA_PROCESSED, DATA_OUTPUTS, MODELS_DIR, REPORTS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

FS = 100.0  # DB1 sampling (RMS)
WINDOW_S, STEP_S = 0.2, 0.05

In [2]:
def lower_nonmeta_keys(d):
    """
    Convert all keys of a MATLAB-loaded dict to lowercase,
    while ignoring MATLAB metadata keys (those starting with '__').

    Example:
        {'EMG': ..., '__header__': ...} 
        -> {'emg': ...}
    """
    return {k.lower(): v for k, v in d.items() if not k.startswith("__")}


def pick_key(d, name, required=True):
    """
    Retrieve a key from dictionary `d` in a case-insensitive way.
    
    - `name`: the key we want (case-insensitive).
    - If not found and `required=True`, raise KeyError.
    - If not found and `required=False`, return None.
    
    This is useful because MATLAB .mat keys may vary in case ("emg" vs "EMG").
    """
    name = name.lower()
    if name in d:
        return d[name]
    for k in d.keys():
        if k.lower() == name:
            return d[k]
    if required:
        raise KeyError(f"Key '{name}' not in keys: {list(d.keys())}")
    return None


def read_mat(path): 
    """
    Load a MATLAB .mat file into a Python dictionary.
    
    - Uses `scipy.io.loadmat`.
    - Converts struct-like objects to normal records (struct_as_record=False).
    - Removes MATLAB metadata keys and lowercases the rest.
    """
    from scipy.io import loadmat
    return lower_nonmeta_keys(loadmat(path, squeeze_me=True, struct_as_record=False))


def ensure_samples_channels(X):
    """
    Ensure that EMG data has the shape (samples, channels).
    
    - Converts 1D array into column vector.
    - If array shape looks transposed (more channels than samples), transpose it.
    - Always returns float32 for ML consistency.
    
    Example:
        Input shape: (10,)         -> Output shape: (10, 1)
        Input shape: (32, 2000)    -> Output shape: (2000, 32)
    """
    X = np.asarray(X)
    if X.ndim == 1:       # make sure we always have 2D
        X = X[:, None]
    if X.shape[0] < X.shape[1] and X.shape[1] > 32:
        # Heuristic: if rows < cols and cols > 32 (likely #samples),
        # then transpose (so rows = samples, cols = channels).
        X = X.T
    return X.astype(np.float32)

In [3]:
def bandpass_filter(x, fs, low=1.0, high=40.0, order=4):
    """
    Apply a Butterworth bandpass filter to EMG data.

    Args:
        x (ndarray): Input signal of shape (samples, channels).
        fs (float): Sampling rate in Hz.
        low (float): Low cutoff frequency in Hz (default 1 Hz).
        high (float): High cutoff frequency in Hz (default 40 Hz).
        order (int): Filter order (default 4).
    
    Returns:
        ndarray: Filtered signal with the same shape as x.

    Notes:
        - EMG signals often contain low-frequency drift (<1 Hz) and
          high-frequency noise (>40 Hz).
        - Bandpass keeps the physiologically relevant range (1–40 Hz here).
    """
    nyq = fs / 2
    high = min(high, 0.9 * nyq)  # prevent cutoff above Nyquist
    low = max(low, 0.1)          # ensure low > 0
    b, a = butter(order, [low / nyq, high / nyq], btype="band")
    return filtfilt(b, a, x, axis=0)  # zero-phase filter (no lag)


def rectify_and_zscore(x, eps=1e-8):
    """
    Rectify (absolute value) and standardize EMG signals.

    Args:
        x (ndarray): Input signal (samples, channels).
        eps (float): Small constant to avoid division by zero.

    Returns:
        ndarray: Rectified and z-scored signal.
    
    Notes:
        - Rectification: EMG signals oscillate around zero. 
          Taking abs(x) makes all values positive, aligning with
          muscle activity amplitude.
        - Z-score normalization: Each channel is standardized
          (mean=0, std=1) to make channels comparable.
    """
    xr = np.abs(x)  
    m = xr.mean(0, keepdims=True)       # mean per channel
    s = xr.std(0, keepdims=True) + eps  # std per channel
    return (xr - m) / s


def sliding_window(x, y, window_s=WINDOW_S, step_s=STEP_S, fs=FS):
    """
    Segment EMG signals into overlapping sliding windows.

    Args:
        x (ndarray): Preprocessed EMG data (samples, channels).
        y (ndarray): Corresponding label sequence (samples,).
        window_s (float): Window length in seconds.
        step_s (float): Step size in seconds.
        fs (float): Sampling rate.

    Returns:
        Xw (ndarray): Segments of shape (n_windows, win_len, n_channels).
        Yw (ndarray): Window labels (majority / dominant label per window).

    Notes:
        - Iterates through EMG in steps (step_s), extracting fixed windows (window_s).
        - Labels per window: determined by majority vote, with preference for non-zero labels.
        - Produces sequences suitable for ML classifiers.
    """
    win, step = int(window_s * fs), int(step_s * fs)
    Xw, Yw = [], []
    for start in range(0, len(x) - win + 1, step):
        seg = x[start:start + win]
        lab = y[start:start + win]

        # majority label (ignore label 0 if possible)
        vals, counts = np.unique(lab, return_counts=True)
        idx = counts.argsort()[::-1]
        chosen = 0
        for j in idx:
            if vals[j] != 0: 
                chosen = vals[j]
                break

        Xw.append(seg)
        Yw.append(chosen)

    # stack into arrays
    Xw = np.stack(Xw) if Xw else np.empty((0, win, x.shape[1]), dtype=np.float32)
    Yw = np.array(Yw, dtype=np.int32)
    return Xw, Yw

# Time-Domain features

In [4]:
def rms(x, axis=0):
    """Root Mean Square: reflects EMG signal power (muscle contraction intensity)."""
    return np.sqrt((x**2).mean(axis=axis))

def mav(x, axis=0):
    """Mean Absolute Value: average rectified value, proportional to muscle effort."""
    return np.abs(x).mean(axis=axis)

def wl(x, axis=0):
    """Waveform Length: cumulative length of signal waveform, related to signal complexity."""
    return np.sum(np.abs(np.diff(x, axis=axis)), axis=axis)

def zc(x, axis=0, thresh=0.01):
    """
    Zero Crossings: counts how many times the signal changes sign
    (above threshold to avoid noise).
    """
    ch = x if axis == 0 else np.swapaxes(x, 0, axis)
    s  = np.sign(ch)
    d  = np.diff(ch, axis=0)

    # sign change between consecutive samples
    sign_change = (s[1:] * s[:-1]) < 0
    # amplitude difference check
    above_thresh = np.abs(d) > thresh

    crossings = sign_change & above_thresh
    return crossings.sum(axis=0)

def ssc(x, axis=0, thresh=0.01):
    """
    Slope Sign Changes: counts how many times slope changes direction
    (above threshold). Captures signal complexity and frequency.
    """
    ch = x if axis == 0 else np.swapaxes(x, 0, axis)
    d1 = np.diff(ch, axis=0)

    # slope sign change
    sign_change = (d1[1:] * d1[:-1]) < 0
    # magnitude check
    above_thresh = (np.abs(d1[1:] - d1[:-1]) > thresh)

    cond = sign_change & above_thresh
    return cond.sum(axis=0)

def var(x, axis=0):
    """Variance of EMG: another measure of signal power."""
    return x.var(axis=axis)

def std(x, axis=0):
    """Standard Deviation: dispersion of signal amplitude."""
    return x.std(axis=axis)

def iemg(x, axis=0):
    """Integrated EMG: sum of absolute values, often used in biomechanics."""
    return np.abs(x).sum(axis=axis)

def kf(x, axis=0):
    """Kurtosis Factor: sensitive to outliers/spikes."""
    m = x.mean(axis=axis)
    s = x.std(axis=axis)
    return ((x - m)**4).mean(axis=axis) / (s**4 + 1e-8)

def skewness(x, axis=0):
    """Skewness: measures asymmetry of amplitude distribution."""
    m = x.mean(axis=axis)
    s = x.std(axis=axis)
    return ((x - m)**3).mean(axis=axis) / (s**3 + 1e-8)

def time_domain_feature_vector(seg):
    """
    Compute a feature vector of common time-domain EMG features for one window.
    
    Parameters
    ----------
    seg : np.ndarray
        EMG window segment, shape (window_size, n_channels).
    
    Returns
    -------
    feat : np.ndarray
        1D feature vector (all features from all channels concatenated).
    """
    # Basic power/amplitude measures
    rms_val  = rms(seg, axis=0)      # Root Mean Square
    mav_val  = mav(seg, axis=0)      # Mean Absolute Value
    wl_val   = wl(seg, axis=0)       # Waveform Length
    var_val  = var(seg, axis=0)      # Variance
    std_val  = std(seg, axis=0)      # Standard Deviation
    iemg_val = iemg(seg, axis=0)     # Integrated EMG
    
    # Signal complexity measures
    zc_val   = zc(seg, axis=0)       # Zero Crossings
    ssc_val  = ssc(seg, axis=0)      # Slope Sign Changes
    
    # Higher-order statistics
    kurt_val = kf(seg, axis=0)       # Kurtosis Factor
    skew_val = skewness(seg, axis=0) # Skewness
    
    # Concatenate all feature vectors (per channel) into one long vector
    feat = np.concatenate([
        rms_val, mav_val, wl_val, var_val, std_val, iemg_val,
        zc_val, ssc_val, kurt_val, skew_val
    ], axis=0)

    return feat

# Frequency-Domain Features

In [5]:
EPS = 1e-12  # numerical stability

def _welch_psd(x, fs, nperseg=None):
    """
    Compute Welch PSD for each channel.
    Args:
        x: array (samples, channels)
        fs: sampling rate (Hz)
        nperseg: window length for Welch (samples). If None, auto-choose.
    Returns:
        f: frequencies (Hz) shape (F,)
        Pxx: power spectral density (per channel) shape (F, C)
    """
    if nperseg is None:
        # Heuristic: clamp between 64 and 1024, but not longer than the signal
        nperseg = int(np.clip(len(x), 64, 1024))
    f, P = welch(x, fs=fs, nperseg=min(nperseg, len(x)), axis=0)  # P shape (F, C)
    return f, P

def total_power(x, fs):
    """Total power = integral of PSD across all freqs (per channel)."""
    f, P = _welch_psd(x, fs)
    # Trapz integrate across frequency for each channel
    return np.trapezoid(P, f, axis=0)

def bandpower(x, fs, band=(5, 15)):
    """
    Band-limited power (per channel).
    Args:
        band: (f_low, f_high) in Hz
    """
    f, P = _welch_psd(x, fs)
    idx = (f >= band[0]) & (f <= band[1])
    return np.trapezoid(P[idx, :], f[idx], axis=0)

def mean_frequency(x, fs):
    """
    Spectral mean frequency (a.k.a. spectral centroid).
    sum(f * P) / sum(P) per channel.
    """
    f, P = _welch_psd(x, fs)
    denom = np.trapezoid(P, f, axis=0) + EPS
    return np.trapezoid(P * f[:, None], f, axis=0) / denom

def median_frequency(x, fs):
    """
    Median frequency: frequency that splits total power into two equal halves.
    """
    f, P = _welch_psd(x, fs)
    # Normalize cumulative power per channel
    cumsum = np.cumsum(P, axis=0)
    # Total per channel
    totals = cumsum[-1, :] + EPS
    # Find the first frequency where cumulative power >= 50% of total
    medf = np.empty(P.shape[1])
    for c in range(P.shape[1]):
        idx = np.searchsorted(cumsum[:, c], 0.5 * totals[c])
        idx = np.clip(idx, 0, len(f) - 1)
        medf[c] = f[idx]
    return medf

def peak_frequency(x, fs):
    """Frequency at which PSD is maximum (per channel)."""
    f, P = _welch_psd(x, fs)
    idx = np.argmax(P, axis=0)
    return f[idx]

def spectral_moments(x, fs, order=2):
    """
    Spectral moments around the centroid:
    - m0 = sum(P)               (total power)
    - m1 = sum(f*P)/sum(P)      (mean freq / centroid)
    - m2 = sum((f-m1)^2 * P)/sum(P)  (spectral variance / bandwidth^2)
    Returns:
        m0, m1, m2 — each (channels,)
    """
    f, P = _welch_psd(x, fs)
    m0 = np.trapezoid(P, f, axis=0) + EPS
    m1 = np.trapezoid(P * f[:, None], f, axis=0) / m0
    # variance around centroid
    var = np.trapezoid(P * (f[:, None] - m1[None, :])**2, f, axis=0) / m0
    return m0, m1, var

def spectral_entropy(x, fs, base=2):
    """
    Spectral entropy (Shannon) of normalized PSD (per channel).
    Lower entropy -> more concentrated spectrum; higher -> more spread.
    """
    f, P = _welch_psd(x, fs)
    # Normalize PSD to a probability distribution across frequencies
    Pn = P / (P.sum(axis=0, keepdims=True) + EPS)
    H = -(Pn * np.log(Pn + EPS)).sum(axis=0)
    if base == 2:
        H = H / np.log(2)
    return H

def spectral_edge_frequency(x, fs, edge=0.95):
    """
    Spectral edge frequency (SEF): the frequency below which `edge` (e.g., 95%)
    of the total power is contained (per channel).
    """
    f, P = _welch_psd(x, fs)
    cum = np.cumsum(P, axis=0)
    totals = cum[-1, :] + EPS
    sef = np.empty(P.shape[1])
    for c in range(P.shape[1]):
        idx = np.searchsorted(cum[:, c], edge * totals[c])
        idx = np.clip(idx, 0, len(f) - 1)
        sef[c] = f[idx]
    return sef

def freq_domain_feature_vector(x, fs, bands=((5,15), (15,30))):
    """
    Build a concatenated frequency-domain feature vector per channel.
    Features included:
      - total power
      - bandpowers (for each band in `bands`)
      - mean frequency
      - median frequency
      - peak frequency
      - spectral variance (2nd central moment)
      - spectral entropy
      - SEF95 (edge=0.95)

    Returns:
        feat: array of shape (n_features * n_channels,)
    """
    m0, m1, var = spectral_moments(x, fs)         # (C,) each
    tp = m0                                       # total power
    mf = m1                                       # mean/centroid
    v  = var                                      # spectral variance
    medf = median_frequency(x, fs)                # (C,)
    peakf = peak_frequency(x, fs)                 # (C,)
    sent = spectral_entropy(x, fs)                # (C,)
    sef95 = spectral_edge_frequency(x, fs, 0.95)  # (C,)

    # bandpowers per band
    bp_list = [bandpower(x, fs, band=b) for b in bands]  # list of (C,)
    # concatenate along feature dimension
    feat_per_channel = np.vstack(
        [tp, *(bp_list), mf, medf, peakf, v, sent, sef95]
    )  # shape: (n_feat, C)
    return feat_per_channel.ravel(order="F")  # channel-major concatenation

In [6]:
# --- Feature extraction config ---
USE_TIME = True
USE_FREQ = True
FREQ_BANDS = ((5, 15), (15, 30))   # tweak for DB1 envelope; for DB2 you'd use higher bands

def extract_feature_vector(seg, fs):
    """Return 1D feature vector for one window, based on toggles above."""
    feats = []
    if USE_TIME:
        feats.append(time_domain_feature_vector(seg))       # from earlier cell
    if USE_FREQ:
        feats.append(freq_domain_feature_vector(seg, fs, bands=FREQ_BANDS))  # from earlier cell
    return np.concatenate(feats, axis=0) if len(feats) > 1 else feats[0]

In [7]:
from tqdm import tqdm
import re

def subject_id_from_path(p):
    """Extract integer subject ID from filenames like S1_A1_E1.mat (case-insensitive)."""
    m = re.search(r"s(\d+)", p.stem.lower())
    return int(m.group(1)) if m else -1

mats = sorted(DATA_RAW.rglob("*.mat"))
print("found .mat files:", len(mats))
assert mats, "No .mat files found in data/raw/ninapro/db1"

# Limit for quick iteration (optional)
MAX_FILES = None  # e.g., 10

X_feat_list, Y_list = [], []
SUBJ_list = []                              # NEW: collect subject id per window

for p in tqdm(mats[:MAX_FILES], desc="featureizing"):
    d = read_mat(p)
    emg = ensure_samples_channels(pick_key(d, "emg"))
    y   = np.asarray(pick_key(d, "restimulus"), dtype=np.int32)

    # Preprocess → window
    x_p = rectify_and_zscore(bandpass_filter(emg, fs=FS, low=1.0, high=40.0))
    Xw, Yw = sliding_window(x_p, y, window_s=WINDOW_S, step_s=STEP_S, fs=FS)
    if Xw.shape[0] == 0:
        continue

    # Extract features per window
    feats = np.vstack([extract_feature_vector(seg, FS) for seg in Xw])
    X_feat_list.append(feats)
    Y_list.append(Yw)

    # NEW: tag each window with this file's subject id
    sid = subject_id_from_path(p)
    SUBJ_list.append(np.full(len(Yw), sid, dtype=np.int32))

# Concatenate across files
X_feat = np.vstack(X_feat_list) if X_feat_list else np.empty((0, ))
Y      = np.concatenate(Y_list) if Y_list else np.empty((0, ), dtype=np.int32)
SUBJ   = np.concatenate(SUBJ_list) if SUBJ_list else np.empty((0, ), dtype=np.int32)   # NEW

# Optional: drop 'rest' windows (label==0) for gesture-only models
mask = (Y != 0)
Xf, Yf, Subjf = X_feat[mask], Y[mask], SUBJ[mask]         # CHANGED: keep subjects aligned

# Save for reuse
DATA_PROCESSED.mkdir(parents=True, exist_ok=True)
np.save(DATA_PROCESSED / "X_feat_db1.npy", Xf)
np.save(DATA_PROCESSED / "y_db1.npy",   Yf)
np.save(DATA_PROCESSED / "subjects_db1.npy", Subjf)       # NEW

Xf.shape, Yf.shape, Subjf.shape, np.unique(Subjf)[:10]

found .mat files: 81


featureizing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 81/81 [28:27<00:00, 21.08s/it]


((1105889, 190),
 (1105889,),
 (1105889,),
 array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32))