-
Notifications
You must be signed in to change notification settings - Fork 110
/
detection.py
2886 lines (2496 loc) · 113 KB
/
detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
YASA (Yet Another Spindle Algorithm): fast and robust detection of spindles,
slow-waves, and rapid eye movements from sleep EEG recordings.
- Author: Raphael Vallat (www.raphaelvallat.com)
- GitHub: https://github.com/raphaelvallat/yasa
- License: BSD 3-Clause License
"""
import mne
import logging
import numpy as np
import pandas as pd
from scipy import signal
from mne.filter import filter_data
from collections import OrderedDict
from scipy.interpolate import interp1d
from scipy.fftpack import next_fast_len
from sklearn.ensemble import IsolationForest
from .spectral import stft_power
from .numba import _detrend, _rms
from .io import set_log_level, is_tensorpac_installed, is_pyriemann_installed
from .others import (
moving_transform,
trimbothstd,
get_centered_indices,
sliding_window,
_merge_close,
_zerocrossings,
)
logger = logging.getLogger("yasa")
__all__ = [
"art_detect",
"spindles_detect",
"SpindlesResults",
"sw_detect",
"SWResults",
"rem_detect",
"REMResults",
]
#############################################################################
# DATA PREPROCESSING
#############################################################################
def _check_data_hypno(data, sf=None, ch_names=None, hypno=None, include=None, check_amp=True):
"""Helper functions for preprocessing of data and hypnogram."""
# 1) Extract data as a 2D NumPy array
if isinstance(data, mne.io.BaseRaw):
sf = data.info["sfreq"] # Extract sampling frequency
ch_names = data.ch_names # Extract channel names
data = data.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV"))
else:
assert sf is not None, "sf must be specified if not using MNE Raw."
if isinstance(sf, np.ndarray): # Deal with sf = array(100.) --> 100
sf = float(sf)
assert isinstance(sf, (int, float)), "sf must be int or float."
data = np.asarray(data, dtype=np.float64)
assert data.ndim in [1, 2], "data must be 1D (times) or 2D (chan, times)."
if data.ndim == 1:
# Force to 2D array: (n_chan, n_samples)
data = data[None, ...]
n_chan, n_samples = data.shape
# 2) Check channel names
if ch_names is None:
ch_names = ["CHAN" + str(i).zfill(3) for i in range(n_chan)]
else:
assert len(ch_names) == n_chan
# 3) Check hypnogram
if hypno is not None:
hypno = np.asarray(hypno, dtype=int)
assert hypno.ndim == 1, "Hypno must be one dimensional."
assert hypno.size == n_samples, "Hypno must have same size as data."
unique_hypno = np.unique(hypno)
logger.info("Number of unique values in hypno = %i", unique_hypno.size)
assert include is not None, "include cannot be None if hypno is given"
include = np.atleast_1d(np.asarray(include))
assert include.size >= 1, "`include` must have at least one element."
assert hypno.dtype.kind == include.dtype.kind, "hypno and include must have same dtype"
assert np.in1d(hypno, include).any(), (
"None of the stages specified " "in `include` are present in " "hypno."
)
# 4) Check data amplitude
logger.info("Number of samples in data = %i", n_samples)
logger.info("Sampling frequency = %.2f Hz", sf)
logger.info("Data duration = %.2f seconds", n_samples / sf)
all_ptp = np.ptp(data, axis=-1)
all_trimstd = trimbothstd(data, cut=0.05)
bad_chan = np.zeros(n_chan, dtype=bool)
for i in range(n_chan):
logger.info("Trimmed standard deviation of %s = %.4f uV" % (ch_names[i], all_trimstd[i]))
logger.info("Peak-to-peak amplitude of %s = %.4f uV" % (ch_names[i], all_ptp[i]))
if check_amp and not (0.1 < all_trimstd[i] < 1e3):
logger.error(
"Wrong data amplitude for %s "
"(trimmed STD = %.3f). Unit of data MUST be uV! "
"Channel will be skipped." % (ch_names[i], all_trimstd[i])
)
bad_chan[i] = True
# 5) Create sleep stage vector mask
if hypno is not None:
mask = np.in1d(hypno, include)
else:
mask = np.ones(n_samples, dtype=bool)
return (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan)
#############################################################################
# BASE DETECTION RESULTS CLASS
#############################################################################
class _DetectionResults(object):
"""Main class for detection results."""
def __init__(self, events, data, sf, ch_names, hypno, data_filt):
self._events = events
self._data = data
self._sf = sf
self._hypno = hypno
self._ch_names = ch_names
self._data_filt = data_filt
def _check_mask(self, mask):
assert isinstance(mask, (pd.Series, np.ndarray, list, type(None)))
n_events = self._events.shape[0]
if mask is None:
mask = np.ones(n_events, dtype="bool") # All set to True
else:
mask = np.asarray(mask)
assert mask.dtype.kind == "b", "Mask must be a boolean array."
assert mask.ndim == 1, "Mask must be one-dimensional"
assert mask.size == n_events, "Mask.size must be the number of detected events."
return mask
def summary(
self, event_type, grp_chan=False, grp_stage=False, aggfunc="mean", sort=True, mask=None
):
"""Summary"""
# Check masking
mask = self._check_mask(mask)
# Define grouping
grouper = []
if grp_stage is True and "Stage" in self._events:
grouper.append("Stage")
if grp_chan is True and "Channel" in self._events:
grouper.append("Channel")
if not len(grouper):
# Return a copy of self._events after masking, without grouping
return self._events.loc[mask, :].copy()
if event_type == "spindles":
aggdict = {
"Start": "count",
"Duration": aggfunc,
"Amplitude": aggfunc,
"RMS": aggfunc,
"AbsPower": aggfunc,
"RelPower": aggfunc,
"Frequency": aggfunc,
"Oscillations": aggfunc,
"Symmetry": aggfunc,
}
# if 'SOPhase' in self._events:
# from scipy.stats import circmean
# aggdict['SOPhase'] = lambda x: circmean(x, low=-np.pi, high=np.pi)
elif event_type == "sw":
aggdict = {
"Start": "count",
"Duration": aggfunc,
"ValNegPeak": aggfunc,
"ValPosPeak": aggfunc,
"PTP": aggfunc,
"Slope": aggfunc,
"Frequency": aggfunc,
}
if "PhaseAtSigmaPeak" in self._events:
from scipy.stats import circmean
aggdict["PhaseAtSigmaPeak"] = lambda x: circmean(x, low=-np.pi, high=np.pi)
aggdict["ndPAC"] = aggfunc
if "CooccurringSpindle" in self._events:
# We do not average "CooccurringSpindlePeak"
aggdict["CooccurringSpindle"] = aggfunc
aggdict["DistanceSpindleToSW"] = aggfunc
else: # REM
aggdict = {
"Start": "count",
"Duration": aggfunc,
"LOCAbsValPeak": aggfunc,
"ROCAbsValPeak": aggfunc,
"LOCAbsRiseSlope": aggfunc,
"ROCAbsRiseSlope": aggfunc,
"LOCAbsFallSlope": aggfunc,
"ROCAbsFallSlope": aggfunc,
}
# Apply grouping, after masking
df_grp = self._events.loc[mask, :].groupby(grouper, sort=sort, as_index=False).agg(aggdict)
df_grp = df_grp.rename(columns={"Start": "Count"})
# Calculate density (= number per min of each stage)
if self._hypno is not None and grp_stage is True:
stages = np.unique(self._events["Stage"])
dur = {}
for st in stages:
# Get duration in minutes of each stage present in dataframe
dur[st] = self._hypno[self._hypno == st].size / (60 * self._sf)
# Insert new density column in grouped dataframe after count
df_grp.insert(
loc=df_grp.columns.get_loc("Count") + 1,
column="Density",
value=df_grp.apply(lambda rw: rw["Count"] / dur[rw["Stage"]], axis=1),
)
return df_grp.set_index(grouper)
def get_mask(self):
"""get_mask"""
from yasa.others import _index_to_events
mask = np.zeros(self._data.shape, dtype=int)
for i in self._events["IdxChannel"].unique():
ev_chan = self._events[self._events["IdxChannel"] == i]
idx_ev = _index_to_events(ev_chan[["Start", "End"]].to_numpy() * self._sf)
mask[i, idx_ev] = 1
return np.squeeze(mask)
def get_sync_events(
self, center, time_before, time_after, filt=(None, None), mask=None, as_dataframe=True
):
"""Get_sync_events (not for REM, spindles & SW only)"""
from yasa.others import get_centered_indices
assert time_before >= 0
assert time_after >= 0
bef = int(self._sf * time_before)
aft = int(self._sf * time_after)
# TODO: Step size is determined by sf: 0.01 sec at 100 Hz, 0.002 sec at
# 500 Hz, 0.00390625 sec at 256 Hz. Should we add resample=100 (Hz) or step_size=0.01?
time = np.arange(-bef, aft + 1, dtype="int") / self._sf
if any(filt):
data = mne.filter.filter_data(
self._data, self._sf, l_freq=filt[0], h_freq=filt[1], method="fir", verbose=False
)
else:
data = self._data
# Apply mask
mask = self._check_mask(mask)
masked_events = self._events.loc[mask, :]
output = []
for i in masked_events["IdxChannel"].unique():
# Copy is required to merge with the stage later on
ev_chan = masked_events[masked_events["IdxChannel"] == i].copy()
ev_chan["Event"] = np.arange(ev_chan.shape[0])
peaks = (ev_chan[center] * self._sf).astype(int).to_numpy()
# Get centered indices
idx, idx_valid = get_centered_indices(data[i, :], peaks, bef, aft)
# If no good epochs are returned raise a warning
if len(idx_valid) == 0:
logger.error(
"Time before and/or time after exceed data bounds, please "
"lower the temporal window around center. Skipping channel."
)
continue
# Get data at indices and time vector
amps = data[i, idx]
if not as_dataframe:
# Output is a list (n_channels) of numpy arrays (n_events, n_times)
output.append(amps)
continue
# Convert to long-format dataframe
df_chan = pd.DataFrame(amps.T)
df_chan["Time"] = time
# Convert to long-format
df_chan = df_chan.melt(id_vars="Time", var_name="Event", value_name="Amplitude")
# Append stage
if "Stage" in masked_events:
df_chan = df_chan.merge(ev_chan[["Event", "Stage"]].iloc[idx_valid])
# Append channel name
df_chan["Channel"] = ev_chan["Channel"].iloc[0]
df_chan["IdxChannel"] = i
# Append to master dataframe
output.append(df_chan)
if as_dataframe:
output = pd.concat(output, ignore_index=True)
return output
def get_coincidence_matrix(self, scaled=True):
"""get_coincidence_matrix"""
if len(self._ch_names) < 2:
raise ValueError("At least 2 channels are required to calculate coincidence.")
mask = self.get_mask()
mask = pd.DataFrame(mask.T, columns=self._ch_names)
mask.columns.name = "Channel"
def _coincidence(x, y):
"""Calculate the (scaled) coincidence."""
coincidence = (x * y).sum()
if scaled:
# Handle division by zero error
denom = x.sum() * y.sum()
if denom == 0:
coincidence = np.nan
else:
coincidence /= denom
return coincidence
coinc_mat = mask.corr(method=_coincidence)
if not scaled:
# Otherwise diagonal values are set to 1
np.fill_diagonal(coinc_mat.values, mask.sum())
coinc_mat = coinc_mat.astype(int)
return coinc_mat
def plot_average(
self,
event_type,
center="Peak",
hue="Channel",
time_before=1,
time_after=1,
filt=(None, None),
mask=None,
figsize=(6, 4.5),
**kwargs,
):
"""Plot the average event (not for REM, spindles & SW only)"""
import seaborn as sns
import matplotlib.pyplot as plt
df_sync = self.get_sync_events(
center=center, time_before=time_before, time_after=time_after, filt=filt, mask=mask
)
assert not df_sync.empty, "Could not calculate event-locked data."
assert hue in ["Stage", "Channel"], "hue must be 'Channel' or 'Stage'"
assert hue in df_sync.columns, "%s is not present in data." % hue
if event_type == "spindles":
title = "Average spindle"
else: # "sw":
title = "Average SW"
# Start figure
fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.lineplot(data=df_sync, x="Time", y="Amplitude", hue=hue, ax=ax, **kwargs)
# ax.legend(frameon=False, loc='lower right')
ax.set_xlim(df_sync["Time"].min(), df_sync["Time"].max())
ax.set_title(title)
ax.set_xlabel("Time (sec)")
ax.set_ylabel("Amplitude (uV)")
return ax
def plot_detection(self):
"""Plot an overlay of the detected events on the signal."""
import matplotlib.pyplot as plt
import ipywidgets as ipy
# Define mask
sf = self._sf
win_size = 10
mask = self.get_mask()
highlight = self._data * mask
highlight = np.where(highlight == 0, np.nan, highlight)
highlight_filt = self._data_filt * mask
highlight_filt = np.where(highlight_filt == 0, np.nan, highlight_filt)
n_epochs = int((self._data.shape[-1] / sf) / win_size)
times = np.arange(self._data.shape[-1]) / sf
# Define xlim and xrange
xlim = [0, win_size]
xrng = np.arange(xlim[0] * sf, (xlim[1] * sf + 1), dtype=int)
# Plot
fig, ax = plt.subplots(figsize=(12, 4))
plt.plot(times[xrng], self._data[0, xrng], "k", lw=1)
plt.plot(times[xrng], highlight[0, xrng], "indianred")
plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude (uV)")
fig.canvas.header_visible = False
fig.tight_layout()
# WIDGETS
layout = ipy.Layout(width="50%", justify_content="center", align_items="center")
sl_ep = ipy.IntSlider(
min=0,
max=n_epochs,
step=1,
value=0,
layout=layout,
description="Epoch:",
)
sl_amp = ipy.IntSlider(
min=25,
max=500,
step=25,
value=150,
layout=layout,
orientation="horizontal",
description="Amplitude:",
)
dd_ch = ipy.Dropdown(
options=self._ch_names, value=self._ch_names[0], description="Channel:"
)
dd_win = ipy.Dropdown(
options=[1, 5, 10, 30, 60],
value=win_size,
description="Window size:",
)
dd_check = ipy.Checkbox(
value=False,
description="Filtered",
)
def update(epoch, amplitude, channel, win_size, filt):
"""Update plot."""
n_epochs = int((self._data.shape[-1] / sf) / win_size)
sl_ep.max = n_epochs
xlim = [epoch * win_size, (epoch + 1) * win_size]
xrng = np.arange(xlim[0] * sf, (xlim[1] * sf), dtype=int)
# Check if filtered
data = self._data if not filt else self._data_filt
overlay = highlight if not filt else highlight_filt
try:
ax.lines[0].set_data(times[xrng], data[dd_ch.index, xrng])
ax.lines[1].set_data(times[xrng], overlay[dd_ch.index, xrng])
ax.set_xlim(xlim)
except IndexError:
pass
ax.set_ylim([-amplitude, amplitude])
return ipy.interact(
update, epoch=sl_ep, amplitude=sl_amp, channel=dd_ch, win_size=dd_win, filt=dd_check
)
#############################################################################
# SPINDLES DETECTION
#############################################################################
def spindles_detect(
data,
sf=None,
ch_names=None,
hypno=None,
include=(1, 2, 3),
freq_sp=(12, 15),
freq_broad=(1, 30),
duration=(0.5, 2),
min_distance=500,
thresh={"rel_pow": 0.2, "corr": 0.65, "rms": 1.5},
multi_only=False,
remove_outliers=False,
verbose=False,
):
"""Spindles detection.
Parameters
----------
data : array_like
Single or multi-channel data. Unit must be uV and shape (n_samples) or
(n_chan, n_samples). Can also be a :py:class:`mne.io.BaseRaw`,
in which case ``data``, ``sf``, and ``ch_names`` will be automatically
extracted, and ``data`` will also be automatically converted from
Volts (MNE) to micro-Volts (YASA).
sf : float
Sampling frequency of the data in Hz.
Can be omitted if ``data`` is a :py:class:`mne.io.BaseRaw`.
.. tip:: If the detection is taking too long, make sure to downsample
your data to 100 Hz (or 128 Hz). For more details, please refer to
:py:func:`mne.filter.resample`.
ch_names : list of str
Channel names. Can be omitted if ``data`` is a
:py:class:`mne.io.BaseRaw`.
hypno : array_like
Sleep stage (hypnogram). If the hypnogram is loaded, the
detection will only be applied to the value defined in
``include`` (default = N1 + N2 + N3 sleep).
The hypnogram must have the same number of samples as ``data``.
To upsample your hypnogram, please refer to
:py:func:`yasa.hypno_upsample_to_data`.
.. note::
The default hypnogram format in YASA is a 1D integer
vector where:
- -2 = Unscored
- -1 = Artefact / Movement
- 0 = Wake
- 1 = N1 sleep
- 2 = N2 sleep
- 3 = N3 sleep
- 4 = REM sleep
include : tuple, list or int
Values in ``hypno`` that will be included in the mask. The default is
(1, 2, 3), meaning that the detection is applied on N1, N2 and N3
sleep. This has no effect when ``hypno`` is None.
freq_sp : tuple or list
Spindles frequency range. Default is 12 to 15 Hz. Please note that YASA
uses a FIR filter (implemented in MNE) with a 1.5Hz transition band,
which means that for `freq_sp = (12, 15 Hz)`, the -6 dB points are
located at 11.25 and 15.75 Hz.
freq_broad : tuple or list
Broad band frequency range. Default is 1 to 30 Hz.
duration : tuple or list
The minimum and maximum duration of the spindles.
Default is 0.5 to 2 seconds.
min_distance : int
If two spindles are closer than ``min_distance`` (in ms), they are
merged into a single spindles. Default is 500 ms.
thresh : dict
Detection thresholds:
* ``'rel_pow'``: Relative power (= power ratio freq_sp / freq_broad).
* ``'corr'``: Moving correlation between original signal and
sigma-filtered signal.
* ``'rms'``: Number of standard deviations above the mean of a moving
root mean square of sigma-filtered signal.
You can disable one or more threshold by putting ``None`` instead:
.. code-block:: python
thresh = {'rel_pow': None, 'corr': 0.65, 'rms': 1.5}
thresh = {'rel_pow': None, 'corr': None, 'rms': 3}
multi_only : boolean
Define the behavior of the multi-channel detection. If True, only
spindles that are present on at least two channels are kept. If False,
no selection is applied and the output is just a concatenation of the
single-channel detection dataframe. Default is False.
remove_outliers : boolean
If True, YASA will automatically detect and remove outliers spindles
using :py:class:`sklearn.ensemble.IsolationForest`.
The outliers detection is performed on all the spindles
parameters with the exception of the ``Start``, ``Peak``, ``End``,
``Stage``, and ``SOPhase`` columns.
YASA uses a random seed (42) to ensure reproducible results.
Note that this step will only be applied if there are more than 50
detected spindles in the first place. Default to False.
verbose : bool or str
Verbose level. Default (False) will only print warning and error
messages. The logging levels are 'debug', 'info', 'warning', 'error',
and 'critical'. For most users the choice is between 'info'
(or ``verbose=True``) and warning (``verbose=False``).
.. versionadded:: 0.2.0
Returns
-------
sp : :py:class:`yasa.SpindlesResults`
To get the full detection dataframe, use:
>>> sp = spindles_detect(...)
>>> sp.summary()
This will give a :py:class:`pandas.DataFrame` where each row is a
detected spindle and each column is a parameter (= feature or property)
of this spindle. To get the average spindles parameters per channel and
sleep stage:
>>> sp.summary(grp_chan=True, grp_stage=True)
Notes
-----
The parameters that are calculated for each spindle are:
* ``'Start'``: Start time of the spindle, in seconds from the beginning of
data.
* ``'Peak'``: Time at the most prominent spindle peak (in seconds).
* ``'End'`` : End time (in seconds).
* ``'Duration'``: Duration (in seconds)
* ``'Amplitude'``: Peak-to-peak amplitude of the (detrended) spindle in
the raw data (in µV).
* ``'RMS'``: Root-mean-square (in µV)
* ``'AbsPower'``: Median absolute power (in log10 µV^2),
calculated from the Hilbert-transform of the ``freq_sp`` filtered signal.
* ``'RelPower'``: Median relative power of the ``freq_sp`` band in spindle
calculated from a short-term fourier transform and expressed as a
proportion of the total power in ``freq_broad``.
* ``'Frequency'``: Median instantaneous frequency of spindle (in Hz),
derived from an Hilbert transform of the ``freq_sp`` filtered signal.
* ``'Oscillations'``: Number of oscillations (= number of positive peaks
in spindle.)
* ``'Symmetry'``: Location of the most prominent peak of spindle,
normalized from 0 (start) to 1 (end). Ideally this value should be close
to 0.5, indicating that the most prominent peak is halfway through the
spindle.
* ``'Stage'`` : Sleep stage during which spindle occured, if ``hypno``
was provided.
All parameters are calculated from the broadband-filtered EEG
(frequency range defined in ``freq_broad``).
For better results, apply this detection only on artefact-free NREM sleep.
.. warning::
A critical bug was fixed in YASA 0.6.1, in which the number of detected spindles could
vary drastically depending on the sampling frequency of the data. Please make sure to check
any results obtained with this function prior to the 0.6.1 release.
References
----------
The sleep spindles detection algorithm is based on:
* Lacourse, K., Delfrate, J., Beaudry, J., Peppard, P., & Warby, S. C.
(2018). `A sleep spindle detection algorithm that emulates human expert
spindle scoring. <https://doi.org/10.1016/j.jneumeth.2018.08.014>`_
Journal of Neuroscience Methods.
Examples
--------
For a walkthrough of the spindles detection, please refer to the following
Jupyter notebooks:
https://github.com/raphaelvallat/yasa/blob/master/notebooks/01_spindles_detection.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/02_spindles_detection_multi.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/03_spindles_detection_NREM_only.ipynb
https://github.com/raphaelvallat/yasa/blob/master/notebooks/04_spindles_slow_fast.ipynb
"""
set_log_level(verbose)
(data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan) = _check_data_hypno(
data, sf, ch_names, hypno, include
)
# If all channels are bad
if sum(bad_chan) == n_chan:
logger.warning("All channels have bad amplitude. Returning None.")
return None
# Check detection thresholds
if "rel_pow" not in thresh.keys():
thresh["rel_pow"] = 0.20
if "corr" not in thresh.keys():
thresh["corr"] = 0.65
if "rms" not in thresh.keys():
thresh["rms"] = 1.5
do_rel_pow = thresh["rel_pow"] not in [None, "none", "None"]
do_corr = thresh["corr"] not in [None, "none", "None"]
do_rms = thresh["rms"] not in [None, "none", "None"]
n_thresh = sum([do_rel_pow, do_corr, do_rms])
assert n_thresh >= 1, "At least one threshold must be defined."
# Filtering
nfast = next_fast_len(n_samples)
# 1) Broadband bandpass filter (optional -- careful of lower freq for PAC)
data_broad = filter_data(data, sf, freq_broad[0], freq_broad[1], method="fir", verbose=0)
# 2) Sigma bandpass filter
# The width of the transition band is set to 1.5 Hz on each side,
# meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located at
# 11.25 and 15.75 Hz.
data_sigma = filter_data(
data,
sf,
freq_sp[0],
freq_sp[1],
l_trans_bandwidth=1.5,
h_trans_bandwidth=1.5,
method="fir",
verbose=0,
)
# Hilbert power (to define the instantaneous frequency / power)
analytic = signal.hilbert(data_sigma, N=nfast)[:, :n_samples]
inst_phase = np.angle(analytic)
inst_pow = np.square(np.abs(analytic))
inst_freq = sf / (2 * np.pi) * np.diff(inst_phase, axis=-1)
# Extract the SO signal for coupling
# if coupling:
# # We need to use the original (non-filtered data)
# data_so = filter_data(data, sf, freq_so[0], freq_so[1], method='fir',
# l_trans_bandwidth=0.1, h_trans_bandwidth=0.1,
# verbose=0)
# # Now extract the instantaneous phase using Hilbert transform
# so_phase = np.angle(signal.hilbert(data_so, N=nfast)[:, :n_samples])
# Initialize empty output dataframe
df = pd.DataFrame()
for i in range(n_chan):
# ####################################################################
# START SINGLE CHANNEL DETECTION
# ####################################################################
# First, skip channels with bad data amplitude
if bad_chan[i]:
continue
# Compute the pointwise relative power using interpolated STFT
# Here we use a step of 200 ms to speed up the computation.
# Note that even if the threshold is None we still need to calculate it
# for the individual spindles parameter (RelPow).
f, t, Sxx = stft_power(
data_broad[i, :], sf, window=2, step=0.2, band=freq_broad, interp=False, norm=True
)
idx_sigma = np.logical_and(f >= freq_sp[0], f <= freq_sp[1])
rel_pow = Sxx[idx_sigma].sum(0)
# Let's interpolate `rel_pow` to get one value per sample
# Note that we could also have use the `interp=True` in the
# `stft_power` function, however 2D interpolation is much slower than
# 1D interpolation.
func = interp1d(t, rel_pow, kind="cubic", bounds_error=False, fill_value=0)
t = np.arange(n_samples) / sf
rel_pow = func(t)
if do_corr:
_, mcorr = moving_transform(
x=data_sigma[i, :],
y=data_broad[i, :],
sf=sf,
window=0.3,
step=0.1,
method="corr",
interp=True,
)
if do_rms:
_, mrms = moving_transform(
x=data_sigma[i, :], sf=sf, window=0.3, step=0.1, method="rms", interp=True
)
# Let's define the thresholds
if hypno is None:
thresh_rms = mrms.mean() + thresh["rms"] * trimbothstd(mrms, cut=0.10)
else:
thresh_rms = mrms[mask].mean() + thresh["rms"] * trimbothstd(mrms[mask], cut=0.10)
# Avoid too high threshold caused by Artefacts / Motion during Wake
thresh_rms = min(thresh_rms, 10)
logger.info("Moving RMS threshold = %.3f", thresh_rms)
# Boolean vector of supra-threshold indices
idx_sum = np.zeros(n_samples)
if do_rel_pow:
idx_rel_pow = (rel_pow >= thresh["rel_pow"]).astype(int)
idx_sum += idx_rel_pow
logger.info("N supra-theshold relative power = %i", idx_rel_pow.sum())
if do_corr:
idx_mcorr = (mcorr >= thresh["corr"]).astype(int)
idx_sum += idx_mcorr
logger.info("N supra-theshold moving corr = %i", idx_mcorr.sum())
if do_rms:
idx_mrms = (mrms >= thresh_rms).astype(int)
idx_sum += idx_mrms
logger.info("N supra-theshold moving RMS = %i", idx_mrms.sum())
# Make sure that we do not detect spindles outside mask
if hypno is not None:
idx_sum[~mask] = 0
# The detection using the three thresholds tends to underestimate the
# real duration of the spindle. To overcome this, we compute a soft
# threshold by smoothing the idx_sum vector with a ~100 ms window.
# Sampling frequency = 100 Hz --> w = 10 samples
# Sampling frequecy = 256 Hz --> w = 25 samples = 97 ms
w = int(0.1 * sf)
# Critical bugfix March 2022, see https://github.com/raphaelvallat/yasa/pull/55
idx_sum = np.convolve(idx_sum, np.ones(w), mode="same") / w
# And we then find indices that are strictly greater than 2, i.e. we
# find the 'true' beginning and 'true' end of the events by finding
# where at least two out of the three treshold were crossed.
where_sp = np.where(idx_sum > (n_thresh - 1))[0]
# If no events are found, skip to next channel
if not len(where_sp):
logger.warning("No spindle were found in channel %s.", ch_names[i])
continue
# Merge events that are too close
if min_distance is not None and min_distance > 0:
where_sp = _merge_close(where_sp, min_distance, sf)
# Extract start, end, and duration of each spindle
sp = np.split(where_sp, np.where(np.diff(where_sp) != 1)[0] + 1)
idx_start_end = np.array([[k[0], k[-1]] for k in sp]) / sf
sp_start, sp_end = idx_start_end.T
sp_dur = sp_end - sp_start
# Find events with bad duration
good_dur = np.logical_and(sp_dur > duration[0], sp_dur < duration[1])
# If no events of good duration are found, skip to next channel
if all(~good_dur):
logger.warning("No spindle were found in channel %s.", ch_names[i])
continue
# Initialize empty variables
sp_amp = np.zeros(len(sp))
sp_freq = np.zeros(len(sp))
sp_rms = np.zeros(len(sp))
sp_osc = np.zeros(len(sp))
sp_sym = np.zeros(len(sp))
sp_abs = np.zeros(len(sp))
sp_rel = np.zeros(len(sp))
sp_sta = np.zeros(len(sp))
sp_pro = np.zeros(len(sp))
# sp_cou = np.zeros(len(sp))
# Number of oscillations (number of peaks separated by at least 60 ms)
# --> 60 ms because 1000 ms / 16 Hz = 62.5 m, in other words, at 16 Hz,
# peaks are separated by 62.5 ms. At 11 Hz peaks are separated by 90 ms
distance = 60 * sf / 1000
for j in np.arange(len(sp))[good_dur]:
# Important: detrend the signal to avoid wrong PTP amplitude
sp_x = np.arange(data_broad[i, sp[j]].size, dtype=np.float64)
sp_det = _detrend(sp_x, data_broad[i, sp[j]])
# sp_det = signal.detrend(data_broad[i, sp[i]], type='linear')
sp_amp[j] = np.ptp(sp_det) # Peak-to-peak amplitude
sp_rms[j] = _rms(sp_det) # Root mean square
sp_rel[j] = np.median(rel_pow[sp[j]]) # Median relative power
# Hilbert-based instantaneous properties
sp_inst_freq = inst_freq[i, sp[j]]
sp_inst_pow = inst_pow[i, sp[j]]
sp_abs[j] = np.median(np.log10(sp_inst_pow[sp_inst_pow > 0]))
sp_freq[j] = np.median(sp_inst_freq[sp_inst_freq > 0])
# Number of oscillations
peaks, peaks_params = signal.find_peaks(
sp_det, distance=distance, prominence=(None, None)
)
sp_osc[j] = len(peaks)
# For frequency and amplitude, we can also optionally use these
# faster alternatives. If we use them, we do not need to compute
# the Hilbert transform of the filtered signal.
# sp_freq[j] = sf / np.mean(np.diff(peaks))
# sp_amp[j] = peaks_params['prominences'].max()
# Peak location & symmetry index
# pk is expressed in sample since the beginning of the spindle
pk = peaks[peaks_params["prominences"].argmax()]
sp_pro[j] = sp_start[j] + pk / sf
sp_sym[j] = pk / sp_det.size
# SO-spindles coupling
# if coupling:
# sp_cou[j] = so_phase[i, sp[j]][pk]
# Sleep stage
if hypno is not None:
sp_sta[j] = hypno[sp[j]][0]
# Create a dataframe
sp_params = {
"Start": sp_start,
"Peak": sp_pro,
"End": sp_end,
"Duration": sp_dur,
"Amplitude": sp_amp,
"RMS": sp_rms,
"AbsPower": sp_abs,
"RelPower": sp_rel,
"Frequency": sp_freq,
"Oscillations": sp_osc,
"Symmetry": sp_sym,
# 'SOPhase': sp_cou,
"Stage": sp_sta,
}
df_chan = pd.DataFrame(sp_params)[good_dur]
# We need at least 50 detected spindles to apply the Isolation Forest.
if remove_outliers and df_chan.shape[0] >= 50:
col_keep = [
"Duration",
"Amplitude",
"RMS",
"AbsPower",
"RelPower",
"Frequency",
"Oscillations",
"Symmetry",
]
ilf = IsolationForest(
contamination="auto", max_samples="auto", verbose=0, random_state=42
)
good = ilf.fit_predict(df_chan[col_keep])
good[good == -1] = 0
logger.info(
"%i outliers were removed in channel %s." % ((good == 0).sum(), ch_names[i])
)
# Remove outliers from DataFrame
df_chan = df_chan[good.astype(bool)]
logger.info("%i spindles were found in channel %s." % (df_chan.shape[0], ch_names[i]))
# ####################################################################
# END SINGLE CHANNEL DETECTION
# ####################################################################
df_chan["Channel"] = ch_names[i]
df_chan["IdxChannel"] = i
df = pd.concat([df, df_chan], axis=0, ignore_index=True)
# If no spindles were detected, return None
if df.empty:
logger.warning("No spindles were found in data. Returning None.")
return None
# Remove useless columns
to_drop = []
if hypno is None:
to_drop.append("Stage")
else:
df["Stage"] = df["Stage"].astype(int)
# if not coupling:
# to_drop.append('SOPhase')
if len(to_drop):
df = df.drop(columns=to_drop)
# Find spindles that are present on at least two channels
if multi_only and df["Channel"].nunique() > 1:
# We round to the nearest second
idx_good = np.logical_or(
df["Start"].round(0).duplicated(keep=False), df["End"].round(0).duplicated(keep=False)
).to_list()
df = df[idx_good].reset_index(drop=True)
return SpindlesResults(
events=df, data=data, sf=sf, ch_names=ch_names, hypno=hypno, data_filt=data_sigma
)
class SpindlesResults(_DetectionResults):
"""Output class for spindles detection.
Attributes
----------
_events : :py:class:`pandas.DataFrame`
Output detection dataframe
_data : array_like
Original EEG data of shape *(n_chan, n_samples)*.
_data_filt : array_like
Sigma-filtered EEG data of shape *(n_chan, n_samples)*.
_sf : float
Sampling frequency of data.
_ch_names : list
Channel names.
_hypno : array_like or None
Sleep staging vector.
"""
def __init__(self, events, data, sf, ch_names, hypno, data_filt):
super().__init__(events, data, sf, ch_names, hypno, data_filt)
def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc="mean", sort=True):
"""Return a summary of the spindles detection, optionally grouped
across channels and/or stage.
Parameters
----------
grp_chan : bool
If True, group by channel (for multi-channels detection only).
grp_stage : bool
If True, group by sleep stage (provided that an hypnogram was
used).
mask : array_like or None
Custom boolean mask. Only the detected events for which mask is True will be
included in the summary dataframe. Default is None, i.e. no masking