# EEG Signal Processing for Seizure Detection using STFT and DWT (Haar)

This notebook is a companion to the paper **EEG Signal Processing for Seizure Detection using STFT and DWT**.

It demonstrates:
- **STFT** time–frequency analysis (spectrogram + simple bandpower feature)
- **DWT (Haar)** multi-level detail coefficients
- A lightweight **variance change-point heuristic**
- A simple way to combine STFT + DWT features into **candidate seizure windows**

> This notebook supports loading EEG stored in **MATLAB `.mat`** files.  
If no file is provided, it will generate a **synthetic EEG** example to demonstrate the pipeline.


In [None]:
# !pip -q install numpy scipy matplotlib pywavelets pandas


## 1) Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy import signal
from scipy.io import loadmat
import pywt

np.random.seed(42)


## 2) Load EEG from a `.mat` file

### Minimal assumptions
A `.mat` file can store data in many shapes/keys. This loader tries to:
1. Find a reasonable EEG array automatically (largest 2D numeric array)
2. Infer orientation:
   - If shape is `(n_samples, n_channels)`, it will transpose to `(n_channels, n_samples)`
   - If shape is already `(n_channels, n_samples)`, it will keep it
3. Try to infer sampling rate `fs` if a common key exists (`fs`, `Fs`, `srate`, `sampling_rate`)

**You should edit** `MAT_PATH` and (optionally) `EEG_KEY`, `FS_KEY` if auto-detection doesn’t pick the right fields.


In [None]:
# --- Configure these ---
MAT_PATH = ""  # e.g., "data/subject01.mat" or "/path/to/file.mat"
EEG_KEY = None # e.g., "data" or "EEG" or "X"; set to None for auto-detect
FS_KEY  = None # e.g., "fs" or "srate"; set to None for auto-detect
# -----------------------


def _is_numeric_array(x):
    return isinstance(x, np.ndarray) and np.issubdtype(x.dtype, np.number)


def _pick_eeg_key(mat_dict, preferred=None):
    if preferred is not None and preferred in mat_dict and _is_numeric_array(mat_dict[preferred]):
        return preferred

    # Ignore MATLAB metadata keys
    ignore = {'__header__', '__version__', '__globals__'}
    candidates = []
    for k, v in mat_dict.items():
        if k in ignore:
            continue
        if _is_numeric_array(v) and v.ndim == 2:
            candidates.append((k, v.size, v.shape))
    if not candidates:
        return None
    # Pick largest 2D numeric array by element count
    candidates.sort(key=lambda x: x[1], reverse=True)
    return candidates[0][0]


def _pick_fs(mat_dict, preferred=None):
    keys = [preferred] if preferred else []
    keys += ["fs", "Fs", "srate", "sampling_rate", "sample_rate", "SamplingRate"]
    for k in keys:
        if k and k in mat_dict and _is_numeric_array(mat_dict[k]):
            v = mat_dict[k]
            # allow scalar or 1-element array
            try:
                return float(np.ravel(v)[0])
            except Exception:
                pass
    return None


def load_eeg_from_mat(path, eeg_key=None, fs_key=None):
    mat = loadmat(path, squeeze_me=True, struct_as_record=False)
    key = _pick_eeg_key(mat, preferred=eeg_key)
    if key is None:
        raise ValueError("Could not find a 2D numeric EEG array in the .mat file. Set EEG_KEY explicitly.")

    X = mat[key]
    fs = _pick_fs(mat, preferred=fs_key)

    # Ensure 2D
    X = np.asarray(X)
    if X.ndim != 2:
        raise ValueError(f"Selected key '{key}' is not 2D after loading. Got shape {X.shape}.")

    # Orientation: want (n_channels, n_samples)
    n0, n1 = X.shape
    # Heuristic: typically n_samples >> n_channels; if so and X is (samples, channels), transpose
    if n0 > n1 and n1 <= 512:
        # likely (samples, channels)
        eeg = X.T
    else:
        # likely already (channels, samples) or square-ish
        eeg = X

    channel_names = [f"Ch{i+1}" for i in range(eeg.shape[0])]
    return eeg, fs, channel_names, key


eeg = fs = channel_names = None
if MAT_PATH.strip():
    eeg, fs, channel_names, used_key = load_eeg_from_mat(MAT_PATH, eeg_key=EEG_KEY, fs_key=FS_KEY)
    print(f"Loaded EEG from: {MAT_PATH}")
    print(f"Using key: {used_key}")
    print("EEG shape (channels, samples):", eeg.shape)
    print("Sampling rate fs:", fs)
else:
    print("MAT_PATH is empty — will use a synthetic EEG example below.")


### Synthetic fallback (optional)
If you don’t have a `.mat` file handy, we generate a synthetic EEG-like signal with an injected seizure-like window.


In [None]:
def synth_eeg(fs=256, duration_s=60, n_channels=2):
    n = int(fs * duration_s)
    t = np.arange(n) / fs

    eeg = []
    for ch in range(n_channels):
        wn = np.random.randn(n)
        b, a = signal.butter(4, 40/(fs/2), btype='low')
        base = signal.filtfilt(b, a, wn)
        base = base / np.std(base) * 20e-6

        alpha = 10e-6 * np.sin(2*np.pi*10*t + 0.3*ch)
        beta  =  5e-6 * np.sin(2*np.pi*20*t + 0.5*ch)
        x = base + alpha + beta
        eeg.append(x)

    eeg = np.vstack(eeg)

    seiz_start_s, seiz_end_s = 25, 35
    i0, i1 = int(seiz_start_s*fs), int(seiz_end_s*fs)
    seizure_t = t[i0:i1]
    spike_wave = 60e-6 * np.sin(2*np.pi*4*seizure_t) + 30e-6 * np.sin(2*np.pi*8*seizure_t)
    spikes = 80e-6 * signal.square(2*np.pi*2*seizure_t, duty=0.1)
    seizure = spike_wave + 0.3*spikes
    eeg[:, i0:i1] += seizure

    channel_names = [f"Ch{c+1}" for c in range(n_channels)]
    return eeg, fs, channel_names, (seiz_start_s, seiz_end_s)


if eeg is None:
    eeg, fs, channel_names, true_seizure_window = synth_eeg(fs=256, duration_s=60, n_channels=2)
    print("Using synthetic EEG example.")
    print("True seizure window (s):", true_seizure_window)

print("EEG shape:", eeg.shape, " | fs:", fs)


## 3) Quick visualization (time domain)

In [None]:
def plot_eeg(eeg, fs, channel=0, tlim=None, title=None):
    x = eeg[channel]
    n = x.size
    t = np.arange(n) / fs

    if tlim is not None:
        mask = (t >= tlim[0]) & (t <= tlim[1])
        t = t[mask]
        x = x[mask]

    plt.figure(figsize=(12, 3))
    plt.plot(t, x)
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude (a.u.)")
    plt.title(title or f"EEG time series ({channel_names[channel]})")
    plt.tight_layout()
    plt.show()

plot_eeg(eeg, fs, channel=0, title="Full EEG trace")
plot_eeg(eeg, fs, channel=0, tlim=(20, 40), title="Zoom (20–40s)")


## 4) STFT (spectrogram + bandpower feature)

In [None]:
def compute_stft(x, fs, nperseg=512, noverlap=384):
    f, t, Zxx = signal.stft(
        x, fs=fs, window='hann',
        nperseg=nperseg, noverlap=noverlap,
        detrend=False, return_onesided=True, boundary=None
    )
    Pxx = np.abs(Zxx)**2
    return f, t, Pxx


def plot_spectrogram(f, t, Pxx, fmax=60, title="STFT Power Spectrogram"):
    plt.figure(figsize=(12, 4))
    Pdb = 10*np.log10(Pxx + 1e-20)
    mask = f <= fmax
    plt.pcolormesh(t, f[mask], Pdb[mask], shading='auto')
    plt.ylabel('Frequency (Hz)')
    plt.xlabel('Time (s)')
    plt.title(title)
    plt.colorbar(label='Power (dB)')
    plt.tight_layout()
    plt.show()


x = eeg[0]
f, tt, Pxx = compute_stft(x, fs, nperseg=512, noverlap=384)
plot_spectrogram(f, tt, Pxx, fmax=60, title="STFT Spectrogram (Ch1)")


In [None]:
def bandpower_over_time(f, Pxx, band=(3, 12)):
    idx = (f >= band[0]) & (f <= band[1])
    return Pxx[idx].sum(axis=0)


def robust_threshold(x, z=3.5):
    med = np.median(x)
    mad = np.median(np.abs(x - med)) + 1e-12
    return med + z * 1.4826 * mad


bp_3_12 = bandpower_over_time(f, Pxx, band=(3, 12))
thr_bp = robust_threshold(bp_3_12, z=3.0)

plt.figure(figsize=(12, 3))
plt.plot(tt, bp_3_12, label='Bandpower 3–12 Hz')
plt.axhline(thr_bp, linestyle='--', label='Threshold')
plt.xlabel('Time (s)')
plt.title('STFT bandpower feature')
plt.legend()
plt.tight_layout()
plt.show()

stft_flags = bp_3_12 > thr_bp
print("Flagged STFT bins:", stft_flags.sum(), "/", stft_flags.size)


## 5) DWT (Haar) + variance change heuristic

In [None]:
def compute_dwt_details(x, wavelet='haar', level=5):
    coeffs = pywt.wavedec(x, wavelet=wavelet, level=level)
    cA = coeffs[0]
    cDs = coeffs[1:]
    return cA, cDs


def plot_dwt_details(cDs, title="DWT Detail Coefficients (Haar)"):
    plt.figure(figsize=(12, 2.5 + 0.5*len(cDs)))
    for i, cD in enumerate(cDs, start=1):
        ax = plt.subplot(len(cDs), 1, i)
        ax.plot(cD)
        ax.set_ylabel(f"D{i}")
        if i == 1:
            ax.set_title(title)
    plt.xlabel("Coefficient index (downsampled)")
    plt.tight_layout()
    plt.show()


def rolling_variance(a, win=32):
    if win < 2:
        return np.zeros_like(a)
    x = a.astype(float)
    k = np.ones(win)/win
    mean = np.convolve(x, k, mode='same')
    mean2 = np.convolve(x*x, k, mode='same')
    var = np.maximum(mean2 - mean*mean, 0.0)
    return var


cA, cDs = compute_dwt_details(x, wavelet='haar', level=5)
plot_dwt_details(cDs, title="Haar DWT details for Ch1")

vars_ = []
for cD in cDs:
    v = rolling_variance(cD, win=32)
    v = (v - np.median(v)) / (np.median(np.abs(v - np.median(v))) + 1e-12)
    vars_.append(v)

maxlen = max(len(v) for v in vars_)
stack = np.vstack([np.pad(v, (0, maxlen-len(v))) for v in vars_])
vsum = stack.sum(axis=0)

thr_v = robust_threshold(vsum, z=3.0)

plt.figure(figsize=(12, 3))
plt.plot(vsum, label='Summed normalized rolling variance (details)')
plt.axhline(thr_v, linestyle='--', label='Threshold')
plt.title('Wavelet variance-change heuristic (downsampled domain)')
plt.legend()
plt.tight_layout()
plt.show()

dwt_flags = vsum > thr_v
print("Flagged DWT indices:", dwt_flags.sum(), "/", dwt_flags.size)


## 6) Convert flags to time intervals and overlay predictions

In [None]:
def flags_to_intervals(t, flags, min_dur_s=1.0):
    intervals = []
    if len(flags) == 0:
        return intervals
    dt = np.median(np.diff(t)) if len(t) > 1 else 0.0

    in_run = False
    start = None
    for i, f in enumerate(flags):
        if f and not in_run:
            in_run = True
            start = t[i]
        elif (not f) and in_run:
            end = t[i]
            in_run = False
            if (end - start) >= min_dur_s:
                intervals.append((start, end))
    if in_run:
        end = t[-1] + dt
        if (end - start) >= min_dur_s:
            intervals.append((start, end))
    return intervals


def merge_intervals(intervals, gap_s=0.5):
    if not intervals:
        return []
    intervals = sorted(intervals, key=lambda x: x[0])
    merged = [intervals[0]]
    for s, e in intervals[1:]:
        ps, pe = merged[-1]
        if s <= pe + gap_s:
            merged[-1] = (ps, max(pe, e))
        else:
            merged.append((s, e))
    return merged


stft_intervals = flags_to_intervals(tt, stft_flags, min_dur_s=1.0)

# DWT is downsampled by ~2^level (approximate mapping)
level = 5
factor = 2**level
t_dwt = np.arange(len(vsum)) * factor / fs
dwt_intervals = flags_to_intervals(t_dwt, dwt_flags, min_dur_s=1.0)

combined = merge_intervals(stft_intervals + dwt_intervals, gap_s=1.0)

print("STFT candidate intervals:", stft_intervals)
print("DWT  candidate intervals:", dwt_intervals)
print("Combined candidate seizure intervals:", combined)


In [None]:
def plot_with_intervals(eeg, fs, intervals, channel=0, truth=None, title="Candidate seizure windows"):
    x = eeg[channel]
    t = np.arange(x.size) / fs

    plt.figure(figsize=(12, 3))
    plt.plot(t, x, linewidth=1)
    for (s, e) in intervals:
        plt.axvspan(s, e, alpha=0.2, label='Predicted' if 'Predicted' not in plt.gca().get_legend_handles_labels()[1] else None)
    if truth is not None:
        plt.axvspan(truth[0], truth[1], alpha=0.2, label='True (synthetic)' if 'True (synthetic)' not in plt.gca().get_legend_handles_labels()[1] else None)
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude (a.u.)")
    plt.title(f"{title} — {channel_names[channel]}")
    plt.legend()
    plt.tight_layout()
    plt.show()


truth = true_seizure_window if 'true_seizure_window' in globals() else None
plot_with_intervals(eeg, fs, combined, channel=0, truth=truth, title="STFT + DWT heuristic detection")


## 7) Notes on using your UC Irvine `.mat` EEG dataset

If the auto-loader chooses the wrong key:
1. Set `EEG_KEY = "your_key_name"` (print keys with `mat.keys()`)
2. If the EEG is nested in a MATLAB struct, you may need to extract it explicitly.

If your `.mat` includes labels:
- Provide the key name and format, and you can extend this notebook to compute metrics (AUC, sensitivity/specificity).
