In [8]:
import numpy as np
from sklearn.manifold import TSNE
from pathlib import Path
import xarray as xr
from hmpai.normalization import *
from hmpai.data import preprocess
from hmpai.training import split_data_on_participants
import matplotlib.pyplot as plt
from hmpai.utilities import MASKING_VALUE
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import random

In [2]:
data_path = Path("../data/sat1/split_stage_data.nc")
data = xr.load_dataset(data_path)
random.seed(42)
np.random.seed(42)
train_data, val_data, test_data = split_data_on_participants(data, 60, norm_dummy)
train_data = preprocess(train_data)
test_data = preprocess(test_data)

In [3]:
train_mean_activation = (
    train_data.where(train_data.data != MASKING_VALUE)
    .mean(dim=["samples"], skipna=True)
    .data
)
test_mean_activation = (
    test_data.where(test_data.data != MASKING_VALUE)
    .mean(dim=["samples"], skipna=True)
    .data
)

In [12]:
# stage_durations = np.isnandata.where(data.data != MASKING_VALUE)
# np.isnan(test_set.data.where(test_set.data != MASKING_VALUE))

train_lengths = np.isnan(
    train_data.data.where(train_data.data != MASKING_VALUE)
).argmax(dim=["samples", "channels"])["samples"]

test_lengths = np.isnan(test_data.data.where(test_data.data != MASKING_VALUE)).argmax(
    dim=["samples", "channels"]
)["samples"]

train_lengths_data = train_lengths.data.reshape(-1, 1)
test_lengths_data = test_lengths.data.reshape(-1, 1)

In [25]:
combined_train = np.append(train_mean_activation, train_lengths_data, axis=1)
combined_test = np.append(test_mean_activation, test_lengths_data, axis=1)

In [27]:
rf = RandomForestClassifier(random_state=42).fit(
    combined_train, train_mean_activation.labels
)

In [9]:
rf = RandomForestClassifier(random_state=42).fit(
    train_mean_activation.to_numpy(), train_mean_activation.labels
)

In [13]:
rf = RandomForestClassifier(random_state=42).fit(
    train_lengths_data, train_lengths.labels
)

In [28]:
rf.score(combined_test, test_mean_activation.labels)

0.6357931726907631

In [11]:
rf.score(test_mean_activation, test_mean_activation.labels)

0.4445281124497992

In [14]:
rf.score(test_lengths_data, test_lengths.labels)

0.5240963855421686