Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make noisy channel exclusion during Reference compatible with MATLAB PREP #93

Merged
merged 16 commits into from
Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/matlab_differences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,43 @@ roughly mean-centered. and will thus produce similar values to normal Pearson
correlation. However, to avoid making any assumptions about the signal for any
given channel / window, PyPREP defaults to normal Pearson correlation unless
strict MATLAB equivalence is requested.


Differences in Robust Referencing
---------------------------------

During the robust referencing part of the pipeline, PREP tries to estimate a
"clean" average reference signal for the dataset, excluding any channels
flagged as noisy from contaminating the reference. The robust referencing
process is performed using the following logic:

1) First, an initial pass of noisy channel detection is performed to identify
channels bad by NaN values, flat signal, or low SNR: the data is then
average-referenced excluding these channels. These channels are subsequently
marked as "unusable" and are excluded from any future average referencing.

2) Noisy channel detection is performed on a copy of the re-referenced signal,
and any newly detected bad channels are added to the full set of channels
to be excluded from the reference signal.

3) After noisy channel detection, all bad channels detected so far are
interpolated, and a new estimate of the robust average reference is
calculated using the mean signal of all good channels and all interpolated
bad channels (except those flagged as "unusable" during the first step).

4) A fresh copy of the re-referenced signal from Step 1 is re-referenced using
the new reference signal calculated in Step 3.

5) Steps 2 through 4 are repeated until either two iterations have passed and
no new noisy channels have been detected since the previous iteration, or
the maximum number of reference iterations has been exceeded (default: 5).
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved


Exclusion of dropout channels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In MATLAB PREP, dropout channels (i.e., channels that have intermittent periods
of flat signal) are detected on each iteration of the reference loop, but are
currently not factored into the full set of "bad" channels to be interpolated.
By contrast, PyPREP will detect and interpolate any bad-by-dropout channels
detected during robust referencing.
43 changes: 29 additions & 14 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,46 +125,61 @@ def _get_filtered_data(self):

return EEG_filt

def get_bads(self, verbose=False):
"""Get a list of all channels currently flagged as bad.
def get_bads(self, verbose=False, as_dict=False):
"""Get the names of all channels currently flagged as bad.

Note that this method does not perform any bad channel detection itself,
and only reports channels already detected as bad by other methods.

Parameters
----------
verbose : bool
verbose : bool, optional
If ``True``, a summary of the channels currently flagged as by bad per
category is printed. Defaults to ``False``.
as_dict: bool, optional
If ``True``, this method will return a dict of the channels currently
flagged as bad by each individual bad channel type. If ``False``, this
method will return a list of all unique bad channels detected so far.
Defaults to ``False``.

Returns
-------
bads : list
THe names of all bad channels detected by any method so far.
bads : list or dict
The names of all bad channels detected so far, either as a combined
list or a dict indicating the channels flagged bad by each type.

"""
bads = {
"n/a": self.bad_by_nan,
"flat": self.bad_by_flat,
"deviation": self.bad_by_deviation,
"hf noise": self.bad_by_hf_noise,
"correl": self.bad_by_correlation,
"SNR": self.bad_by_SNR,
"dropout": self.bad_by_dropout,
"RANSAC": self.bad_by_ransac,
"bad_by_nan": self.bad_by_nan,
"bad_by_flat": self.bad_by_flat,
"bad_by_deviation": self.bad_by_deviation,
"bad_by_hf_noise": self.bad_by_hf_noise,
"bad_by_correlation": self.bad_by_correlation,
"bad_by_SNR": self.bad_by_SNR,
"bad_by_dropout": self.bad_by_dropout,
"bad_by_ransac": self.bad_by_ransac,
}

all_bads = set()
for bad_chs in bads.values():
all_bads.update(bad_chs)

name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"}
if verbose:
out = f"Found {len(all_bads)} uniquely bad channels:\n"
for bad_type, bad_chs in bads.items():
bad_type = bad_type.replace("bad_by_", "")
if bad_type in name_map.keys():
bad_type = name_map[bad_type]
out += f"\n{len(bad_chs)} by {bad_type}: {bad_chs}\n"
print(out)

return list(all_bads)
if as_dict:
bads["bad_all"] = list(all_bads)
else:
bads = list(all_bads)

return bads

def find_all_bads(self, ransac=True, channel_wise=False, max_chunk_size=None):
"""Call all the functions to detect bad channels.
Expand Down
99 changes: 31 additions & 68 deletions pyprep/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,7 @@ def perform_reference(self):
# Record Noisy channels and EEG before interpolation
self.bad_before_interpolation = noisy_detector.get_bads(verbose=True)
self.EEG_before_interpolation = self.EEG.copy()
self.noisy_channels_before_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True)
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
self._extra_info["interpolated"] = noisy_detector._extra_info

bad_channels = _union(self.bad_before_interpolation, self.unusable_channels)
Expand All @@ -170,17 +160,7 @@ def perform_reference(self):
noisy_detector.find_all_bads(**self.ransac_settings)
self.still_noisy_channels = noisy_detector.get_bads()
self.raw.info["bads"] = self.still_noisy_channels
self.noisy_channels_after_interpolation = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_after_interpolation = noisy_detector.get_bads(as_dict=True)
self._extra_info["remaining_bad"] = noisy_detector._extra_info

return self
Expand Down Expand Up @@ -213,17 +193,7 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels_original = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": noisy_detector.bad_by_deviation,
"bad_by_hf_noise": noisy_detector.bad_by_hf_noise,
"bad_by_correlation": noisy_detector.bad_by_correlation,
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_dropout": noisy_detector.bad_by_dropout,
"bad_by_ransac": noisy_detector.bad_by_ransac,
"bad_all": noisy_detector.get_bads(),
}
self.noisy_channels_original = noisy_detector.get_bads(as_dict=True)
self._extra_info["initial_bad"] = noisy_detector._extra_info
logger.info("Bad channels: {}".format(self.noisy_channels_original))

Expand All @@ -235,16 +205,16 @@ def robust_reference(self):
reference_channels = _set_diff(self.reference_channels, self.unusable_channels)

# Initialize channels to permanently flag as bad during referencing
self.noisy_channels = {
noisy = {
"bad_by_nan": noisy_detector.bad_by_nan,
"bad_by_flat": noisy_detector.bad_by_flat,
"bad_by_deviation": [],
"bad_by_hf_noise": [],
"bad_by_correlation": [],
"bad_by_SNR": noisy_detector.bad_by_SNR,
"bad_by_SNR": [],
"bad_by_dropout": [],
"bad_by_ransac": [],
"bad_all": self.unusable_channels,
"bad_all": [],
}

# Get initial estimate of the reference by the specified method
Expand All @@ -260,7 +230,7 @@ def robust_reference(self):
# Remove reference from signal, iteratively interpolating bad channels
raw_tmp = raw.copy()
iterations = 0
noisy_channels_old = []
previous_bads = set()
max_iteration_num = 4
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved

while True:
Expand All @@ -272,51 +242,43 @@ def robust_reference(self):
matlab_strict=self.matlab_strict,
)
# Detrend applied at the beginning of the function.

# Detect all currently bad channels
noisy_detector.find_all_bads(**self.ransac_settings)
self.noisy_channels["bad_by_nan"] = _union(
self.noisy_channels["bad_by_nan"], noisy_detector.bad_by_nan
)
self.noisy_channels["bad_by_flat"] = _union(
self.noisy_channels["bad_by_flat"], noisy_detector.bad_by_flat
)
self.noisy_channels["bad_by_deviation"] = _union(
self.noisy_channels["bad_by_deviation"], noisy_detector.bad_by_deviation
)
self.noisy_channels["bad_by_hf_noise"] = _union(
self.noisy_channels["bad_by_hf_noise"], noisy_detector.bad_by_hf_noise
)
self.noisy_channels["bad_by_correlation"] = _union(
self.noisy_channels["bad_by_correlation"],
noisy_detector.bad_by_correlation,
)
self.noisy_channels["bad_by_ransac"] = _union(
self.noisy_channels["bad_by_ransac"], noisy_detector.bad_by_ransac
)
self.noisy_channels["bad_all"] = _union(
self.noisy_channels["bad_all"], noisy_detector.get_bads()
)
logger.info("Bad channels: {}".format(self.noisy_channels))
noisy_new = noisy_detector.get_bads(as_dict=True)

# Specify bad channel types to ignore when updating noisy channels
# NOTE: MATLAB PREP ignores dropout channels, possibly by mistake?
sappelhoff marked this conversation as resolved.
Show resolved Hide resolved
ignore = ["bad_by_SNR", "bad_all"]
if self.matlab_strict:
ignore += ["bad_by_dropout"]

# Update set of all noisy channels detected so far with any new ones
bad_chans = set()
for bad_type in noisy_new.keys():
noisy[bad_type] = _union(noisy[bad_type], noisy_new[bad_type])
if bad_type not in ignore:
bad_chans.update(noisy[bad_type])
noisy["bad_all"] = list(bad_chans)
logger.info("Bad channels: {}".format(noisy))

if (
iterations > 1
and (
not self.noisy_channels["bad_all"]
or set(self.noisy_channels["bad_all"]) == set(noisy_channels_old)
)
and (len(bad_chans) == 0 or bad_chans == previous_bads)
or iterations > max_iteration_num
):
break
noisy_channels_old = self.noisy_channels["bad_all"].copy()
previous_bads = bad_chans.copy()

if raw_tmp.info["nchan"] - len(self.noisy_channels["bad_all"]) < 2:
if raw_tmp.info["nchan"] - len(bad_chans) < 2:
raise ValueError(
"RobustReference:TooManyBad "
"Could not perform a robust reference -- not enough good channels"
)

if self.noisy_channels["bad_all"]:
if len(bad_chans) > 0:
raw_tmp._data = signal * 1e-6
raw_tmp.info["bads"] = self.noisy_channels["bad_all"]
raw_tmp.info["bads"] = list(bad_chans)
raw_tmp.interpolate_bads()
signal_tmp = raw_tmp.get_data() * 1e6
else:
Expand All @@ -332,6 +294,7 @@ def robust_reference(self):
logger.info("Iterations: {}".format(iterations))

logger.info("Robust reference done")
self.noisy_channels = noisy
return self.noisy_channels, self.reference_signal

@staticmethod
Expand Down