Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PyPREP's random channel picks match MATLAB PREP's for the same random seeds #62

Merged
merged 9 commits into from
Apr 10, 2021
1 change: 1 addition & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Changelog
- Added two attributes :attr:`PrepPipeline.noisy_channels_before_interpolation <prep_pipeline.PrepPipeline>` and :attr:`PrepPipeline.noisy_channels_after_interpolation <prep_pipeline.PrepPipeline>` which have the detailed output of each noisy criteria, by `Yorguin Mantilla`_ (:gh:`45`)
- Added two keys to the :attr:`PrepPipeline.noisy_channels_original <prep_pipeline.PrepPipeline>` dictionary: ``bad_by_dropout`` and ``bad_by_SNR``, by `Yorguin Mantilla`_ (:gh:`45`)
- Changed RANSAC chunking logic to reduce max memory use and prefer equal chunk sizes where possible, by `Austin Hurst`_ (:gh:`44`)
- Changed RANSAC's random channel sampling code to produce the same results as MATLAB PREP for the same random seed, additionally changing the default RANSAC sample size from 25% of all *good* channels (e.g. 15 for a 64-channel dataset with 4 bad channels) to 25% of *all* channels (e.g. 16 for the same dataset), by `Austin Hurst`_ (:gh:`62`)

Bug
~~~
Expand Down
20 changes: 12 additions & 8 deletions pyprep/ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mne.channels.interpolation import _make_interpolation_matrix
from mne.utils import check_random_state

from pyprep.utils import split_list, verify_free_ram
from pyprep.utils import split_list, verify_free_ram, _get_random_subset


def find_bad_by_ransac(
Expand Down Expand Up @@ -100,14 +100,18 @@ def find_bad_by_ransac(

# Check if we have enough remaining channels
# after exclusion of bad channels
n_pred_chns = int(np.ceil(fraction_good * n_chans_good))
n_chans = data.shape[0]
n_pred_chns = int(np.around(fraction_good * n_chans))

if n_pred_chns <= 3:
raise IOError(
"Too few channels available to reliably perform"
" ransac. Perhaps, too many channels have failed"
" quality tests."
)
sample_pct = int(fraction_good * 100)
e = "Too few channels in the original data to reliably perform ransac "
e += "(minimum {0} for a sample size of {1}%)."
raise IOError(e.format(int(np.floor(4.0 / fraction_good)), sample_pct))
elif n_chans_good < (n_pred_chns + 1):
e = "Too many noisy channels in the data to reliably perform ransac "
e += "(only {0} good channels remaining, need at least {1})."
raise IOError(e.format(n_chans_good, n_pred_chns + 1))

# Before running, make sure we have enough memory when using the
# smallest possible chunk size
Expand All @@ -119,7 +123,7 @@ def find_bad_by_ransac(
rng = check_random_state(random_state)
for i in range(n_samples):
# Pick a random subset of clean channels to use for interpolation
picks = rng.choice(good_chans, size=n_pred_chns, replace=False)
picks = _get_random_subset(good_chans, n_pred_chns, rng)
random_ch_picks.append(picks)

# Correlation windows setup
Expand Down
35 changes: 35 additions & 0 deletions pyprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,41 @@ def _mat_iqr(arr, axis=None):
return iqr(arr, rng=np.clip(iqr_adj, 0, 100), axis=axis)


def _get_random_subset(x, size, rand_state):
"""Get a random subset of items from a list or array, without replacement.

Parameters
----------
x : list or np.ndarray
One-dimensional array of items to sample from.
size : int
The number of items to sample. Must be less than the number of input
items.
rand_state : np.random.RandState
A random state object to use for random number generation.

Returns
-------
sample : list
A random subset of the input items.

Notes
-----
This function generates random subsets identical to the internal
``randsample`` function in MATLAB PREP's ``findNoisyChannels.m``, allowing
the same random seed to produce identical results across both PyPREP and
MATLAB PREP.

"""
sample = []
remaining = list(x)
for val in rand_state.rand(size):
index = round(1 + (len(remaining) - 1) * val) - 1
pick = remaining.pop(index)
sample.append(pick)
return sample


def filter_design(N_order, amp, freq):
"""Create FIR low-pass filter for EEG data using frequency sampling method.

Expand Down
11 changes: 11 additions & 0 deletions tests/test_find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def test_findnoisychannels(raw, montage):
with pytest.raises(TypeError):
nd.find_bad_by_ransac(n_samples=n_samples)

# Test IOError when too few good channels for RANSAC sample size
raw_tmp = raw.copy()
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_all_bads(ransac=False)
# Make 80% of channels bad
num_bad_channels = int(raw._data.shape[0] * 0.8)
bad_channels = raw.info["ch_names"][0:num_bad_channels]
nd.bad_by_hf_noise = bad_channels
with pytest.raises(IOError):
nd.find_bad_by_ransac()

# Test IOError when not enough channels for ransac predictions
raw_tmp = raw.copy()
# Make flat all channels except 2
Expand Down
14 changes: 13 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test various helper functions."""
import numpy as np

from pyprep.utils import _mat_quantile, _mat_iqr
from pyprep.utils import _mat_quantile, _mat_iqr, _get_random_subset


def test_mat_quantile_iqr():
Expand Down Expand Up @@ -35,3 +35,15 @@ def test_mat_quantile_iqr():
# Test IQR equivalence with MATLAB
iqr_actual = _mat_iqr(tst, axis=0)
assert all(np.isclose(iqr_expected, iqr_actual, atol=0.001))


def test_get_random_subset():
"""Test the function for getting random channel subsets."""
# Generate test data
rng = np.random.RandomState(435656)
chans = range(1, 61)

# Compare random subset equivalence with MATLAB
expected_picks = [6, 47, 55, 31, 29, 44, 36, 15]
actual_picks = _get_random_subset(chans, size=8, rand_state=rng)
assert all(np.equal(expected_picks, actual_picks))