In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
from pathlib import Path
import xarray as xr
import tensorflow as tf
from shared.models import *
from shared.generators import SAT1DataGenerator, NewSAT1DataGenerator
from shared.utilities import earlyStopping_cb
from sklearn.metrics import classification_report
import random
random.seed(42)
%env TF_FORCE_GPU_ALLOW_GROWTH=true
%env TF_GPU_ALLOCATOR=cuda_malloc_async

2023-08-24 13:08:23.199908: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: TF_GPU_ALLOCATOR=cuda_malloc_async


### Set up data

In [2]:
data_path = Path("data/sat1/split_stage_data.nc")

data = xr.load_dataset(data_path)

In [3]:
participants = data.participant.values.tolist()
test_participants = random.sample(participants, 5)
train_participants = [p for p in participants if p not in test_participants]

In [4]:
test_data = data.sel(participant=test_participants)
train_data = data.sel(participant=train_participants)

# train_data = data.sel(
#     participant=[
#         "0021",
#         "0022",
#         "0023",
#         "0024",
#     ]
# )

# test_data = data.sel(
#     participant=[
#         "0025"
#     ]
# )

In [22]:
batch_size = 16
train_gen = NewSAT1DataGenerator(train_data, batch_size=batch_size)
test_gen = NewSAT1DataGenerator(test_data, batch_size=batch_size)

# tf.config.optimizer.set_experimental_options({"layout_optimizer": False})

In [39]:
model = SAT1Test(30, 157, 4)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 30, 157, 1)]      0         
                                                                 
 conv2d (Conv2D)             (None, 30, 153, 64)       384       
                                                                 
 max_pooling2d (MaxPooling2  (None, 30, 76, 64)        0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 30, 74, 128)       24704     
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 30, 37, 128)       0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 30, 35, 256)       98560 

In [40]:
fit = model.fit(
    train_gen,
    epochs=20,
    callbacks=[earlyStopping_cb],
    # TODO: Create val_gen
    validation_data=test_gen,
    use_multiprocessing=True,
    workers=8,
)

# model.fit(
#     train_gen,
#     epochs=20,
#     # TODO: Create val_gen (use 6th participant?)
#     validation_data=test_gen,
#     use_multiprocessing=True,
#     workers=8,
# )

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20


In [38]:
# Run before re-training to clear up VRAM
import gc

gc.collect()
tf.keras.backend.clear_session()
del model

In [30]:
test_gen.cat_labels

['confirmation', 'decision', 'encoding', 'response']

In [41]:
print("Testset results")
predicted_classes = np.argmax(model.predict(test_gen), axis=1)
predicted_classes = [test_gen.cat_labels[idx] for idx in list(predicted_classes)]
print(classification_report(test_gen.full_labels, predicted_classes))
# print(test_gen.categories)

Testset results
              precision    recall  f1-score   support

confirmation       0.73      0.63      0.68       441
    decision       0.72      0.81      0.76       844
    encoding       0.88      0.89      0.89       869
    response       0.84      0.80      0.82       870

    accuracy                           0.80      3024
   macro avg       0.80      0.78      0.79      3024
weighted avg       0.81      0.80      0.80      3024

