In [38]:
import numpy as np
from tensorflow.keras.utils import to_categorical
from dl_ts_models.main import create_classifier
from drowsiness_detection import config
from drowsiness_detection.data import load_experiment_data
from drowsiness_detection.models import ThreeDStandardScaler

In [66]:
# load data
config.set_paths(30, 10)
X_train, X_test, y_train, y_test, cv_splits = load_experiment_data(
    exclude_by="a",
    num_targets=2,
    seed=45,
    test_size=.2,
    split_by_subjects=True,
    use_dummy_data=False,
    nn_experiment=True,
    feature_col_indices=(5, 8, 9, 14, 15, 16, 19),
    model_name="mvts_transformer")

X_train shape: (29012, 300, 7), y_train shape: (29012,)
X_test shape: (7094, 300, 7), y_test shape: (7094,)


In [67]:
num_samples = 1000
X_train = X_train[:num_samples]
X_test = X_test[:num_samples]
y_test = y_test[:num_samples]
y_train = y_train[:num_samples]



In [68]:
scaler = ThreeDStandardScaler(feature_axis=-1)

X_train_scaled = scaler.fit_transform(X_train, y_train)
X_test_scaled = scaler.transform(X_test)

input_shape = (None, 300, 7)


In [69]:
    # transform the labels from integers to one hot vectors
nb_classes = 2
y_train = to_categorical(y_train, nb_classes)
y_test = to_categorical(y_test, nb_classes)
y_test.shape

(1000, 2)

In [70]:
# base_path = config.SOURCES_ROOT_PATH.parent.joinpath("data/gcloud_dataset/30sec/")
# np.save(base_path.joinpath("x_test.npy"), X_test_scaled)
# np.save(base_path.joinpath("x_train.npy"), X_train_scaled)
# np.save(base_path.joinpath("y_test.npy"), y_test)
# np.save(base_path.joinpath("y_train.npy"), y_train)
X_test.shape

(1000, 300, 7)

In [71]:
model = create_classifier(
    classifier_name="resnet",
    input_shape=input_shape[1:],
    nb_classes=nb_classes,
    output_directory=str(config.SOURCES_ROOT_PATH.parent.joinpath("data/resnet_experiments")) + "/",
    verbose=True,
)

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 300, 7)]     0           []                               
                                                                                                  
 conv1d (Conv1D)                (None, 300, 64)      3648        ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 300, 64)     256         ['conv1d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 activation (Activation)        (None, 300, 64)      0           ['batch_normalization[0][0]']

In [72]:
metrics = model.fit(x_train=X_train, y_train=y_train, x_val=X_test, y_val=y_test,
                    nb_epochs=2, batch_size=64,
                    # class_weight={0: 0.84, 1: 1.14}
                    class_weight=None
                    )
print(metrics)

Epoch 1/2
Epoch 2/2
[[0.99191886 0.00808115]
 [0.99345076 0.00654917]
 [0.9878036  0.01219641]
 [0.98885286 0.01114711]
 [0.9900662  0.0099338 ]
 [0.99020797 0.009792  ]
 [0.9910924  0.00890765]
 [0.98660594 0.01339407]
 [0.9855483  0.01445163]
 [0.9886281  0.01137191]]
saving results to  /home/tim/IM/data/resnet_experiments/
   precision  accuracy    recall   duration
0   0.705411     0.412  0.501695  25.631132


In [73]:
y_pred = model.predict(X_test_scaled, y_test, return_df_metrics=False)

In [74]:
y_pred

array([[0.00402456, 0.99597543],
       [0.00401705, 0.995983  ],
       [0.00400956, 0.99599046],
       ...,
       [0.00399128, 0.99600875],
       [0.00399631, 0.9960037 ],
       [0.00399645, 0.99600357]], dtype=float32)

In [None]:
np.argmax([[0], [1]], axis=1)

In [None]:
from sklearn.metrics import RocCurveDisplay

y_hat = model.predict(X_test_scaled)

RocCurveDisplay.from_predictions(y_pred=y_hat, y_true=y_test, pos_label=1)