# Denoising ds003020 fMRI With Savitzky-Golay vs Empirical Mode Modeling
Drift visualization and detrending comparison


Load ds003020 ROI BOLD time series, visualize raw drift, apply Huth-style Savitzky-Golay (SG) detrending and Empirical Mode Modeling (EMM/EMD) detrending, then visualize cleaned signals, power spectra, and IMF decomposition. This notebook intentionally omits EDM forecasting per current focus on denoising only.


In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy.signal import savgol_filter, welch

sns.set(style='whitegrid')
plt.rcParams['figure.dpi'] = 150


In [None]:

# PyEMD is required for EMM detrending. Install on first run if missing.
try:
    from PyEMD import EMD
except ImportError:
    # In notebooks, uncomment the next line to install:
    # %pip install -q EMD-signal
    from PyEMD import EMD


## Load BOLD data

In [None]:

PROJECT_ROOT = Path('/flash/PaoU/seann/fmri-edm-ccm')
DATA_CACHE = PROJECT_ROOT / 'data_cache'
SUBJECT = 'UTS01'
STORY = 'wheretheressmoke'
TR = 2.0  # seconds

bold_path = DATA_CACHE / SUBJECT / STORY / 'schaefer_400.npy'
assert bold_path.exists(), f'Missing BOLD cache: {bold_path}'

bold_raw = np.load(bold_path)
n_tr, n_roi = bold_raw.shape
print(f'Loaded BOLD shape: {bold_raw.shape}, TR={TR}s')


## Detrending utilities

In [None]:

from typing import Dict, Sequence, Tuple

def _ensure_trim_indices(n_tr: int, trim_tr: int) -> np.ndarray:
    if trim_tr <= 0:
        return np.arange(n_tr)
    start = trim_tr
    stop = n_tr - trim_tr
    if stop <= start:
        raise ValueError(f'Trim of {trim_tr} TRs each side leaves no data (n_tr={n_tr}).')
    return np.arange(start, stop)


def _safe_zscore(data: np.ndarray) -> np.ndarray:
    mean = np.nanmean(data, axis=0, keepdims=True)
    std = np.nanstd(data, axis=0, keepdims=True)
    std[std == 0] = 1.0
    return (data - mean) / std


def preprocess_sg(
    bold_ts: np.ndarray,
    tr: float = 2.0,
    window_s: float = 120.0,
    polyorder: int = 2,
    trim_s: float = 20.0,
    zscore: bool = True,
) -> Tuple[np.ndarray, np.ndarray, Dict]:
    '''Huth-style Savitzky-Golay detrending with trimming and z-scoring.'''
    data = np.asarray(bold_ts, dtype=float)
    window_tr = max(int(round(window_s / tr)), polyorder + 2)
    if window_tr % 2 == 0:
        window_tr += 1
    trend = savgol_filter(data, window_length=window_tr, polyorder=polyorder, axis=0, mode='interp')
    detrended = data - trend

    trim_tr = int(round(trim_s / tr))
    keep_idx = _ensure_trim_indices(data.shape[0], trim_tr)
    cleaned = detrended[keep_idx]
    if zscore:
        cleaned = _safe_zscore(cleaned)
    return cleaned, keep_idx, {"trend": trend, "trim_tr": trim_tr, "window_tr": window_tr}


def preprocess_emm(
    bold_ts: np.ndarray,
    tr: float = 2.0,
    trim_s: float = 20.0,
    zscore: bool = True,
    remove_last_imf: bool = True,
    roi_for_diagnostics: int = 0,
) -> Tuple[np.ndarray, np.ndarray, Dict]:
    '''EMD detrending: subtract residual drift reconstructed from IMFs.'''
    data = np.asarray(bold_ts, dtype=float)
    n_tr, n_roi_local = data.shape
    emd = EMD()
    detrended = np.zeros_like(data)
    diag = None

    for roi in range(n_roi_local):
        x = data[:, roi]
        imfs = emd(x)
        if imfs.size == 0:
            reconstructed = x.copy()
        else:
            if remove_last_imf and imfs.shape[0] > 1:
                reconstructed = np.sum(imfs[:-1], axis=0)
            else:
                reconstructed = np.sum(imfs, axis=0)
        residual = x - reconstructed
        cleaned = x - residual
        detrended[:, roi] = cleaned
        if roi == roi_for_diagnostics:
            diag = {"imfs": imfs, "residual": residual, "original": x}

    trim_tr = int(round(trim_s / tr))
    keep_idx = _ensure_trim_indices(n_tr, trim_tr)
    detrended = detrended[keep_idx]
    if zscore:
        detrended = _safe_zscore(detrended)
    return detrended, keep_idx, {"trim_tr": trim_tr, "diagnostics": diag}


## Run detrending

In [None]:

SG_WINDOW_S = 120
SG_POLYORDER = 2
TRIM_SECONDS = 20

sg_clean, sg_keep, sg_info = preprocess_sg(
    bold_raw,
    tr=TR,
    window_s=SG_WINDOW_S,
    polyorder=SG_POLYORDER,
    trim_s=TRIM_SECONDS,
    zscore=True,
)

emm_clean, emm_keep, emm_info = preprocess_emm(
    bold_raw,
    tr=TR,
    trim_s=TRIM_SECONDS,
    zscore=True,
    remove_last_imf=True,
    roi_for_diagnostics=0,
)

print(f"SG cleaned shape: {sg_clean.shape}, trimmed TRs: {sg_info['trim_tr']}")
print(f"EMM cleaned shape: {emm_clean.shape}, trimmed TRs: {emm_info['trim_tr']}")


## Drift before and after detrending (sample ROIs)

In [None]:

roi_sample = [0, 1, 2, 50, 120]
time_raw_min = np.arange(n_tr) * TR / 60.0
trim_lines = [TRIM_SECONDS / 60.0, (n_tr * TR - TRIM_SECONDS) / 60.0]

fig, axes = plt.subplots(len(roi_sample), 1, figsize=(12, 2.5 * len(roi_sample)), sharex=True)
for ax, roi in zip(axes, roi_sample):
    ax.plot(time_raw_min, bold_raw[:, roi], label='Raw', color='0.55', alpha=0.8)
    ax.plot(sg_keep * TR / 60.0, sg_clean[:, roi], label='SG detrended', linewidth=1.3)
    ax.plot(emm_keep * TR / 60.0, emm_clean[:, roi], label='EMM detrended', linewidth=1.3)
    for line in trim_lines:
        ax.axvline(line, color='k', linestyle='--', alpha=0.25)
    ax.set_ylabel('Signal (a.u.)')
    ax.set_title(f'ROI {roi}')
axes[-1].set_xlabel('Time (minutes)')
axes[0].legend(loc='upper right')
fig.suptitle('Drift removal: raw vs SG vs EMM', y=1.02)
plt.tight_layout()


## Power spectral density (representative ROI)

In [None]:

roi_psd = 0
fs = 1.0 / TR
f_raw, p_raw = welch(bold_raw[:, roi_psd], fs=fs, nperseg=min(128, n_tr))
f_sg, p_sg = welch(sg_clean[:, roi_psd], fs=fs, nperseg=min(128, sg_clean.shape[0]))
f_emm, p_emm = welch(emm_clean[:, roi_psd], fs=fs, nperseg=min(128, emm_clean.shape[0]))

plt.figure(figsize=(8, 4))
plt.semilogy(f_raw, p_raw, label='Raw')
plt.semilogy(f_sg, p_sg, label='SG detrended')
plt.semilogy(f_emm, p_emm, label='EMM detrended')
plt.xlabel('Frequency (Hz)')
plt.ylabel('PSD (a.u.)')
plt.title(f'PSD comparison (ROI {roi_psd})')
plt.legend()
plt.tight_layout()


## IMF decomposition (EMD)

In [None]:

roi_diag = 0
imfs = None
residual = None
original = bold_raw[:, roi_diag]

diag = emm_info.get('diagnostics') if emm_info else None
if diag:
    imfs = diag.get('imfs')
    residual = diag.get('residual')
    original = diag.get('original', original)

if imfs is None or (hasattr(imfs, 'size') and imfs.size == 0):
    emd = EMD()
    imfs = emd(original)
    residual = original - imfs.sum(axis=0) if getattr(imfs, 'size', 0) else None

if imfs is not None and getattr(imfs, 'size', 0) > 0:
    time_min = np.arange(imfs.shape[1]) * TR / 60.0
    max_imfs = min(6, imfs.shape[0])
    series = [("Signal", original)] + [(f"IMF {i+1}", imfs[i]) for i in range(max_imfs)]
    if residual is not None:
        series.append(("Residual", residual))

    max_amp = max(np.nanmax(np.abs(s)) for _, s in series)
    sep = 2.5
    plt.figure(figsize=(10, 6))
    for idx, (label, s) in enumerate(series):
        offset = sep * (len(series) - idx - 1)
        y = (s / max_amp) + offset
        plt.plot(time_min, y, lw=1.2, color='k')
        plt.text(time_min[0], offset + 0.1, label, va='bottom', ha='left', fontsize=9)
    plt.yticks([])
    plt.xlabel('Time (minutes)')
    plt.title(f'EMD mode stack (ROI {roi_diag})')
    plt.tight_layout()
else:
    print('No IMFs were returned by EMD; cannot plot decomposition.')


## Summary

In [None]:

summary_lines = [
    f"Savitzky-Golay removed slow drift with window={SG_WINDOW_S}s and polyorder={SG_POLYORDER}, trimming +/-{TRIM_SECONDS}s.",
    "EMM subtracted the residual drift estimated from IMFs (dropping the slowest IMF).",
    "Visualize above: raw signals show slow drift; both SG and EMM flatten baselines, with IMF plot showing decomposed modes.",
]
print("".join(summary_lines))
