/
test_find_noisy_channels.py
224 lines (199 loc) · 8.86 KB
/
test_find_noisy_channels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""Test the find_noisy_channels module."""
import numpy as np
import pytest
from pyprep.find_noisy_channels import NoisyChannels
from pyprep.removeTrend import removeTrend
@pytest.mark.usefixtures("raw", "montage")
def test_findnoisychannels(raw, montage):
"""Test find noisy channels."""
# Set a random state for the test
rng = np.random.RandomState(30)
raw.set_montage(montage)
nd = NoisyChannels(raw, random_state=rng)
nd.find_all_bads(ransac=True)
bads = nd.get_bads()
iterations = (
10 # remove any noisy channels by interpolating the bads for 10 iterations
)
for iter in range(0, iterations):
if len(bads) == 0:
continue
raw.info["bads"] = bads
raw.interpolate_bads()
nd = NoisyChannels(raw, random_state=rng)
nd.find_all_bads(ransac=True)
bads = nd.get_bads()
# make sure no bad channels exist in the data
raw.drop_channels(ch_names=bads)
# Test for NaN and flat channels
raw_tmp = raw.copy()
m, n = raw_tmp._data.shape
# Insert a nan value for a random channel and make another random channel
# completely flat (ones)
idxs = rng.choice(np.arange(m), size=2, replace=False)
rand_chn_idx1 = idxs[0]
rand_chn_idx2 = idxs[1]
rand_chn_lab1 = raw_tmp.ch_names[rand_chn_idx1]
rand_chn_lab2 = raw_tmp.ch_names[rand_chn_idx2]
raw_tmp._data[rand_chn_idx1, n - 1] = np.nan
raw_tmp._data[rand_chn_idx2, :] = np.ones(n)
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_nan_flat()
assert nd.bad_by_nan == [rand_chn_lab1]
assert nd.bad_by_flat == [rand_chn_lab2]
# Test for high and low deviations in EEG data
raw_tmp = raw.copy()
m, n = raw_tmp._data.shape
# Now insert one random channel with very low deviations
rand_chn_idx = int(rng.randint(0, m, 1))
rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
raw_tmp._data[rand_chn_idx, :] = raw_tmp._data[rand_chn_idx, :] / 10
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_deviation()
assert rand_chn_lab in nd.bad_by_deviation
# Inserting one random channel with a high deviation
raw_tmp = raw.copy()
rand_chn_idx = int(rng.randint(0, m, 1))
rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
arbitrary_scaling = 5
raw_tmp._data[rand_chn_idx, :] *= arbitrary_scaling
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_deviation()
assert rand_chn_lab in nd.bad_by_deviation
# Test for correlation between EEG channels
raw_tmp = raw.copy()
m, n = raw_tmp._data.shape
rand_chn_idx = int(rng.randint(0, m, 1))
rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
# Use cosine instead of sine to create a signal
low = 10
high = 30
n_freq = 5
signal = np.zeros((1, n))
for freq_i in range(n_freq):
freq = rng.randint(low, high, n)
signal[0, :] += np.cos(2 * np.pi * raw.times * freq)
raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_correlation()
assert rand_chn_lab in nd.bad_by_correlation
bad_by_correlation_orig = nd.bad_by_correlation # save for dropout tests
# Test for channels with signal dropouts (reuse data from correlation tests)
dropout_idx = rand_chn_idx - 1 if rand_chn_idx > 0 else 1
# Make 2nd and 4th quarters of the dropout channel completely flat
raw_tmp._data[dropout_idx, :int(n/4)] = 0
raw_tmp._data[dropout_idx, int(3*n/4):] = 0
# Run correlation and dropout detection on data
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_correlation() # also does dropout detection
# Test if dropout channel detected correctly
assert raw_tmp.ch_names[dropout_idx] in nd.bad_by_dropout
# Test if correlations still detected correctly
bad_orig_plus_dropout = bad_by_correlation_orig + nd.bad_by_dropout
same_bads = set(nd.bad_by_correlation) == set(bad_by_correlation_orig)
same_plus_dropout = set(nd.bad_by_correlation) == set(bad_orig_plus_dropout)
assert same_bads or same_plus_dropout
# Test for high freq noise detection
raw_tmp = raw.copy()
m, n = raw_tmp._data.shape
rand_chn_idx = int(rng.randint(0, m, 1))
rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
# Use freqs between 90 and 100 Hz to insert hf noise
signal = np.zeros((1, n))
for freq_i in range(n_freq):
freq = rng.randint(90, 100, n)
signal[0, :] += np.sin(2 * np.pi * raw.times * freq)
raw_tmp._data[rand_chn_idx, :] = signal * 1e-6
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_hfnoise()
assert rand_chn_lab in nd.bad_by_hf_noise
# Test for high freq noise detection when sample rate < 100 Hz
raw_tmp.resample(80) # downsample to 80 Hz
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_hfnoise()
assert len(nd.bad_by_hf_noise) == 0
# Test for signal to noise ratio in EEG data
raw_tmp = raw.copy()
m, n = raw_tmp._data.shape
rand_chn_idx = int(rng.randint(0, m, 1))
rand_chn_lab = raw_tmp.ch_names[rand_chn_idx]
# inserting an uncorrelated high frequency (90 Hz) signal in one channel
raw_tmp[rand_chn_idx, :] = np.sin(2 * np.pi * raw.times * 90) * 1e-6
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_bad_by_SNR()
assert rand_chn_lab in nd.bad_by_SNR
@pytest.mark.usefixtures("raw", "montage")
def test_find_bad_by_ransac(raw, montage):
"""Test the RANSAC component of NoisyChannels."""
# Set a fixed random seed and a montage for the tests
rng = 435656
raw.set_montage(montage)
# RANSAC identifies channels that go bad together and are highly correlated.
# Inserting highly correlated signal in channels 0 through 3 at 30 Hz
raw_tmp = raw.copy()
raw_tmp._data[0:6, :] = np.cos(2 * np.pi * raw.times * 30) * 1e-6
# Pre-detrend data to save time during NoisyChannels initialization
raw_tmp._data = removeTrend(raw_tmp.get_data(), raw.info["sfreq"])
# Run different variations of RANSAC on the same data
test_matrix = {
# List items represent [matlab_strict, channel_wise, max_chunk_size]
'by_window': [False, False, None],
'by_channel': [False, True, None],
'by_channel_maxchunk': [False, True, 2],
'by_window_strict': [True, False, None],
'by_channel_strict': [True, True, None]
}
bads = {}
corr = {}
for name, args in test_matrix.items():
nd = NoisyChannels(
raw_tmp, do_detrend=False, random_state=rng, matlab_strict=args[0]
)
nd.find_bad_by_ransac(channel_wise=args[1], max_chunk_size=args[2])
# Save bad channels and RANSAC correlation matrix for later comparison
bads[name] = nd.bad_by_ransac
corr[name] = nd._extra_info['bad_by_ransac']['ransac_correlations']
# Test whether all methods detected bad channels properly
assert bads['by_window'] == raw_tmp.ch_names[0:6]
assert bads['by_channel'] == raw_tmp.ch_names[0:6]
assert bads['by_channel_maxchunk'] == raw_tmp.ch_names[0:6]
assert bads['by_window_strict'] == raw_tmp.ch_names[0:6]
assert bads['by_channel_strict'] == raw_tmp.ch_names[0:6]
# Make sure non-strict correlation matrices all match
assert np.allclose(corr['by_window'], corr['by_channel'])
assert np.allclose(corr['by_window'], corr['by_channel_maxchunk'])
# Make sure MATLAB-strict correlation matrices match
assert np.allclose(corr['by_window_strict'], corr['by_channel_strict'])
# Make sure strict and non-strict matrices differ
assert not np.allclose(corr['by_window'], corr['by_window_strict'])
# Set n_samples very very high to trigger a memory error
n_samples = int(1e100)
nd = NoisyChannels(raw_tmp, do_detrend=False, random_state=rng)
with pytest.raises(MemoryError):
nd.find_bad_by_ransac(n_samples=n_samples)
# Set n_samples to a float to trigger a type error
n_samples = 35.5
nd = NoisyChannels(raw_tmp, do_detrend=False, random_state=rng)
with pytest.raises(TypeError):
nd.find_bad_by_ransac(n_samples=n_samples)
# Test IOError when too few good channels for RANSAC sample size
raw_tmp = raw.copy()
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_all_bads(ransac=False)
# Make 80% of channels bad
num_bad_channels = int(raw._data.shape[0] * 0.8)
bad_channels = raw.info["ch_names"][0:num_bad_channels]
nd.bad_by_deviation = bad_channels
with pytest.raises(IOError):
nd.find_bad_by_ransac()
# Test IOError when not enough channels for ransac predictions
raw_tmp = raw.copy()
# Make flat all channels except 2
num_bad_channels = raw._data.shape[0] - 2
raw_tmp._data[0:num_bad_channels, :] = np.zeros_like(
raw_tmp._data[0:num_bad_channels, :]
)
nd = NoisyChannels(raw_tmp, random_state=rng)
nd.find_all_bads(ransac=False)
with pytest.raises(IOError):
nd.find_bad_by_ransac()