In [1]:
import os
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

from eegsc.preprocessing.filters import BandPassFilter
from eegsc.preprocessing.spectrum import SpectrumTransformer
from eegsc.utils.io import read_raw
from eegsc.utils.path import get_data_path
from eegsc.utils.experiments import create_spectrum_dataset
from eegsc.utils.cross_val import train_test_split, cross_val_score

In [2]:
start_time = 0
signal_duration = 10 - start_time
bandpass_filter = BandPassFilter(signal_duration=signal_duration)
spectrum_transformer = SpectrumTransformer(psd_method='periodogram',
                                           signal_duration=signal_duration)

In [3]:
# spectrum_transformer.columns

In [4]:
data = read_raw(os.path.join(get_data_path(), 'raw', '1st_Day.mat'),
                # data_types=['left_real', 'right_real',
                #             'left_im1', 'right_im1',
                #             'left_im2', 'right_im2',
                #             'left_quasi', 'right_quasi'])
                # data_types=['left_quasi', 'right_quasi'])
                # data_types=['left_im2', 'right_im2'])
                # data_types=['left_im1', 'right_im1'])
                data_types=['left_real', 'right_real'])

In [5]:
statistics, labels, person_idxs = create_spectrum_dataset(
    data,
    bandpass_filter,
    spectrum_transformer,
    start_time=start_time,
    is_len_equal=True
)
statistics.shape, labels.shape, person_idxs.shape

((569, 320), (569,), (569,))

In [8]:
metrics = cross_val_score(data=statistics,
                          labels=labels,
                          person_idxs=person_idxs,
                          model=RandomForestClassifier(n_estimators=1000),
                          metric=accuracy_score,
                          n_test_persons=3,
                          random_state=0)
metrics

 20%|██        | 1/5 [00:03<00:12,  3.10s/it]

score [ 2 11  3]: 0.4827586206896552


 40%|████      | 2/5 [00:06<00:09,  3.09s/it]

score [10  0  4]: 0.5630252100840336


 60%|██████    | 3/5 [00:09<00:06,  3.00s/it]

score [ 7  5 14]: 0.5263157894736842


 80%|████████  | 4/5 [00:12<00:02,  2.97s/it]

score [12  6  9]: 0.48598130841121495


100%|██████████| 5/5 [00:14<00:00,  2.98s/it]

score [13  8  1]: 0.46017699115044247





Unnamed: 0,test_persons,train_metrics,test_metrics
0,"[2, 11, 3]",1.0,0.482759
1,"[10, 0, 4]",1.0,0.563025
2,"[7, 5, 14]",1.0,0.526316
3,"[12, 6, 9]",1.0,0.485981
4,"[13, 8, 1]",1.0,0.460177
