In [21]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa

In [2]:
%matplotlib qt
# to make plots interactive

## *Pickle data

### Importing data

In [3]:
file_path = r"C:\EEG DATA\FL_label_data.pickle"
# added r in front of file path to make it a raw string, to make sure that \ is not interpreted as a newline character

# open the pickle file
with open(file_path, "rb") as file:
    label_data = pickle.load(file)

# show the label_data type
print(type(label_data))

<class 'dict'>


In [10]:
label_data['020']

{'label': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3,
        3, 3, 3, 3, 3, 3, 3, 

### Functions

In [5]:
# to return all the results
# returns a dict so should have commas between values

def extract_onsets(label_data):
    onset_dict = {}
    for key, value in label_data.items():
        labels = np.atleast_1d(value['label'])
        onsets = np.atleast_1d(value['onset'])
        # to ensure that labels and onsets are treated as array
        # because subsequently using np.where
        indices = np.where((labels == 1) | (labels == 2))[0]
        # returns indices where the label is 1 (N2) or 2 (N3)
        if indices.size > 0 and np.all(indices < len(onsets)):
            # to ensure that no out-of-bounds error
            selected_onsets = onsets[indices]
            # retrieve onset value corresponding to label 1 or 2
            onset_dict[key] = selected_onsets
            # save extracted onset under correct key in dict
            print(f"Key: {key}, Onset values for labels 1 (N2) and 2 (N3): {', '.join(map(str, selected_onsets))}")
        else:
            print(f"Key: {key}, Warning: The indices do not match")
    return onset_dict
    # returning the onset_dict and what you're printing
    # should I be only returning what is supposed to be printed? or maybe only the dict, since already has commas?

def group_by_increment(onset_values, increment=30):
    groups = []
    # will be a list of lists
    current_group = [float(onset_values[0])]
    # initializes this list with the first value from onset_values (the input)
    
    for i in range(1, len(onset_values)):
        # loops through all the onset values
        if onset_values[i] - onset_values[i - 1] == increment:
            # if i = 1, if onset_values[1] - onset_values[0] == 30
            current_group.append(float(onset_values[i]))
            # add the value at current index
        else:
            # if not a difference of 30
            # means you've reached the end of that sublist
            if len(current_group) > 1:
                # if there is more than one value in that group
                groups.append(current_group)
                # add the sublist to the big list
            current_group = [float(onset_values[i])]
            # starts a new current group with the new value at the current index
    
    if len(current_group) > 1:
        groups.append(current_group)
    # once you exit the group, if the last current_group contains more than one value
    # then you can add it to group
    # to make sure that last sequence is not left out
    
    return groups

def extract_segments(raw, groups):
    raw_segments = []
    
    for group in groups:
        start = group[0]
        # start = first value in group
        #stop = min(group[-1], max_time) 
        stop = group[-1]
        # stop = last value in group

        #if start >= max_time:
            #continue
        # takes the smaller of the two values
        segment = raw.copy().crop(tmin=start, tmax=stop)
        raw_segments.append(segment)
    
    return raw_segments

## *Raw data

### Importing 

In [8]:
# follow instructions from YASA

participant_020_file = r"C:\EEG DATA\020\eeg\TMR.vhdr"
participant_020_raw = mne.io.read_raw_brainvision(vhdr_fname=participant_020_file, preload=True)

participant_020_raw.resample(100)
# downsample to 100 Hz
participant_020_raw.filter(0.1, 40)
# bandpass filter between 0.1 Hz and 40 Hz
participant_020_raw.pick(['Fz'])

Extracting parameters from C:\EEG DATA\020\eeg\TMR.vhdr...
Setting channel info structure...
Reading 0 ... 11188539  =      0.000 ... 22377.078 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 3301 samples (33.010 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    1.1s


Unnamed: 0,General,General.1
,Filename(s),TMR.eeg
,MNE object type,RawBrainVision
,Measurement date,2023-04-06 at 00:49:29 UTC
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,06:12:58 (HH:MM:SS)
,Sampling frequency,100.00 Hz
,Time points,2237708
,Channels,Channels


### Put data in YASA format

In [15]:
participant_020_data, participant_020_times = participant_020_raw.get_data(return_times=True) 

# put data in npz format for the hypnogram 
# numpy array
np.savez("participant_020_npz.npz", data=participant_020_data, times=participant_020_times, ch_names=participant_020_raw.ch_names, sfreq=participant_020_raw.info["sfreq"])

In [28]:
# format the npz data

npzfile_020 = np.load("participant_020_npz.npz")
data_020, ch_names_020 = npzfile_020['data'], npzfile_020['ch_names']
sf_020 = 100
times_020 = np.arange(data_020.size) / sf_020

print(data_020.shape, ch_names_020)
print(np.round(data_020[:, 0:5], 3))

(1, 2237708) ['Fz']
[[-0. -0.  0.  0. -0.]]


In [40]:
# retrieve the labels for participant 020

original_labels_020 = label_data['020']['label']
original_labels_020

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,

In [65]:
# remap values for hypnogram
# original values:
# 0: wake/N1, 1: N2, 2: N3, 3: REM
# YASA values
# 0: wake, 1: N1, 2: N2, 3: N3, 4: REM

yasa_labels_020 = np.array(original_labels_020, dtype=object)
yasa_labels_020[original_labels_020 == 0] = "N1"
# only keep N1 for efficiency
yasa_labels_020[original_labels_020 == 1] = "N2"
yasa_labels_020[original_labels_020 == 2] = "N3"
yasa_labels_020[original_labels_020 == 3] = "REM"

yasa_labels_020

array(['N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1',
       'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N1', 'N2',
       'N2', 'N2', 'N2', 'N2', 'N2', 'N1', 'N1', 'N2', 'N2', 'N2', 'N2',
       'N2', 'N2', 'N2', 'N2', 'N3', 'N2', 'N3', 'N3', 'N3', 'N3', 'N3',
       'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3',
       'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3',
       'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N

### Hypnogram and spectrogram

In [67]:
hypnogram_020 = yasa.Hypnogram(yasa_labels_020, freq="30s")
hypnogram_020.hypno

Epoch
0      N1
1      N1
2      N1
3      N1
4      N1
       ..
740    N1
741    N1
742    N1
743    N1
744    N1
Name: Stage, Length: 745, dtype: category
Categories (7, object): ['WAKE', 'N1', 'N2', 'N3', 'REM', 'ART', 'UNS']

In [68]:
# plot the hypnogram

fig, ax = plt.subplots(1, 1, figsize=(7,3), constrained_layout=True, dpi=80)
ax = hypnogram_020.plot_hypnogram(fill_color="gainsboro", ax=ax)

In [70]:
# upsample the hypnogram for the spectrogram

hypnogram_020_upsampled = yasa.hypno_upsample_to_data(hypno=hypnogram_020.hypno, sf_hypno=(1/30), data=data_020, sf_data=sf_020)
print(hypnogram_020_upsampled.shape, 'Unique values =', np.unique(hypnogram_020_upsampled))



(2237708,) Unique values = ['N1' 'N2' 'N3' 'REM']


In [72]:
hypnogram_020_upsampled

array(['N1', 'N1', 'N1', ..., 'N1', 'N1', 'N1'],
      shape=(2237708,), dtype=object)

In [73]:
# convert back to numbers

# remap values for hypnogram
# original values:
# 0: wake/N1, 1: N2, 2: N3, 3: REM
# YASA values
# 0: wake, 1: N1, 2: N2, 3: N3, 4: REM

hypnogram_020_upsampled_int = np.array(hypnogram_020_upsampled)
hypnogram_020_upsampled_int[hypnogram_020_upsampled == "N1"] = 1
# only keep N1 for efficiency
hypnogram_020_upsampled_int[hypnogram_020_upsampled == "N2"] = 2
hypnogram_020_upsampled_int[hypnogram_020_upsampled == "N3"] = 3
hypnogram_020_upsampled_int[hypnogram_020_upsampled == "REM"] = 4

hypnogram_020_upsampled_int

array([1, 1, 1, ..., 1, 1, 1], shape=(2237708,), dtype=object)

In [77]:
fig = yasa.plot_spectrogram(data_020[0, :], sf_020, hypnogram_020_upsampled_int)
# data must be a 1D numpy array
fig.suptitle("Spectrogram with Hypnogram of Participant 020", fontsize=14)

plt.show()

### Sleep spindles detection

In [79]:
spindles_020 = yasa.spindles_detect(data_020, sf_020, ch_names=ch_names_020, hypno=hypnogram_020_upsampled_int, include=(2,3))
spindles_020.summary().round(3)


11-Apr-25 17:39:49 | ERROR | Wrong data amplitude for Fz (trimmed STD = 0.000). Unit of data MUST be uV! Channel will be skipped.


AttributeError: 'NoneType' object has no attribute 'summary'

In [80]:
print(np.min(data_020), np.max(data_020))

-0.0013877670752919324 0.00029716413576642777


In [81]:
data_020_uv = data_020 * 1e6

In [82]:
spindles_020 = yasa.spindles_detect(data_020_uv, sf_020, ch_names=ch_names_020, hypno=hypnogram_020_upsampled_int, include=(2,3))
spindles_020.summary().round(3)

Unnamed: 0,Start,Peak,End,Duration,Amplitude,RMS,AbsPower,RelPower,Frequency,Oscillations,Symmetry,Stage,Channel,IdxChannel
0,2940.02,2940.35,2940.71,0.69,20.288,4.488,1.213,0.432,14.082,9.0,0.471,2,Fz,0
1,2950.13,2950.16,2950.95,0.82,12.687,2.767,0.957,0.451,13.331,11.0,0.036,2,Fz,0
2,2959.69,2960.02,2960.53,0.84,26.345,6.205,1.370,0.274,13.799,11.0,0.388,2,Fz,0
3,2988.12,2988.44,2988.70,0.58,18.829,4.639,1.399,0.520,13.179,8.0,0.542,2,Fz,0
4,2996.60,2996.86,2997.31,0.71,11.804,2.839,0.955,0.349,12.088,9.0,0.361,2,Fz,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
427,22037.96,22038.14,22038.52,0.56,11.190,2.611,0.897,0.318,13.116,8.0,0.316,2,Fz,0
428,22072.66,22073.01,22073.21,0.55,17.058,4.103,1.131,0.217,11.935,7.0,0.625,2,Fz,0
429,22075.97,22076.26,22076.49,0.52,13.702,3.002,0.946,0.299,12.342,7.0,0.547,2,Fz,0
430,22077.78,22078.20,22078.62,0.84,14.342,3.220,0.981,0.323,12.754,11.0,0.494,2,Fz,0


In [83]:
spindles_020.summary(grp_chan=True, grp_stage=True, aggfunc='mean')

Unnamed: 0_level_0,Unnamed: 1_level_0,Count,Density,Duration,Amplitude,RMS,AbsPower,RelPower,Frequency,Oscillations,Symmetry
Stage,Channel,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2,Fz,429,2.648148,0.796177,14.045094,3.127975,0.947426,0.380288,13.138409,10.137529,0.502221
3,Fz,3,0.046875,0.55,15.01517,3.400456,1.008176,0.346264,13.36129,6.666667,0.401881


In [84]:
spindles_020.plot_average(errorbar=None, palette="Set1")

<Axes: title={'center': 'Average spindle'}, xlabel='Time (sec)', ylabel='Amplitude (uV)'>