In [1]:

import os
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

from eegsc.preprocessing.filters import BandPassFilter
from eegsc.preprocessing.spectrum import SpectrumTransformer
from eegsc.utils.io import read_raw
from eegsc.utils.path import get_data_path
from eegsc.utils.experiments import create_spectrum_dataset
from eegsc.utils.cross_val import train_test_split, cross_val_score

In [2]:
start_time = 5
signal_duration = 10 - start_time
bandpass_filter = BandPassFilter(signal_duration=signal_duration)
spectrum_transformer = SpectrumTransformer(psd_method='periodogram',
                                           signal_duration=signal_duration)

In [8]:
data = read_raw(os.path.join(get_data_path(), 'raw', '1st_Day.mat'),
                # data_types=['left_real', 'right_real',
                #             'left_im1', 'right_im1',
                #             'left_im2', 'right_im2',
                #             'left_quasi', 'right_quasi'])
                # data_types=['left_quasi', 'right_quasi'])
                # data_types=['left_im2', 'right_im2'])
                # data_types=['left_im1', 'right_im1'])
                data_types=['left_real', 'right_real'])

In [9]:
statistics, labels, person_idxs = create_spectrum_dataset(
    data, bandpass_filter, spectrum_transformer, start_time=start_time)
statistics.shape, labels.shape, person_idxs.shape

((569, 1920), (569,), (569,))

In [12]:
metrics = cross_val_score(data=statistics,
                          labels=labels,
                          person_idxs=person_idxs,
                        #   model=RandomForestClassifier(n_estimators=1000),
                        #   model=LogisticRegression(C=.01),
                          model=SVC(C=1),
                          metric=accuracy_score,
                          n_test_persons=3,
                          random_state=0)
metrics

100%|██████████| 5/5 [00:02<00:00,  1.78it/s]


Unnamed: 0,test_persons,train_metrics,test_metrics
0,"[2, 11, 3]",0.580574,0.474138
1,"[10, 0, 4]",0.582222,0.529412
2,"[7, 5, 14]",0.571429,0.482456
3,"[12, 6, 9]",0.590909,0.495327
4,"[13, 8, 1]",0.60307,0.495575


In [5]:
x_train, x_test, y_train, y_test = train_test_split(
    statistics, labels, person_idxs, test_size=0.2, random_state=0)
x_train.shape, x_test.shape

((1834, 1920), (463, 1920))

In [9]:
pca = PCA(n_components=.99, svd_solver='full').fit(x_train, x_test)
pca.n_components_, pca.n_features_

(178, 1920)

In [10]:
x_train = pca.transform(x_train)
x_test = pca.transform(x_test)

In [11]:
rf = RandomForestClassifier().fit(x_train, y_train)

In [12]:
accuracy_score(y_train, rf.predict(x_train)), accuracy_score(y_test, rf.predict(x_test))

(1.0, 0.11447084233261338)

In [10]:
rf.feature_importances_.shape, rf.feature_importances_[rf.feature_importances_ != 0].shape

((1920,), (645,))

In [11]:
[spectrum_transformer.columns[i] for i in np.flatnonzero(rf.feature_importances_)]

['original_mean_2',
 'original_std_0',
 'original_std_1',
 'original_std_2',
 'original_std_3',
 'original_std_4',
 'original_std_5',
 'original_std_6',
 'original_std_7',
 'original_std_8',
 'original_std_9',
 'original_std_10',
 'original_std_11',
 'original_std_12',
 'original_std_13',
 'original_std_14',
 'original_std_15',
 'original_std_16',
 'original_std_17',
 'original_std_18',
 'original_std_19',
 'original_std_20',
 'original_std_21',
 'original_std_22',
 'original_std_23',
 'original_std_24',
 'original_std_25',
 'original_std_26',
 'original_std_27',
 'original_std_28',
 'original_std_29',
 'original_std_30',
 'original_std_31',
 'original_median_0',
 'original_median_1',
 'original_median_2',
 'original_median_5',
 'original_median_6',
 'original_median_7',
 'original_median_8',
 'original_median_10',
 'original_median_11',
 'original_median_13',
 'original_median_15',
 'original_median_17',
 'original_median_18',
 'original_median_19',
 'original_median_20',
 'original_m