In [1]:
import os
import mne
import glob
import joblib
import logging
import numpy as np
import pandas as pd
import scipy.signal as sp_sig
import scipy.stats as sp_stats
import antropy as ant
from mne.filter import filter_data
from scipy.stats import zscore
from sklearn.preprocessing import robust_scale
from sklearn.metrics import accuracy_score
from imblearn.pipeline import make_pipeline
from imblearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
from lightgbm import LGBMClassifier
from joblib import dump, load
from yasa import sliding_window, bandpower_from_psd_ndarray
import time

# define feature extraction algorithm and parameters

In [2]:
## feature extraction def
def feature_lfp(epochs, sf, bands, kwargs_welch, freq_broad):
    # Calculate standard descriptive statistics
    hmob, hcomp = ant.hjorth_params(epochs, axis=1)
    feat = {
                    'std': np.std(epochs, ddof=1, axis=1),
                    'iqr': sp_stats.iqr(epochs, rng=(25, 75), axis=1),
                    'skew': sp_stats.skew(epochs, axis=1),
                    'kurt': sp_stats.kurtosis(epochs, axis=1),
                    'nzc': ant.num_zerocross(epochs, axis=1),
                    'hmob': hmob,
                    'hcomp': hcomp
                }
    # power band
    freqs, psd = sp_sig.welch(epochs, sf, **kwargs_welch)
    bp = bandpower_from_psd_ndarray(psd, freqs, bands=bands, relative = True)
    for j, (_, _, b) in enumerate(bands):
        feat[b] = bp[j]
        
    # power ratio
    Gamma = feat['lGamma'] + feat['hGamma']
    Beta = feat['lBeta'] + feat['hBeta']
    feat['dt'] = feat['Delta'] / feat['Theta']
    feat['da'] = feat['Delta'] / feat['Alpha']
    feat['db'] = feat['Delta'] / Beta
    feat['dg'] = feat['Delta'] / Gamma
    feat['ta'] = feat['Theta'] / feat['Alpha']
    feat['tb'] = feat['Theta'] / Beta
    feat['tg'] = feat['Theta'] / Gamma
    feat['ab'] = feat['Alpha'] / Beta
    feat['ag'] = feat['Alpha'] / Gamma
    feat['bg'] = Beta / Gamma
    
    # total power
    idx_broad = np.logical_and(
        freqs >= freq_broad[0], freqs <= freq_broad[1])
    dx = freqs[1] - freqs[0]
    feat['abspow'] = np.trapz(psd[:, idx_broad], dx=dx)

    # Calculate entropy and fractal dimension features
    feat['perm'] = np.apply_along_axis(
        ant.perm_entropy, axis=1, arr=epochs, normalize=True)
    feat['higuchi'] = np.apply_along_axis(
        ant.higuchi_fd, axis=1, arr=epochs)
    feat['petrosian'] = ant.petrosian_fd(epochs, axis=1)

    # Convert to dataframe
    features = []
    feat = pd.DataFrame(feat)
    features.append(feat)

    # Save features to dataframe
    features = pd.concat(features, axis=1)
    features.index.name = 'epoch'

    # Downcast float64 to float32 (to reduce size of training datasets)
    cols_float = features.select_dtypes(np.float64).columns.tolist()
    features[cols_float] = features[cols_float].astype(np.float32)

    # Sort the column names here (same behavior as lightGBM)
    features.sort_index(axis=1, inplace=True)
    return features

In [3]:
# define parameters for feature extraction
# range of frequency band
freq_broad = (1, 90)
# window second
win_sec = 2
# each individual band
bands = [
    (1, 4, 'Delta'), (4, 8, 'Theta'), (8, 12, 'Alpha'), 
    (12, 20, 'lBeta'), (20, 30, 'hBeta'), (30, 45, 'lGamma'), (55, 90, 'hGamma')
]
# first downsample rate
# We did two times of down sampling because we recorded the data using 500 or 1000 Hz
# The first down sampling is to make sure all data have the same sampling rate
# And the second is to reducing computational load
fs_first = 500
# second downsample rate
fs_second = 200
# define window and parameters for psd comuputation
win = int(win_sec * fs_second)
kwargs_welch = dict(window='hamming', nperseg=win, average='median')    

# import demo data

In [4]:
# define parameters for data preprocessing
# data_folder_path
data_folder = r'D:\a big_decoding_local_temporary\github代码上传'
data_name = 'demo_data.npy'
data = np.load('%s\%s' %(data_folder, data_name))

In [5]:
# LFP data were downsampled to 200 Hz during preprocessing
# shape of data (epoch in 30 seconds, channel, time point: 200 Hz * 30 seconds)
data.shape

(199, 1, 6000)

# import BGOOSE model

In [6]:
# set parameters for model loading
# model_file_path
model_folder = r'D:\a big_decoding_local_temporary\github代码上传'
model_name = 'final_model2312.joblib'
model = load('%s\%s' %(model_folder, model_name))

# run the prediction

In [7]:
# filter the data into different frequency bands
epochs_filter = filter_data(data.squeeze(), fs_second, 
                      l_freq=freq_broad[0], h_freq=freq_broad[1], verbose=False)
# get features matrix
features = feature_lfp(epochs_filter, fs_second, bands, kwargs_welch, freq_broad)
# load the machine learning model
model = load('%s\%s' %(model_folder, model_name))
# set the input array
X = np.array(features.iloc[:,0:28]).astype(np.float32)
# get the predictions
predict = model.predict(X)
# collect the sleep stages label
sleep_stage = []
for i in predict:
    if i == 0:
        sleep_stage += ['Wake']
    if i == 1:
        sleep_stage += ['Sleep-NREM']
    if i == 2:
        sleep_stage += ['Sleep-REM']
# report the results
for i in range(len(sleep_stage)):
    print ('The stage for your epoch [%s] is [%s]'%(str(i), sleep_stage[i]))

The stage for your epoch [0] is [Sleep-NREM]
The stage for your epoch [1] is [Sleep-NREM]
The stage for your epoch [2] is [Wake]
The stage for your epoch [3] is [Wake]
The stage for your epoch [4] is [Wake]
The stage for your epoch [5] is [Sleep-NREM]
The stage for your epoch [6] is [Sleep-NREM]
The stage for your epoch [7] is [Wake]
The stage for your epoch [8] is [Wake]
The stage for your epoch [9] is [Wake]
The stage for your epoch [10] is [Wake]
The stage for your epoch [11] is [Wake]
The stage for your epoch [12] is [Wake]
The stage for your epoch [13] is [Sleep-NREM]
The stage for your epoch [14] is [Wake]
The stage for your epoch [15] is [Wake]
The stage for your epoch [16] is [Sleep-NREM]
The stage for your epoch [17] is [Sleep-NREM]
The stage for your epoch [18] is [Wake]
The stage for your epoch [19] is [Sleep-NREM]
The stage for your epoch [20] is [Wake]
The stage for your epoch [21] is [Wake]
The stage for your epoch [22] is [Wake]
The stage for your epoch [23] is [Wake]
Th