Skip to content

Commit

Permalink
Skip RANSAC predictions for channels already confirmed as bad (#72)
Browse files Browse the repository at this point in the history
* Make RANSAC skip known-bad chans during prediction

* Updated whats_new.rst
  • Loading branch information
a-hurst committed Apr 29, 2021
1 parent fa4bad3 commit 145a5a1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 48 deletions.
1 change: 1 addition & 0 deletions docs/whats_new.rst
Expand Up @@ -48,6 +48,7 @@ Bug
- Fixed "bad channel by flat" threshold in :meth:`NoisyChannels.find_bad_by_nan_flat` to be consistent with MATLAB PREP, by `Austin Hurst`_ (:gh:`60`)
- Changed "bad channel by deviation" and "bad channel by correlation" detection code in :class:`NoisyChannels` to compute IQR and quantiles in the same manner as MATLAB, thus producing identical results to MATLAB PREP, by `Austin Hurst`_ (:gh:`57`)
- Fixed a bug where EEG data was getting reshaped into RANSAC windows incorrectly (channel samples were not sequential), which was causing considerable variability and noise in RANSAC results, by `Austin Hurst`_ (:gh:`67`)
- Fixed RANSAC to avoid making unnecessary signal predictions for known-bad channels, matching MATLAB behaviour and reducing RAM requirements, by `Austin Hurst`_ (:gh:`72`)

API
~~~
Expand Down
64 changes: 16 additions & 48 deletions pyprep/ransac.py
Expand Up @@ -111,7 +111,6 @@ def find_bad_by_ransac(
# Exclude should be the bad channels from other methods
# That is, identify all bad channels by other means
good_idx = mne.pick_channels(list(complete_chn_labs), include=[], exclude=exclude)
good_chn_labs = complete_chn_labs[good_idx]
n_chans_good = good_idx.shape[0]
chn_pos_good = chn_pos[good_idx, :]

Expand Down Expand Up @@ -165,15 +164,15 @@ def find_bad_by_ransac(
# Calculate smallest chunk size for each possible chunk count
chunk_sizes = []
chunk_count = 0
for i in range(1, n_chans_complete + 1):
n_chunks = int(np.ceil(n_chans_complete / i))
for i in range(1, n_chans_good + 1):
n_chunks = int(np.ceil(n_chans_good / i))
if n_chunks != chunk_count:
chunk_count = n_chunks
chunk_sizes.append(i)

chunk_size = chunk_sizes.pop()
mem_error = True
job = list(range(n_chans_complete))
job = list(range(n_chans_good))

if channel_wise:
chunk_size = 1
Expand All @@ -183,14 +182,11 @@ def find_bad_by_ransac(
total_chunks = len(channel_chunks)
current = 1
for chunk in channel_chunks:
channel_correlations[:, chunk] = _ransac_correlations(
channel_correlations[:, good_idx[chunk]] = _ransac_correlations(
chunk,
random_ch_picks,
chn_pos,
chn_pos_good,
good_chn_labs,
complete_chn_labs,
data,
data[good_idx, :],
n_samples,
n,
w_correlation,
Expand All @@ -217,7 +213,7 @@ def find_bad_by_ransac(
"ested samples."
)

# Thresholding
# Calculate fractions of bad RANSAC windows for each channel
thresholded_correlations = channel_correlations < corr_thresh
frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0)

Expand All @@ -233,10 +229,7 @@ def find_bad_by_ransac(
def _ransac_correlations(
chans_to_predict,
random_ch_picks,
chn_pos,
chn_pos_good,
good_chn_labs,
complete_chn_labs,
data,
n_samples,
n,
Expand All @@ -251,15 +244,9 @@ def _ransac_correlations(
Indexes of the channels to predict as they appear in chn_pos.
random_ch_picks : list
each element is a list of indexes of the channels (as they appear
in good_chn_labs) to use for reconstruction in each of the samples.
chn_pos : np.ndarray
3-D coordinates of the electrode positions to predict
in chn_pos_good) to use for reconstruction in each of the samples.
chn_pos_good : np.ndarray
3-D coordinates of all the channels not detected noisy so far
good_chn_labs : array_like
channel labels for the ch_pos_good channels
complete_chn_labs : array_like
labels of the channels in data in the same order
data : np.ndarray
2-D EEG data
n_samples : int
Expand All @@ -285,10 +272,8 @@ def _ransac_correlations(
ransac_eeg = _run_ransac(
n_samples=n_samples,
random_ch_picks=random_ch_picks,
chn_pos=chn_pos[chans_to_predict, :],
chn_pos=chn_pos_good[chans_to_predict, :],
chn_pos_good=chn_pos_good,
good_chn_labs=good_chn_labs,
complete_chn_labs=complete_chn_labs,
data=data,
matlab_strict=matlab_strict,
)
Expand Down Expand Up @@ -320,8 +305,6 @@ def _run_ransac(
random_ch_picks,
chn_pos,
chn_pos_good,
good_chn_labs,
complete_chn_labs,
data,
matlab_strict,
):
Expand All @@ -336,15 +319,11 @@ def _run_ransac(
number of interpolations from which a median will be computed
random_ch_picks : list
each element is a list of indexes of the channels (as they appear
in good_chn_labs) to use for reconstruction in each of the samples.
in chn_pos_good) to use for reconstruction in each of the samples.
chn_pos : np.ndarray
3-D coordinates of the electrode position
chn_pos_good : np.ndarray
3-D coordinates of all the channels not detected noisy so far
good_chn_labs : array_like
channel labels for the ch_pos_good channels
complete_chn_labs : array_like
labels of the channels in data in the same order
data : np.ndarray
2-D EEG data
matlab_strict : bool
Expand Down Expand Up @@ -372,7 +351,7 @@ def _run_ransac(
# Get the random channel selection for the current sample
reconstr_idx = random_ch_picks[sample]
eeg_predictions[..., sample] = _get_ransac_pred(
chn_pos, chn_pos_good, good_chn_labs, complete_chn_labs, reconstr_idx, data
chn_pos, chn_pos_good, reconstr_idx, data
)

# Form median from all predictions
Expand All @@ -383,12 +362,11 @@ def _run_ransac(
ransac_eeg = eeg_predictions[:, :, median_idx]
else:
ransac_eeg = np.median(eeg_predictions, axis=-1, overwrite_input=True)

return ransac_eeg


def _get_ransac_pred(
chn_pos, chn_pos_good, good_chn_labs, complete_chn_labs, reconstr_idx, data
):
def _get_ransac_pred(chn_pos, chn_pos_good, reconstr_idx, data):
"""Perform RANSAC prediction.
Parameters
Expand All @@ -397,12 +375,8 @@ def _get_ransac_pred(
3-D coordinates of the electrode position
chn_pos_good : np.ndarray
3-D coordinates of all the channels not detected noisy so far
good_chn_labs : array_like
channel labels for the ch_pos_good channels
complete_chn_labs : array_like
labels of the channels in data in the same order
reconstr_idx : array_like
indexes of the channels in good_chn_labs to use for reconstruction
indexes of the channels in chn_pos_good to use for reconstruction
data : np.ndarray
2-D EEG data
Expand All @@ -412,17 +386,11 @@ def _get_ransac_pred(
Single RANSAC prediction
"""
# Get positions and according labels
reconstr_labels = good_chn_labs[reconstr_idx]
# Get positions
reconstr_pos = chn_pos_good[reconstr_idx, :]

# Map the labels to their indices within the complete data
# Do not use mne.pick_channels, because it will return a sorted list.
reconstr_picks = [
list(complete_chn_labs).index(chn_lab) for chn_lab in reconstr_labels
]

# Interpolate
interpol_mat = _make_interpolation_matrix(reconstr_pos, chn_pos)
ransac_pred = np.matmul(interpol_mat, data[reconstr_picks, :])
ransac_pred = np.matmul(interpol_mat, data[reconstr_idx, :])

return ransac_pred

0 comments on commit 145a5a1

Please sign in to comment.