Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make removeTrend output identical to MATLAB PREP #71

Merged
merged 11 commits into from
Apr 29, 2021
25 changes: 25 additions & 0 deletions docs/matlab_differences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ internal math.
:depth: 3


Differences in Signal Detrending
--------------------------------

In the PREP pipeline, trends (i.e., slow drifts in EEG baseline signal) are
temporarily removed from the data prior to adaptive line-noise removal
as well as prior to bad channel detection via :class:`~pyprep.NoisyChannels`,
which occurs at multiple points during robust re-referencing. This is done to
improve the accuracy of both of these processes, which are sensitive to
influence from trends in the signal.

In MATLAB PREP, the default method of trend removal is to use EEGLAB's
``pop_eegfiltnew``, which creates and applies an FIR high-pass filter to the
data. MNE's :func:`mne.filter.filter_data` offers similar functionality, but
uses slightly different filter creation math and a different filtering
algorithm such that its results and subsequent :class:`~pyprep.NoisyChannels`
values also differ slightly (on the order of ~0.002) for RANSAC correlations.

Because the practical differences are small and MNE's filtering is fast and
well-tested, PyPREP defaults to using :func:`mne.filter.filter_data` for
high-pass trend removal. However, for exact numerical compatibility, PyPREP
has a basic re-implementaion of EEGLAB's ``pop_eegfiltnew`` in Python that
produces identical results to MATLAB PREP's ``removeTrend`` when
``matlab_strict`` is set to ``True``.


Differences in RANSAC
---------------------

Expand Down
1 change: 1 addition & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Changelog
- Changed RANSAC's random channel sampling code to produce the same results as MATLAB PREP for the same random seed, additionally changing the default RANSAC sample size from 25% of all *good* channels (e.g. 15 for a 64-channel dataset with 4 bad channels) to 25% of *all* channels (e.g. 16 for the same dataset), by `Austin Hurst`_ (:gh:`62`)
- Changed RANSAC so that "bad by high-frequency noise" channels are retained when making channel predictions (provided they aren't flagged as bad by any other metric), matching MATLAB PREP behaviour, by `Austin Hurst`_ (:gh:`64`)
- Added a new flag ``matlab_strict`` to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, :class:`~pyprep.NoisyChannels`, and :func:`~pyprep.ransac.find_bad_by_ransac` for optionally matching MATLAB PREP's internal math as closely as possible, overriding areas where PyPREP attempts to improve on the original, by `Austin Hurst`_ (:gh:`70`)
- Added a ``matlab_strict`` method for high-pass trend removal, exactly matching MATLAB PREP's values if ``matlab_strict`` is enabled, by `Austin Hurst`_ (:gh:`71`)

Bug
~~~
Expand Down
2 changes: 1 addition & 1 deletion pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, raw, do_detrend=True, random_state=None, matlab_strict=False)
self.sample_rate = raw.info["sfreq"]
if do_detrend:
self.raw_mne._data = removeTrend(
self.raw_mne.get_data(), sample_rate=self.sample_rate
self.raw_mne.get_data(), self.sample_rate, matlab_strict=matlab_strict
)
self.matlab_strict = matlab_strict

Expand Down
4 changes: 3 additions & 1 deletion pyprep/prep_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def fit(self):
# reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels)
# Step 1: 1Hz high pass filtering
if len(self.prep_params["line_freqs"]) != 0:
self.EEG_new = removeTrend(self.EEG_raw, sample_rate=self.sfreq)
self.EEG_new = removeTrend(
self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict
)

# Step 2: Removing line noise
linenoise = self.prep_params["line_freqs"]
Expand Down
4 changes: 3 additions & 1 deletion pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def robust_reference(self):

"""
raw = self.raw.copy()
raw._data = removeTrend(raw.get_data(), sample_rate=self.sfreq)
raw._data = removeTrend(
raw.get_data(), self.sfreq, matlab_strict=self.matlab_strict
)

# Determine unusable channels and remove them from the reference channels
noisy_detector = NoisyChannels(
Expand Down
87 changes: 61 additions & 26 deletions pyprep/removeTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,76 @@
import mne
import numpy as np

from pyprep.utils import _eeglab_create_highpass, _eeglab_fir_filter


def removeTrend(
EEG,
detrendType="High pass",
sample_rate=160.0,
sample_rate,
detrendType="high pass",
detrendCutoff=1.0,
detrendChannels=None,
matlab_strict=False,
):
"""Perform high pass filtering or detrending.
"""Remove trends (i.e., slow drifts in baseline) from an array of EEG data.

Parameters
----------
EEG : np.ndarray
The input EEG data.
detrendType : str
Type of detrending to be performed: high pass, high pass sinc, or local
detrending.
A 2-D array of EEG data to detrend.
sample_rate : float
Rate at which the EEG data was sampled.
detrendCutoff : float
High pass cut-off frequency.
The sample rate (in Hz) of the input EEG data.
detrendType : str, optional
Type of detrending to be performed: must be one of 'high pass',
'high pass sinc, or 'local detrend'. Defaults to 'high pass'.
detrendCutoff : float, optional
The high-pass cutoff frequency (in Hz) to use for detrending. Defaults
to 1.0 Hz.
detrendChannels : {list, None}, optional
List of all the channels that require detrending/filtering. If None,
all channels are used (default).
List of the indices of all channels that require detrending/filtering.
If ``None``, all channels are used (default).
matlab_strict : bool, optional
Whether or not detrending should strictly follow MATLAB PREP's internal
math, ignoring any improvements made in PyPREP over the original code
(see :ref:`matlab-diffs` for more details). Defaults to ``False``.

Returns
-------
EEG : np.ndarray
Filtered/detrended EEG data.
A 2-D array containing the filtered/detrended EEG data.

Notes
-----
Filtering is implemented using the MNE filter function mne.filter.filter_data.
Local detrending is the python implementation of the chronux_2 runline command.
High-pass filtering is implemented using the MNE filter function
:func:``mne.filter.filter_data`` unless `matlab_strict` is ``True``, in
which case it is performed using a minimal re-implementation of EEGLAB's
``pop_eegfiltnew``. Local detrending is performed using a Python
re-implementation of the ``runline`` function from the Chronux package for
MATLAB [1]_.

References
----------
.. [1] http://chronux.org/

"""
if len(EEG.shape) == 1:
EEG = np.reshape(EEG, (1, EEG.shape[0]))

if detrendType == "High pass":
EEG = mne.filter.filter_data(
EEG, sfreq=sample_rate, l_freq=1, h_freq=None, picks=detrendChannels
)
if detrendType.lower() == "high pass":
if matlab_strict:
picks = detrendChannels if detrendChannels else range(EEG.shape[0])
filt = _eeglab_create_highpass(detrendCutoff, sample_rate)
EEG[picks, :] = _eeglab_fir_filter(EEG[picks, :], filt)
else:
EEG = mne.filter.filter_data(
EEG,
sfreq=sample_rate,
l_freq=detrendCutoff,
h_freq=None,
picks=detrendChannels
)

elif detrendType == "High pass sinc":
elif detrendType.lower() == "high pass sinc":
fOrder = np.round(14080 * sample_rate / 512)
fOrder = np.int(fOrder + fOrder % 2)
EEG = mne.filter.filter_data(
Expand All @@ -60,7 +85,8 @@ def removeTrend(
filter_length=fOrder,
fir_window="blackman",
)
elif detrendType == "Local detrend":

elif detrendType.lower() == "local detrend":
if detrendChannels is None:
detrendChannels = np.arange(0, EEG.shape[0])
windowSize = 1.5 / detrendCutoff
Expand All @@ -82,29 +108,38 @@ def removeTrend(
for ch in detrendChannels:
EEG[:, ch] = runline(EEG[:, ch], np.int(n), np.int(dn))
EEG = np.transpose(EEG)

else:
logging.warning(
"No filtering/detreding performed since the detrend type did not match"
)

return EEG


def runline(y, n, dn):
"""Implement chronux_2 runline command for performing local linear regression.
"""Perform local linear regression on a channel of EEG data.

A re-implementation of the ``runline`` function from the Chronux package
for MATLAB [1]_.

Parameters
----------
y : np.ndarray
Input from one EEG channel.
A 1-D array of data from a single EEG channel.
n : int
length of the detrending window.
Length of the detrending window.
dn : int
length of the window step size.
Length of the window step size.

Returns
-------
y: np.ndarray
Detrended EEG signal for one channel.
The detrended signal for the given EEG channel.

References
----------
.. [1] http://chronux.org/

"""
nt = y.shape[0]
Expand Down
99 changes: 99 additions & 0 deletions pyprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import scipy.interpolate
from scipy.stats import iqr
from scipy.signal import firwin, lfilter, lfilter_zi
from psutil import virtual_memory


Expand Down Expand Up @@ -108,6 +109,104 @@ def _mat_iqr(arr, axis=None):
return iqr(arr, rng=np.clip(iqr_adj, 0, 100), axis=axis)


def _eeglab_create_highpass(cutoff, srate):
"""Create a high-pass FIR filter using Hamming windows.

Parameters
----------
cutoff : float
The lower pass-band edge of the filter, in Hz.
srate : float
The sampling rate of the EEG signal, in Hz.

Returns
-------
filter : np.ndarray
A 1-dimensional array of FIR filter coefficients.

Notes
-----
In MATLAB PREP, the internal ``removeTrend`` function uses EEGLAB's
``pop_eegfiltnew`` to high-pass the EEG data to remove slow drifts.
Because MNE's ``mne.filter.filter_data`` and EEGLAB's ``pop_eegfiltnew``
calculate filter parameters slightly differently, this function is
used to precisely match EEGLAB & MATLAB PREP's filtering method.

"""
TRANSITION_WIDTH_RATIO = 0.25
HAMMING_CONSTANT = 3.3 # note: not entirely clear what this represents

# Calculate parameters for constructing filter
trans_bandwidth = cutoff if cutoff < 2 else cutoff * TRANSITION_WIDTH_RATIO
order = HAMMING_CONSTANT / (trans_bandwidth / srate)
order = int(np.ceil(order / 2) * 2) # ensure order is even
stop = cutoff - trans_bandwidth
transition = (stop + cutoff) / srate

# Generate highpass filter
N = order + 1
filt = np.zeros(N)
filt[N // 2] = 1
filt -= firwin(N, transition, window='hamming', nyq=1)
return filt


def _eeglab_fir_filter(data, filt):
"""Apply an FIR filter to a 2-D array of EEG data.

Parameters
----------
data : np.ndarray
A 2-D array of EEG data to filter.
filt : np.ndarray
A 1-D array of FIR filter coefficients.

Returns
-------
filtered : np.ndarray
A 2-D array of FIR-filtered EEG data.

Notes
-----
Produces identical output to EEGLAB's ``firfilt`` function (for non-epoched
data). For internal use within :mod:`pyprep.removeTrend`.

"""
# Initialize parameters for FIR filtering
frames_per_window = 2000
group_delay = int((len(filt) - 1) / 2)
n_samples = data.shape[1]
n_windows = int(np.ceil((n_samples - group_delay) / frames_per_window))
pad_len = min(group_delay, n_samples)

# Prepare initial state of filter, using padding at start of data
start_pad_idx = np.zeros(pad_len, dtype=np.uint8)
start_padded = np.concatenate(
(data[:, start_pad_idx], data[:, :pad_len]),
axis=1
)
zi_init = lfilter_zi(filt, 1) * np.take(start_padded, [0], axis=0)
_, zi = lfilter(filt, 1, start_padded, axis=1, zi=zi_init)

# Iterate over windows of signal, filtering in chunks
out = np.zeros_like(data)
for w in range(n_windows):
start = group_delay + w * frames_per_window
end = min(start + frames_per_window, n_samples)
start_out = start - group_delay
end_out = end - group_delay
out[:, start_out:end_out], zi = lfilter(
filt, 1, data[:, start:end], axis=1, zi=zi
)

# Finish filtering data, using padding at end to calculate final values
end_pad_idx = np.zeros(pad_len, dtype=np.uint8) + (n_samples - 1)
end, _ = lfilter(filt, 1, data[:, end_pad_idx], axis=1, zi=zi)
out[:, (n_samples - pad_len):] = end[:, (group_delay - pad_len):]

return out


def _get_random_subset(x, size, rand_state):
"""Get a random subset of items from a list or array, without replacement.

Expand Down
9 changes: 9 additions & 0 deletions tests/test_removeTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ def test_highpass():
lowpass_filt2 = removeTrend.removeTrend(
signal, detrendType="High pass", sample_rate=srate, detrendCutoff=1
)
lowpass_filt3 = removeTrend.removeTrend(
signal,
detrendType="High pass",
sample_rate=srate,
detrendCutoff=1,
matlab_strict=True
)
error1 = lowpass_filt1 - highfreq_signal
error2 = lowpass_filt2 - highfreq_signal
error3 = lowpass_filt3 - highfreq_signal
assert np.sqrt(np.mean(error1 ** 2)) < 0.1
assert np.sqrt(np.mean(error2 ** 2)) < 0.1
assert np.sqrt(np.mean(error3 ** 2)) < 0.1


def test_detrend():
Expand Down
24 changes: 23 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np

from pyprep.utils import (
_mat_round, _mat_quantile, _mat_iqr, _get_random_subset, _correlate_arrays
_mat_round, _mat_quantile, _mat_iqr, _get_random_subset, _correlate_arrays,
_eeglab_create_highpass
)


Expand Down Expand Up @@ -92,3 +93,24 @@ def test_correlate_arrays():
corr_expected = np.asarray([-0.0898, 0.0327, -0.1140])
corr_actual = _correlate_arrays(a, b, matlab_strict=True)
assert all(np.isclose(corr_expected, corr_actual, atol=0.001))


def test_eeglab_create_highpass():
"""Test EEGLAB-equivalent high-pass filter creation.

NOTE: EEGLAB values were obtained using breakpoints in ``pop_eegfiltnew``,
since filter creation and data filtering are both done in the same function.
Values here are first 4 values of the array ``b`` which contains the FIR
filter coefficents used by the function.

"""
# Compare initial FIR filter coefficents with EEGLAB
expected_vals = [5.3691e-5, 5.4165e-5, 5.4651e-5, 5.5149e-5]
actual_vals = _eeglab_create_highpass(cutoff=1.0, srate=256)[:4]
assert all(np.isclose(expected_vals, actual_vals, atol=0.001))

# Compare middle FIR filter coefficent with EEGLAB
vals = _eeglab_create_highpass(cutoff=1.0, srate=256)
expected_val = 0.9961
actual_val = vals[len(vals) // 2]
assert np.isclose(expected_val, actual_val, atol=0.001)