In [2]:
from scipy.signal import butter, lfilter, freqz, welch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

dataset = pd.read_csv('../../dataset/DataSet.csv')
dataset = dataset[dataset.target!=0].drop('Task', axis=1).reset_index(drop=True)

In [3]:
dataset

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,631,632,633,634,635,636,637,638,639,target
0,-0.000014,-0.000013,-0.000007,0.000010,0.000035,0.000006,-0.000001,0.000012,0.000005,0.000021,...,0.000002,-0.000020,-0.000012,0.000003,0.000016,-0.000002,-0.000005,-0.000020,-0.000020,30
1,-0.000027,-0.000036,-0.000026,-0.000021,-0.000008,-0.000012,-0.000010,-0.000004,-0.000004,-0.000015,...,-0.000037,-0.000075,-0.000043,-0.000032,-0.000047,-0.000004,0.000002,-0.000018,-0.000022,40
2,-0.000010,0.000013,0.000030,0.000025,0.000025,-0.000017,-0.000016,0.000004,0.000022,0.000022,...,0.000042,0.000039,0.000027,0.000025,0.000024,0.000032,0.000039,0.000031,0.000013,30
3,-0.000067,-0.000049,-0.000046,-0.000066,-0.000069,-0.000034,-0.000024,-0.000016,-0.000003,0.000004,...,0.000030,-0.000001,-0.000010,0.000002,0.000002,-0.000029,-0.000046,0.000036,0.000047,40
4,0.000010,0.000031,0.000018,0.000010,0.000020,0.000001,-0.000007,0.000039,0.000002,0.000024,...,0.000013,-0.000004,-0.000009,0.000003,-0.000003,0.000000,-0.000027,-0.000029,-0.000024,40
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9832,-0.000044,-0.000028,-0.000013,-0.000013,-0.000039,-0.000041,-0.000013,-0.000023,0.000025,-0.000008,...,-0.000036,-0.000027,-0.000042,-0.000042,-0.000062,-0.000065,-0.000011,-0.000049,-0.000077,20
9833,0.000010,-0.000077,0.000003,-0.000012,-0.000063,-0.000045,-0.000127,-0.000146,0.000011,0.000059,...,0.000065,0.000086,0.000062,0.000108,0.000060,-0.000008,0.000050,0.000106,0.000086,10
9834,0.000116,0.000065,0.000086,0.000103,0.000062,0.000159,0.000125,0.000071,0.000057,-0.000009,...,-0.000175,-0.000200,-0.000194,-0.000210,-0.000244,-0.000250,-0.000185,-0.000260,-0.000253,20
9835,0.000086,0.000041,-0.000060,-0.000122,-0.000157,-0.000053,-0.000198,-0.000231,-0.000256,-0.000204,...,-0.000082,-0.000062,-0.000078,-0.000098,-0.000145,-0.000076,-0.000122,-0.000062,-0.000084,10


In [4]:
# extract mean brainwave
fs = 160  # 采样频率

# 定义带通滤波器函数
def bandpass_filter(data, low, high, sfreq=160, order=5):
    nyq = 0.5 * sfreq
    low = low / nyq
    high = high / nyq
    b, a = butter(order, [low, high], btype='band')
    y = lfilter(b, a, data)
    return y

# 定义一个函数来应用所有的滤波器并返回结果
def apply_filters(row, sfreq=160):
    filters = {
        'delta': (0.1, 3),
        'theta': (4, 7),
        'alpha_slow': (8, 9),
        'alpha_middle': (9, 12),
        'alpha_fast': (12, 14),
        'beta_low': (12.5, 16),
        'beta_middle': (16.5, 20),
        'beta_high': (20.5, 28),
        'gamma': (25, 60)
    }
    results = {}
    for key, (low, high) in filters.items():
        filtered_data = bandpass_filter(row, low, high, sfreq)
        results[key + '_mean'] = np.mean(filtered_data)  # 直接计算并保存平均值
    return pd.Series(results)

# 假设 dataset 是一个 n*640 的 DataFrame
# 应用滤波器并计算每行的平均值
mean_brainwaves = dataset.drop('target',axis=1).apply(apply_filters, axis=1)
mean_brainwaves['target'] = dataset.target

In [5]:
mean_brainwaves

Unnamed: 0,delta_mean,theta_mean,alpha_slow_mean,alpha_middle_mean,alpha_fast_mean,beta_low_mean,beta_middle_mean,beta_high_mean,gamma_mean,target
0,-0.000015,-4.899878e-08,-2.045790e-08,-1.873350e-08,-1.149191e-08,3.755529e-09,1.229423e-08,-4.257819e-10,1.552808e-08,30
1,-0.000004,1.178347e-08,2.180244e-08,1.650733e-09,2.543684e-08,-1.964601e-08,7.445221e-09,1.255333e-08,-9.772046e-09,40
2,0.000059,-1.101760e-08,-2.255887e-08,4.078502e-09,-8.691410e-09,-1.541009e-08,-7.963366e-09,1.998614e-08,-5.487397e-09,30
3,-0.000006,-4.300523e-08,1.612889e-08,-2.493651e-08,-2.232885e-08,8.203456e-09,-4.043102e-10,2.011623e-08,1.452681e-08,40
4,0.000058,4.616067e-08,-1.213564e-08,6.151335e-09,2.960803e-08,-3.527544e-09,6.119341e-10,1.871087e-08,1.333430e-08,40
...,...,...,...,...,...,...,...,...,...,...
9832,0.000716,3.968427e-08,1.487511e-08,-1.027647e-08,3.418260e-09,-3.412457e-09,2.328181e-08,1.269855e-08,-4.786711e-09,20
9833,0.000235,-2.181069e-09,-3.639633e-08,-6.875953e-08,-1.442593e-08,2.401021e-08,2.053111e-08,1.449368e-08,1.238511e-08,10
9834,0.001104,-1.210862e-07,-1.730512e-08,2.579694e-08,-3.457925e-08,3.872236e-08,-2.733391e-08,-4.741316e-10,-1.959760e-08,20
9835,-0.000860,2.961705e-08,-1.797454e-08,3.760366e-08,1.088559e-08,-1.164884e-08,1.031503e-08,1.006632e-08,-2.895528e-08,10


In [6]:
# original Dataset
from autogluon.tabular import TabularPredictor

# 指定目标列名
label = 'target'

predictor = TabularPredictor(label=label).fit(mean_brainwaves)

No path specified. Models will be saved in: "AutogluonModels/ag-20240501_083744"
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets.
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='best_quality'   : Maximize accuracy. Default time_limit=3600.
	presets='high_quality'   : Strong accuracy with fast inference speed. Default time_limit=3600.
	presets='good_quality'   : Good accuracy with very fast inference speed. Default time_limit=3600.
	presets='medium_quality' : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ...
AutoGluon will save models to "AutogluonModels/ag-20240501_083744"
AutoGluon Version:  1.0.0
Python Version:     3.8.19
Operating System:   Darwin
Platform Machine:   arm64
Platform Version:   Darwin Kernel Version 23.3.0: Wed Dec 20 21:33:31 PST 2023; root:xnu-10002.81.5~7/RELEASE_ARM64_T8112
CPU 

In [7]:
predictor.leaderboard()

Unnamed: 0,model,score_val,eval_metric,pred_time_val,fit_time,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,WeightedEnsemble_L2,0.307927,accuracy,0.158767,29.077604,0.000299,0.317889,2,True,14
1,LightGBM,0.286585,accuracy,0.001188,3.279854,0.001188,3.279854,1,True,5
2,NeuralNetFastAI,0.286585,accuracy,0.005266,4.331172,0.005266,4.331172,1,True,3
3,NeuralNetTorch,0.285569,accuracy,0.003333,5.143878,0.003333,5.143878,1,True,12
4,XGBoost,0.284553,accuracy,0.003886,2.072057,0.003886,2.072057,1,True,11
5,LightGBMLarge,0.28252,accuracy,0.002111,12.702884,0.002111,12.702884,1,True,13
6,CatBoost,0.276423,accuracy,0.001143,1.201138,0.001143,1.201138,1,True,8
7,LightGBMXT,0.27439,accuracy,0.002504,3.5879,0.002504,3.5879,1,True,4
8,RandomForestGini,0.270325,accuracy,0.03898,0.514492,0.03898,0.514492,1,True,6
9,RandomForestEntr,0.269309,accuracy,0.038612,0.475667,0.038612,0.475667,1,True,7
