In [7]:
%load_ext autoreload
%autoreload 2
import numpy as np
from pathlib import Path
import xarray as xr
import tensorflow as tf
from hmpai.models import *
from hmpai.utilities import *
from hmpai.training import train_and_evaluate, split_data_on_participants, k_fold_cross_validate, get_compile_kwargs
from hmpai.normalization import *
from hmpai.generators import *
%env TF_FORCE_GPU_ALLOW_GROWTH=true
%env TF_GPU_ALLOCATOR=cuda_malloc_async

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: TF_GPU_ALLOCATOR=cuda_malloc_async


### Set up data

In [2]:
data_path = Path("data/sat1/split_stage_data.nc")

data = xr.load_dataset(data_path)

In [2]:
data_path = Path("data/sat1/split_stage_data_unprocessed_100hz.nc")

data = xr.load_dataset(data_path)

In [8]:
data_path = Path("data/sat1/split_stage_data_weighted_mean.nc")

data = xr.load_dataset(data_path)

In [2]:
data_path = Path("data/sat1/split_stage_data_unprocessed_500hz.nc")

data = xr.load_dataset(data_path)

In [2]:
data_path = Path("data/sat1/stage_data.nc")

data = xr.load_dataset(data_path)

In [3]:
train_data, val_data, test_data = split_data_on_participants(data, 60, norm_dummy)

In [None]:
model = SAT1LSTM(len(data.channels), len(data.samples), len(data.labels))
model.compile(**get_compile_kwargs())

In [None]:
model = SAT1GRU(len(data.channels), len(data.samples), len(data.labels))
model.compile(**get_compile_kwargs())

In [6]:
model = SAT1seq2seqGRU(len(data.channels), len(data.samples), len(data.labels))
model.compile(**get_compile_kwargs())

In [5]:
train_and_evaluate(
    model, train_data, val_data, test_data, epochs=20, logs_path=Path("logs/")
)

Epoch 1/20


2023-10-23 14:03:51.960860: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8600
2023-10-23 14:03:52.050649: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:606] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-10-23 14:03:52.293323: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x556348d5f7e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-23 14:03:52.293358: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2023-10-23 14:03:52.299391: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:255] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-10-23 14:03:52.413421: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the p

Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20


(<keras.src.callbacks.History at 0x7f27eb4cb990>,
 {'confirmation': {'precision': 0.7322097378277154,
   'recall': 0.8786516853932584,
   'f1-score': 0.7987742594484167,
   'support': 445},
  'decision': {'precision': 0.9064245810055865,
   'recall': 0.7267637178051511,
   'f1-score': 0.8067122436295836,
   'support': 893},
  'encoding': {'precision': 0.7941507311586051,
   'recall': 0.7897091722595079,
   'f1-score': 0.791923724060572,
   'support': 894},
  'pre-attentive': {'precision': 0.7230419977298524,
   'recall': 0.7529550827423168,
   'f1-score': 0.7376954255935148,
   'support': 846},
  'response': {'precision': 0.8280590717299579,
   'recall': 0.8820224719101124,
   'f1-score': 0.8541893362350381,
   'support': 890},
  'accuracy': 0.7983870967741935,
  'macro avg': {'precision': 0.7967772238903434,
   'recall': 0.8060204260220694,
   'f1-score': 0.7978589977934251,
   'support': 3968},
  'weighted avg': {'precision': 0.8049161647545949,
   'recall': 0.7983870967741935,
   'f

In [None]:
train_and_evaluate(
    model,
    train_data,
    val_data,
    test_data,
    epochs=20,
    logs_path=Path("logs/"),
    generator=SequentialSAT1DataGenerator,
)

In [11]:
# Run before re-training to clear up VRAM
import gc

gc.collect()
tf.keras.backend.clear_session()
del model

In [6]:
model.save("models/lstm")

INFO:tensorflow:Assets written to: models/lstm/assets


INFO:tensorflow:Assets written to: models/lstm/assets


In [6]:
model.save("models/gru")

INFO:tensorflow:Assets written to: models/gru/assets


INFO:tensorflow:Assets written to: models/gru/assets


In [10]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 154, 30)]         0         
                                                                 
 masking (Masking)           (None, 154, 30)           0         
                                                                 
 gru (GRU)                   (None, 154, 64)           18432     
                                                                 
 gru_1 (GRU)                 (None, 154, 32)           9408      
                                                                 
 dense (Dense)               (None, 154, 5)            165       
                                                                 
Total params: 28005 (109.39 KB)
Trainable params: 28005 (109.39 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
