In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
from sklearn.manifold import TSNE
from pathlib import Path
import xarray as xr
from hmpai.normalization import *
from hmpai.data import preprocess
from hmpai.training import split_data_on_participants
import matplotlib.pyplot as plt
from hmpai.utilities import MASKING_VALUE
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import random
from tslearn.neighbors import KNeighborsTimeSeriesClassifier
from hmpai.pytorch.utilities import set_global_seed
from hmpai.training import calculate_features, k_fold_cross_validate_sklearn
from hmpai.utilities import print_results


In [2]:
data_path = Path("../data/sat1/split_stage_data_100hz.nc")
data = xr.load_dataset(data_path)
set_global_seed(42)

In [3]:
rf = RandomForestClassifier(random_state=42)
results = k_fold_cross_validate_sklearn(rf, data, normalization_fn=norm_min1_to_1, seed=42)

Fold 1: test fold: ['0009' '0017' '0001' '0024' '0012']
Fold 1: Accuracy: 0.637956204379562
Fold 1: F1-Score: 0.618794133152609
Fold 2: test fold: ['0010' '0014' '0002' '0023' '0006']
Fold 2: Accuracy: 0.677671802396674
Fold 2: F1-Score: 0.6598334336143747
Fold 3: test fold: ['0003' '0013' '0016' '0004' '0005']
Fold 3: Accuracy: 0.6161768399899523
Fold 3: F1-Score: 0.6088948270895493
Fold 4: test fold: ['0021' '0018' '0022' '0019' '0025']
Fold 4: Accuracy: 0.6557139405653539
Fold 4: F1-Score: 0.6481566132273803
Fold 5: test fold: ['0008' '0011' '0015' '0020' '0007']
Fold 5: Accuracy: 0.6310703666997026
Fold 5: F1-Score: 0.6214356787659987


In [5]:
print_results(results)

Accuracies
[0.637956204379562, 0.677671802396674, 0.6161768399899523, 0.6557139405653539, 0.6310703666997026]
F1-Scores
[0.618794133152609, 0.6598334336143747, 0.6088948270895493, 0.6481566132273803, 0.6214356787659987]
Average Accuracy: 0.643717830806249, std: 0.021205846828958735
Average F1-Score: 0.6314229371699824, std: 0.019255752323625478


In [3]:
train_data = calculate_features(train_data)
test_data = calculate_features(test_data)

In [4]:
train_data_np = train_data.to_numpy().reshape(-1, train_data.shape[1] * train_data.shape[2])
test_data_np = test_data.to_numpy().reshape(-1, test_data.shape[1] * test_data.shape[2])

In [5]:
rf = RandomForestClassifier(random_state=42).fit(train_data_np, train_data.labels)

In [57]:
accuracy = rf.score(test_data_np, test_data.labels)
print(accuracy)

0.6063668519454269


In [6]:
rf.predict(test_data_np)

array(['response', 'pre-attentive', 'pre-attentive', ..., 'decision',
       'decision', 'encoding'], dtype=object)

In [12]:
svm = SVC(kernel="rbf", random_state=42)
svm.fit(dataset_np, dataset.labels)

In [13]:
accuracy = svm.score(dataset_np_test, dataset_test.labels)
print(accuracy)

0.6647623154036866


In [4]:
# stage_durations = np.isnandata.where(data.data != MASKING_VALUE)
# np.isnan(test_set.data.where(test_set.data != MASKING_VALUE))

train_lengths = np.isnan(
    train_data.data.where(train_data.data != MASKING_VALUE)
).argmax(dim=["samples", "channels"])["samples"]

test_lengths = np.isnan(test_data.data.where(test_data.data != MASKING_VALUE)).argmax(
    dim=["samples", "channels"]
)["samples"]

train_lengths_data = train_lengths.data.reshape(-1, 1)
test_lengths_data = test_lengths.data.reshape(-1, 1)

In [5]:
combined_train = np.append(train_mean_activation, train_lengths_data, axis=1)
combined_test = np.append(test_mean_activation, test_lengths_data, axis=1)

In [8]:
rf = RandomForestClassifier(random_state=42).fit(
    combined_train, train_mean_activation.labels
)

In [10]:
rf = RandomForestClassifier(random_state=42).fit(
    train_mean_activation.to_numpy(), train_mean_activation.labels
)

In [13]:
rf = RandomForestClassifier(random_state=42).fit(
    train_lengths_data, train_lengths.labels
)

In [9]:
rf.score(combined_test, test_mean_activation.labels)

0.7789694944486364

In [11]:
rf.score(test_mean_activation, test_mean_activation.labels)

0.2819068664438935

In [14]:
rf.score(test_lengths_data, test_lengths.labels)

0.5240963855421686