In [1]:
import os
import numpy as np
import scipy
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sktime.classification.compose import ColumnEnsembleClassifier
from sktime.classification.interval_based import DrCIF
from sktime.classification.kernel_based import RocketClassifier
from sktime.classification.hybrid import HIVECOTEV2
from sktime.classification.distance_based import KNeighborsTimeSeriesClassifier
from sktime.transformations.panel.catch22 import Catch22
from sktime.pipeline import make_pipeline

from eegsc.preprocessing.filters import BandPassFilter, FFTFilter
from eegsc.utils.io import read_raw
from eegsc.utils.path import get_data_path
from eegsc.utils.experiments import create_sequence_dataset
from eegsc.utils.cross_val import train_test_split, cross_val_score

In [2]:
data = read_raw(os.path.join(get_data_path(), 'raw', '1st_Day.mat'),
                # data_types=['left_real', 'right_real',
                #             'left_im1', 'right_im1',
                #             'left_im2', 'right_im2',
                #             'left_quasi', 'right_quasi'])
                # data_types=['left_quasi', 'right_quasi'])
                # data_types=['left_im2', 'right_im2'])
                # data_types=['left_im1', 'right_im1'])
                data_types=['left_real', 'right_real'])

In [3]:
start_time = 5
signal_duration = 10 - start_time
bandpass_filter = BandPassFilter(signal_duration=signal_duration)

In [4]:
statistics, labels, person_idxs = create_sequence_dataset(
    data=data,
    bandpass_filter=bandpass_filter,
    start_time=start_time,
    is_len_equal=True
)
statistics.shape, labels.shape

((569, 32, 5094), (569,))

In [5]:
statistics *= 1e6

In [6]:
# for i, sample in enumerate(statistics):
#     for j, signal in enumerate(sample):
#         statistics[i, j, :] = scipy.signal.medfilt(signal, 5)

fft_filter = FFTFilter(.75)
statistics = fft_filter.filter(statistics)

In [7]:
class MultiRandomForest:
    def __init__(self, n_estimators, **params) -> None:
        self.n_estimators = n_estimators
        self.estimators = [RandomForestClassifier(**params) for _ in range(n_estimators)]

    def fit(self, x_train: list, y_train: np.ndarray):
        n_features = x_train.shape[1] // self.n_estimators
        print('Fitting...')
        for i, estimator in tqdm(enumerate(self.estimators)):
            # estimator.fit(x_train[i], y_train)
            start_col = i * n_features
            end_col = min((i + 1) * n_features, x_train.shape[1])
            estimator.fit(x_train[:, start_col: end_col], y_train)

    def predict(self, x_test: list):
        # votes = np.vstack(
        #     [estimator.predict(x_test[i]) for i, estimator in enumerate(self.estimators)])
        votes = []
        n_features = x_test.shape[1] // self.n_estimators
        for i, estimator in enumerate(self.estimators):
            start_col = i * n_features
            end_col = min((i + 1) * n_features, x_test.shape[1])
            votes.append(estimator.predict(x_test[:, start_col: end_col]))
        votes = np.vstack(votes)
        preds = np.array([np.bincount(votes[:, i]).argmax() for i in range(votes.shape[1])])
        return preds

In [8]:
model = ColumnEnsembleClassifier([
    # (
    #     'knn',
    #     KNeighborsTimeSeriesClassifier(n_neighbors=5, distance='dtw'),
    #     list(range(statistics.shape[1]))
    # ),
    # (
    #     'Catch22_RF',
    #     make_pipeline(RandomForestClassifier()),
    #     list(range(statistics.shape[1]))
    # ),
    (
        'ROCKET',
        RocketClassifier(n_jobs=8, random_state=0),
        list(range(statistics.shape[1]))
        # [0]
    ),
])
# model = MultiRandomForest(n_estimators=statistics.shape[1], random_state=0)

def catch22_preprocessor(x_train: np.ndarray, x_test: np.ndarray):
    catch22 = Catch22(n_jobs=8, replace_nans=True)
    x_train_transformed = catch22.fit_transform(x_train).to_numpy()
    x_test_transformed = catch22.transform(x_test).to_numpy()

    return x_train_transformed, x_test_transformed

metrics = cross_val_score(
    data=statistics,
    labels=labels,
    person_idxs=person_idxs,
    model=model,
    metric=accuracy_score,
    n_test_persons=3,
    random_state=0,
    train_score=False,
    data_preprocessor=None
)
metrics

 20%|██        | 1/5 [1:10:19<4:41:19, 4219.98s/it]

score [ 2 11  3]: 0.46551724137931033


 40%|████      | 2/5 [2:20:00<3:29:50, 4196.84s/it]

score [10  0  4]: 0.5462184873949579


 60%|██████    | 3/5 [3:40:29<2:29:31, 4485.67s/it]

score [ 7  5 14]: 0.5087719298245614


 80%|████████  | 4/5 [5:07:07<1:19:26, 4766.77s/it]

score [12  6  9]: 0.5420560747663551


100%|██████████| 5/5 [6:29:08<00:00, 4669.63s/it]  

score [13  8  1]: 0.6991150442477876





Unnamed: 0,test_persons,test_metrics
0,"[2, 11, 3]",0.465517
1,"[10, 0, 4]",0.546218
2,"[7, 5, 14]",0.508772
3,"[12, 6, 9]",0.542056
4,"[13, 8, 1]",0.699115
