In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
from pathlib import Path
import xarray as xr
import tensorflow as tf
from shared.models import *
from shared.utilities import *
from shared.training import train_and_evaluate, split_data_on_participants, k_fold_cross_validate, get_compile_kwargs
from shared.normalization import *
from shared.generators import *
%env TF_FORCE_GPU_ALLOW_GROWTH=true
%env TF_GPU_ALLOCATOR=cuda_malloc_async

env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: TF_GPU_ALLOCATOR=cuda_malloc_async


### Set up data

In [2]:
data_path = Path("data/sat1/split_stage_data.nc")

data = xr.load_dataset(data_path)



In [8]:
data_path = Path("data/sat1/split_stage_data_weighted_mean.nc")

data = xr.load_dataset(data_path)

In [2]:
data_path = Path("data/sat1/split_stage_data_unprocessed_500hz.nc")

data = xr.load_dataset(data_path)

In [8]:
data_path = Path("data/sat1/stage_data.nc")

data = xr.load_dataset(data_path)

In [9]:
train_data, val_data, test_data = split_data_on_participants(data, 60, norm_dummy)

In [52]:
model = SAT1LSTM(len(data.channels), len(data.samples), len(data.labels))
model.compile(**get_compile_kwargs())

In [10]:
model = SAT1GRU(len(data.channels), len(data.samples), len(data.labels))
model.compile(**get_compile_kwargs())

In [10]:
model = SAT1seq2seqGRU(len(data.channels), len(data.samples), len(data.labels))
model.compile(**get_compile_kwargs())

In [None]:
train_and_evaluate(
    model,
    train_data,
    val_data,
    test_data,
    epochs=20,
    logs_path=Path("logs/"),
    generator=SequentialSAT1DataGenerator,
)

In [7]:
# Run before re-training to clear up VRAM
import gc

gc.collect()
tf.keras.backend.clear_session()
del model

In [12]:
train_and_evaluate(
    model, train_data, val_data, test_data, epochs=20, logs_path=Path("logs/")
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20


(<keras.src.callbacks.History at 0x7f55e16f9650>,
 {'confirmation': {'precision': 0.6795252225519288,
   'recall': 0.5204545454545455,
   'f1-score': 0.5894465894465895,
   'support': 440},
  'decision': {'precision': 0.681265206812652,
   'recall': 0.6363636363636364,
   'f1-score': 0.6580493537015276,
   'support': 880},
  'encoding': {'precision': 0.659963436928702,
   'recall': 0.8242009132420092,
   'f1-score': 0.732994923857868,
   'support': 876},
  'pre-attentive': {'precision': 0.8543577981651376,
   'recall': 0.8485193621867881,
   'f1-score': 0.8514285714285713,
   'support': 878},
  'response': {'precision': 0.7750906892382105,
   'recall': 0.7300683371298405,
   'f1-score': 0.7519061583577713,
   'support': 878},
  'accuracy': 0.7330465587044535,
  'macro avg': {'precision': 0.7300404707393262,
   'recall': 0.711921358875364,
   'f1-score': 0.7167651193584657,
   'support': 3952},
  'weighted avg': {'precision': 0.7356498538987007,
   'recall': 0.7330465587044535,
   'f1-s

In [6]:
model.save("models/gru")

INFO:tensorflow:Assets written to: models/gru/assets


INFO:tensorflow:Assets written to: models/gru/assets


In [10]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 154, 30)]         0         
                                                                 
 masking (Masking)           (None, 154, 30)           0         
                                                                 
 gru (GRU)                   (None, 154, 64)           18432     
                                                                 
 gru_1 (GRU)                 (None, 154, 32)           9408      
                                                                 
 dense (Dense)               (None, 154, 5)            165       
                                                                 
Total params: 28005 (109.39 KB)
Trainable params: 28005 (109.39 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
